Pruning and Visualizing sklearn DecisionTreeClassifiers
This post serves two purposes:
- It illustrates and compares three different methods of visualizing
DecisionTreeClassifiers
from sklearn. - It shows a simple quick way of manually pruning selected nodes from the tree.
from dtreeviz.trees import *
from IPython.display import SVG
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
import copy
#for plotting
import matplotlib.pyplot as plt
from sklearn import tree
import graphviz
dtreeviz
We are using the wonderful tree visualization library dtreeviz
:
https://github.com/parrt/dtreeviz
def ViewSVG(viz):
from IPython.display import SVG
fname= viz.save_svg()
return SVG(fname)
clf1 = tree.DecisionTreeClassifier(max_depth=3) # limit depth of tree
iris = load_iris()
clf1.fit(iris.data, iris.target)
viz1 = dtreeviz(clf1,
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=["setosa", "versicolor", "virginica"] # need class_names for classifier
)
ViewSVG(viz1)
We now selectively prune the last two children which belong to parent node #6:
clf2 = copy.deepcopy(clf1)
#prune the tree
clf2.tree_.children_left[6] = -1
clf2.tree_.children_right[6] = -1
viz2 = dtreeviz(clf2,
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=["setosa", "versicolor", "virginica"] # need class_names for classifier
)
ViewSVG(viz2)
Using plot_tree
also works:
plt.rcParams["figure.figsize"]=10,8
tmp=tree.plot_tree(clf1)
plt.rcParams["figure.figsize"]=8,6
tmp=tree.plot_tree(clf2)
Graphviz
plt.rcParams["figure.figsize"]=5,5
dot_data = tree.export_graphviz(clf1, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph
dot_data = tree.export_graphviz(clf2, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph