Table of contents
1.
Introduction
2.
A Brief Introduction to Decision Trees
3.
Various ways to visualize the Decision tree
3.1.
Code in Python
3.2.
Output Using graphviz method
3.3.
Output Using plot_tree method
3.4.
Output Using export_text method
3.5.
Output Using dtreeviz method
4.
Frequently Asked Questions
5.
Key takeaways
Last Updated: Mar 27, 2024

Visualizing Decision Trees

Author aniket verma
0 upvote
Career growth poll
Do you think IIT Guwahati certified course can help you in your career?

Introduction

Whenever we use decision trees to make decisions, we often look at how it performs using different metrics like accuracy, etc. But, if we want to visualize how the tree is created and how the decisions are being made, we need to plot the decision tree corresponding to the dataset fed to the model. This blog will look at how we can plot the decision tree and look at multiple ways available in python to do the same. 

Source: Link

A Brief Introduction to Decision Trees

A decision tree is a common supervised learning method in machine learning that is often used for various classification and regression problems. It is convenient because it doesn’t require feature scaling and is easy to understand. To interpret them, we can visualize them by plotting the decision tree and understanding how our model is working, and making adjustments according to it. To know more about decision trees, you can refer here.

Various ways to visualize the Decision tree

There are multiple methods present in python libraries to visualize decision trees. Some of them are as follows:

  1. Visualizing Decision Trees using Sklearn plot tree method.
  2. Visualizing Decision Trees using Matplotlib
  3. Visualizing Decision Trees using Graphviz
  4. Visualizing Decision Trees using dtreeviz
     

Let’s look at each of these methods with the help of a code.

Code in Python

We will be using the inbuilt iris dataset of sklearn, which is present inside sklearn.datasets. 

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from dtreeviz.trees import dtreeviz 
from sklearn import tree
import graphviz
import random
import math
import warnings
warnings.filterwarnings("ignore")

### this class basically helps visualise Decision Trees
### by passing a decision tree model and the dataset dataframe.
### we can train and fit the model and visualise the decision trees
class DecisionTreeVisualiser:
    
    # this is an initialiser function that would store
    # store the dataframe of the dataset along with 
    # parameters required by a Decision Tree Model
    def __init__(self, dataset_df, min_samples_leaf = 1,min_samples_split = 2, criterion="gini", splitter = "best", max_depth = None, max_features = None, min_weight_fraction_leaf = None, max_leaf_nodes = None):
        self.criterion = criterion
        self.max_depth = max_depth
        self.max_features = max_features
        self.splitter = splitter
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_fraction_leaf = min_weight_fraction_leaf
        self.dataset_df = dataset_df
        self.X_train, self.Y_train = None, None
        self.X_test, self.Y_test = None, None
        
    # function with which you can visualise the dataframe
    # with the first num_entries which is set to default 5
    # you can change as per your choice
    def visualiseDataset(self, num_entries = 5):
        self.dataset_df.head(num_entries)
    
    # function the splits the dataset into training
    # and testing dataset splitting 100 % into x percent training data
    # z% testing data such that x+y+z = 100
    # by default it's set to x = 80, z = 20
    # another default parameter is shuffle which allows to shuffle to data
    # you can set it to false if you don't want to shuffle the data
    def train_test_split(self, train_size = 80, test_size = 20, shuffle = True):
        if(shuffle):
            self.dataset_df = self.dataset_df.sample(frac = 1)
        dataset_val = self.dataset_df.values
        self.X_train, self.Y_train = dataset_val[:int((train_size/100)*len(dataset_val)),:-1], dataset_val[:int((train_size/100)*len(dataset_val)),-1]
        self.X_test, self.Y_test = dataset_val[int((train_size/100)*len(dataset_val)):,:-1], dataset_val[int((train_size/100)*len(dataset_val)):,-1]
    
    # function which fits and predicts the results and computes the score 
    # it also displays the decision trees in different models.
    # the function takes input the plotting method of the decision tree as well
    # as the feature names and class names of the dataset
    def fit_and_predict_and_PlotDT(self, plot_method="export_text", feature_names=None, cls_names=None):
        # decision tree classifier object initialised
        dt = DecisionTreeClassifier(criterion = self.criterion, 
                                    max_depth = self.max_depth,
                                    max_features = self.max_features,
                                    splitter = self.splitter,
                                    min_samples_split = self.min_samples_split,
                                    min_samples_leaf = self.min_samples_leaf)
        # fitting the training dataset 
        dt.fit(self.X_train, self.Y_train)
        
        # print the scores of the dataset
        print("The score on the training set is: ", dt.score(self.X_train, self.Y_train))
        print("The score on the testing set is: ", dt.score(self.X_test, self.Y_test))
        
        # the following conditions would execute if the plot_method 
        # given as input matches with anyone of the options
        if(plot_method == "export_text"):
            text_representation = tree.export_text(dt)
            print(text_representation)
        elif(plot_method == "plot_tree"):
            fig = plt.figure(figsize=(25,20))
            pltree = tree.plot_tree(dt, feature_names=feature_names, class_names=cls_names, filled=True)
            fig.savefig("decistion_tree.png")
        elif(plot_method == "graphviz"):
            dot_data = tree.export_graphviz(dt, out_file=None, feature_names=feature_names,  
                                            class_names=cls_names,filled=True)
            graph = graphviz.Source(dot_data, format="png") 
            graph
            graph.render("decision_tree_graphivz")
        elif(plot_method == "dtreeviz"):
            viz = dtreeviz(dt, np.c_(self.X_train, self.X_test), np.c_(self.Y_train, self.Y_test), target_name="target", feature_names=feature_names, class_names=list(cls_names))
            viz            
# loading the iris dataset
iris_dataset = load_iris()
# store the data in the form of a dataset
df = pd.DataFrame(iris_dataset.data, columns = iris_dataset.feature_names)
# add teh target column in the dataframe
df['target'] = iris_dataset.target

# create the Decision tree visualiser object 
dtv = DecisionTreeVisualiser(df)
# call the train_test_split function 
dtv.train_test_split()
# fit_predict and plot the decision tree
dtv.fit_and_predict_and_PlotDT(plot_method = "graphviz", feature_names = iris_dataset.feature_names, cls_names = iris_dataset.target_names)
You can also try this code with Online Python Compiler
Run Code

 

Output Using graphviz method

 

Output Using plot_tree method

 

Output Using export_text method

 

Output Using dtreeviz method

Frequently Asked Questions

Q1. What are the advantages of decision trees?
Ans. The decision tree is a commonly used model as it’s easy to understand and takes fewer training periods. 

Q2. Why is dtreeviz used?
Ans. dtreeviz is a python library used to visualize a decision tree and gives an explanatory view, including plots.

Q3. How to tune the decision trees to get better results?
Ans. To tune the decision trees to get better results, one could manually tune the max_depth parameter and the number of leaf nodes. The other way is to change the decision criterion.

Key takeaways

This article gave a brief explanation about how we can visualize decision trees using various methods present in the python library. We looked at the graphviz method, the export_text method, the plot_tree method, and the dtreeviz method. To dive deeper into machine learning, check out our industry-level courses on coding ninjas.

Live masterclass