Do you think IIT Guwahati certified course can help you in your career?
No
Introduction
Whenever we use decision trees to make decisions, we often look at how it performs using different metrics like accuracy, etc. But, if we want to visualize how the tree is created and how the decisions are being made, we need to plot the decision tree corresponding to the dataset fed to the model. This blog will look at how we can plot the decision tree and look at multiple ways available in python to do the same.
A decision tree is a common supervised learning method in machine learning that is often used for various classification and regression problems. It is convenient because it doesn’t require feature scaling and is easy to understand. To interpret them, we can visualize them by plotting the decision tree and understanding how our model is working, and making adjustments according to it. To know more about decision trees, you can refer here.
Various ways to visualize the Decision tree
There are multiple methods present in python libraries to visualize decision trees. Some of them are as follows:
Visualizing Decision Trees using Sklearn plot tree method.
Visualizing Decision Trees using Matplotlib
Visualizing Decision Trees using Graphviz
Visualizing Decision Trees using dtreeviz
Let’s look at each of these methods with the help of a code.
Code in Python
We will be using the inbuilt iris dataset of sklearn, which is present inside sklearn.datasets.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from dtreeviz.trees import dtreeviz
from sklearn import tree
import graphviz
import random
import math
import warnings
warnings.filterwarnings("ignore")
### this class basically helps visualise Decision Trees
### by passing a decision tree model and the dataset dataframe.
### we can train and fit the model and visualise the decision trees
class DecisionTreeVisualiser:
# this is an initialiser function that would store
# store the dataframe of the dataset along with
# parameters required by a Decision Tree Model
def __init__(self, dataset_df, min_samples_leaf = 1,min_samples_split = 2, criterion="gini", splitter = "best", max_depth = None, max_features = None, min_weight_fraction_leaf = None, max_leaf_nodes = None):
self.criterion = criterion
self.max_depth = max_depth
self.max_features = max_features
self.splitter = splitter
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.dataset_df = dataset_df
self.X_train, self.Y_train = None, None
self.X_test, self.Y_test = None, None
# function with which you can visualise the dataframe
# with the first num_entries which is set to default 5
# you can change as per your choice
def visualiseDataset(self, num_entries = 5):
self.dataset_df.head(num_entries)
# function the splits the dataset into training
# and testing dataset splitting 100 % into x percent training data
# z% testing data such that x+y+z = 100
# by default it's set to x = 80, z = 20
# another default parameter is shuffle which allows to shuffle to data
# you can set it to false if you don't want to shuffle the data
def train_test_split(self, train_size = 80, test_size = 20, shuffle = True):
if(shuffle):
self.dataset_df = self.dataset_df.sample(frac = 1)
dataset_val = self.dataset_df.values
self.X_train, self.Y_train = dataset_val[:int((train_size/100)*len(dataset_val)),:-1], dataset_val[:int((train_size/100)*len(dataset_val)),-1]
self.X_test, self.Y_test = dataset_val[int((train_size/100)*len(dataset_val)):,:-1], dataset_val[int((train_size/100)*len(dataset_val)):,-1]
# function which fits and predicts the results and computes the score
# it also displays the decision trees in different models.
# the function takes input the plotting method of the decision tree as well
# as the feature names and class names of the dataset
def fit_and_predict_and_PlotDT(self, plot_method="export_text", feature_names=None, cls_names=None):
# decision tree classifier object initialised
dt = DecisionTreeClassifier(criterion = self.criterion,
max_depth = self.max_depth,
max_features = self.max_features,
splitter = self.splitter,
min_samples_split = self.min_samples_split,
min_samples_leaf = self.min_samples_leaf)
# fitting the training dataset
dt.fit(self.X_train, self.Y_train)
# print the scores of the dataset
print("The score on the training set is: ", dt.score(self.X_train, self.Y_train))
print("The score on the testing set is: ", dt.score(self.X_test, self.Y_test))
# the following conditions would execute if the plot_method
# given as input matches with anyone of the options
if(plot_method == "export_text"):
text_representation = tree.export_text(dt)
print(text_representation)
elif(plot_method == "plot_tree"):
fig = plt.figure(figsize=(25,20))
pltree = tree.plot_tree(dt, feature_names=feature_names, class_names=cls_names, filled=True)
fig.savefig("decistion_tree.png")
elif(plot_method == "graphviz"):
dot_data = tree.export_graphviz(dt, out_file=None, feature_names=feature_names,
class_names=cls_names,filled=True)
graph = graphviz.Source(dot_data, format="png")
graph
graph.render("decision_tree_graphivz")
elif(plot_method == "dtreeviz"):
viz = dtreeviz(dt, np.c_(self.X_train, self.X_test), np.c_(self.Y_train, self.Y_test), target_name="target", feature_names=feature_names, class_names=list(cls_names))
viz
# loading the iris dataset
iris_dataset = load_iris()
# store the data in the form of a dataset
df = pd.DataFrame(iris_dataset.data, columns = iris_dataset.feature_names)
# add teh target column in the dataframe
df['target'] = iris_dataset.target
# create the Decision tree visualiser object
dtv = DecisionTreeVisualiser(df)
# call the train_test_split function
dtv.train_test_split()
# fit_predict and plot the decision tree
dtv.fit_and_predict_and_PlotDT(plot_method = "graphviz", feature_names = iris_dataset.feature_names, cls_names = iris_dataset.target_names)
You can also try this code with Online Python Compiler
Q1. What are the advantages of decision trees? Ans. The decision tree is a commonly used model as it’s easy to understand and takes fewer training periods.
Q2. Why is dtreeviz used? Ans. dtreeviz is a python library used to visualize a decision tree and gives an explanatory view, including plots.
Q3. How to tune the decision trees to get better results? Ans. To tune the decision trees to get better results, one could manually tune the max_depth parameter and the number of leaf nodes. The other way is to change the decision criterion.
Key takeaways
This article gave a brief explanation about how we can visualize decision trees using various methods present in the python library. We looked at the graphviz method, the export_text method, the plot_tree method, and the dtreeviz method. To dive deeper into machine learning, check out our industry-level courses on coding ninjas.
Live masterclass
System Design Questions Asked at Microsoft, Oracle, PayPal
by Pranav Malik
23 Apr, 2025
01:30 PM
Master DSA to Ace MAANG SDE Interviews
by Saurav Prateek
21 Apr, 2025
01:30 PM
Google Data Analyst roadmap: Essential SQL concepts
by Maaheen Jaiswal
22 Apr, 2025
01:30 PM
Amazon Data Analyst: Advanced Excel & AI Interview Tips
by Megna Roy
24 Apr, 2025
01:30 PM
System Design Questions Asked at Microsoft, Oracle, PayPal