In this notebook I walk through how a classification decision tree is fit, how inference is performed, how to reduce overfitting, how splits are determined. This is based on a notebook I created while I was fellow at Insight Data Science when studying for interviews.

import graphviz
import matplotlib.pyplot as plt
import numpy as np
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

scikit-learn now comes with a way to plot trees, but I prefer using graphviz so here is a quick function to plot a tree, which we monkey-patch into the DecisionTreeClassifier and DecisionTreeRegressor classes.

def display_tree(self):
    dot_data = tree.export_graphviz(self, out_file=None, filled=True, rounded=True, special_characters=True)
    graph = graphviz.Source(dot_data)
    return graph

DecisionTreeClassifier.plot = display_tree

We'll use the Iris dataset as a quick way to discuss classification trees. To learn more about this dataset use help(load_iris).

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

Fitting

Decision trees search over all possible ways to split up features and find the split that is most informative about the target variable. The parent node splits into two child nodes based on this split. From there the children also split in the same manner until all leaves are pure, unless another stopping condition is specified. A leaf is a node that has no children. A pure leaf is a leaf with only one class of items in it.

Classifications trees split using the GINI impurity which is defined as:

$I_{G}(p) = \sum_{i=1}^{J}p_{i}(1-p_{i})$

Here $p_{i}$ is the probability of an item with label $i$ being chosen and $1 - p_{i}$ is the probability of a mistake in categorizing that item. $J$ is the number of classes. Gini reaches zero when all cases in the node fall into a single target category.

Alternatively, one can use information gain to decide where to split, where information gain is defined as the difference in entropy of the parent and the weighted sum of the entropies of the children. Entropy is defined as:

$H(p) = -\sum_{i=1}^{J}p_{i}\log_{2}(p_{i})$

Let's train a decision tree:

clf = DecisionTreeClassifier(criterion="entropy").fit(X_train, y_train)
clf.plot()
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 X 2 ≤ 2.35 entropy = 1.581 samples = 112 value = [37, 34, 41] 1 entropy = 0.0 samples = 37 value = [37, 0, 0] 0->1 True 2 X 2 ≤ 4.95 entropy = 0.994 samples = 75 value = [0, 34, 41] 0->2 False 3 X 3 ≤ 1.65 entropy = 0.414 samples = 36 value = [0, 33, 3] 2->3 8 X 3 ≤ 1.75 entropy = 0.172 samples = 39 value = [0, 1, 38] 2->8 4 entropy = 0.0 samples = 32 value = [0, 32, 0] 3->4 5 X 1 ≤ 3.1 entropy = 0.811 samples = 4 value = [0, 1, 3] 3->5 6 entropy = 0.0 samples = 3 value = [0, 0, 3] 5->6 7 entropy = 0.0 samples = 1 value = [0, 1, 0] 5->7 9 X 3 ≤ 1.65 entropy = 0.811 samples = 4 value = [0, 1, 3] 8->9 12 entropy = 0.0 samples = 35 value = [0, 0, 35] 8->12 10 entropy = 0.0 samples = 3 value = [0, 0, 3] 9->10 11 entropy = 0.0 samples = 1 value = [0, 1, 0] 9->11

Here's the training score which indeed shows the tree is perfect at classfying the flowers on the training set. This tends to result in overfitting to the training set.

clf.score(X_train, y_train)
1.0

This is an easy dataset to classify, so the overfitting is not evident here.

clf.score(X_test, y_test)
0.9736842105263158

There are 7 leaves in our tree. Note that the leaves do not have to be depicted at the bottom of the tree in the diagram. A leaf is just a node without any children and could be represented near the top of the tree.

clf.get_n_leaves()
7

Inference

Now that we have trained our model, we can perform inference.

When inference on new samples is performed, the sample simply is examined with the "rules" created by the feature splits. Starting from the topmost node (the root node) in our example above, if feature three has a value of less than or equal 0.8, go to the left child node; otherwise go to the right. This process continues all the way down until the sample is put into a leaf.

The predicted class is the class in the leaf with the highest probability of that class for that leaf. In other words, simply break down the training samples by class within that leaf and choose the class with the most number of train samples. The probability of choosing that class is simply the number of training samples in that leaf belonging to that class divided by the total number of training samples in that leaf.

Since all of our leafs are pure, the classifier will always give 100% for its predictions. We'll rexamine this when we have impure leafs below.

Here are the features for the first test sample.

X_test[0]
array([5.8, 2.8, 5.1, 2.4])

The path for this sample follows down the right side of the tree. Note that features are zero-indexed.

Root node: Is 2.4 <= 0.8? No, so go right.

Is 5.1 <= 4.95? No, so go right.

Is 5.1 <= 5.05? No, so go right.

That brings it to the leaf on the right with 35 samples, where the 3rd class (index 2) is predicted:

clf.predict([X_test[0]])
array([2])

Again, the probability is 100% since there are no training samples in that leaf from the other two classes. In the plot the node is colored dark purple. For this tree purple represents the 3rd class and the darker the shade the more probable it is.

clf.predict_proba([X_test[0]])
array([[0., 0., 1.]])

Pruning

One way to prevent overfitting is the pre-prune the tree by specifying the maximum depth and/or maximum number of leaves. Here we set the maximum depth to 3.

clf = DecisionTreeClassifier(max_depth=3, random_state=0, criterion="entropy").fit(X_train, y_train)
clf.plot()
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 X 3 ≤ 0.8 entropy = 1.581 samples = 112 value = [37, 34, 41] 1 entropy = 0.0 samples = 37 value = [37, 0, 0] 0->1 True 2 X 2 ≤ 4.95 entropy = 0.994 samples = 75 value = [0, 34, 41] 0->2 False 3 X 3 ≤ 1.65 entropy = 0.414 samples = 36 value = [0, 33, 3] 2->3 6 X 2 ≤ 5.05 entropy = 0.172 samples = 39 value = [0, 1, 38] 2->6 4 entropy = 0.0 samples = 32 value = [0, 32, 0] 3->4 5 entropy = 0.811 samples = 4 value = [0, 1, 3] 3->5 7 entropy = 0.811 samples = 4 value = [0, 1, 3] 6->7 8 entropy = 0.0 samples = 35 value = [0, 0, 35] 6->8
clf.score(X_train, y_train)
0.9821428571428571
clf.score(X_test, y_test)
0.9736842105263158

Here is a sample where we are only 75% sure that it is class 2, since only 3 of the 4 samples in its leaf are class 2.

clf.predict([X_test[20]])
array([2])
clf.predict_proba([X_test[20]])
array([[0.  , 0.25, 0.75]])

A popular and effective way to reduce overfitting is to create a "forest" of weak decision trees and use them together (for example, random forests). This is beyond the scope of this notebook.

Information gain & splits

Let's talk a little bit more about how trees use entropy (or alternatively Gini) to determine splits.

Information gain is calculated by cycling through all possible splits in the training set. Practically this is the process:

  1. Select the first feature.
  2. Pick the halfway point between the first sample and the second sample.
  3. Calculate the entropy of the two child nodes if a split is made at that point.
  4. Repeat steps 2 and 3 for all midpoints for this feature.
  5. Go back to step 1 and repeat for all features.

At the end, pick the feature and split that has the lowest weighted summation of the entropies for the two child nodes.

Again, information gain is the difference in entropy of the parent node and the weighted summation of entropies of the two child nodes. Since the entropy of the parent node is the same for each potential split that we try, we only need to calculate the entropies of the child nodes for a split and find the split that minimizes their weighted sum since that will maximize the information gained.

Here's the formula for information gain (IG), where $H_{parent}$ is the entropy of the parent node, $N_{left}$ is the number of samples in the left child, $H_{left}$ is the entropy of the left child, $N_{right}$ is the number of samples in the right child, $H_{right}$ is the entropy of the right child:

$IG = H_{parent} - (N_{left} H_{left} + N_{right} H_{right})$

Here's a tree with a max depth of 1 using entropy to split:

clf = DecisionTreeClassifier(random_state=42, max_depth=1, criterion="entropy").fit(X_train, y_train)
clf.plot()
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 X 2 ≤ 2.35 entropy = 1.581 samples = 112 value = [37, 34, 41] 1 entropy = 0.0 samples = 37 value = [37, 0, 0] 0->1 True 2 entropy = 0.994 samples = 75 value = [0, 34, 41] 0->2 False

Let's implement our entropy calculation. As a reminder, the formula is:

$H(p) = -\sum_{i=1}^{J}p_{i}\log_{2}(p_{i})$

This takes a list of targets and calculates the entropy for that node.

from collections import Counter

def entropy(values):
    total = values.shape[0]
    c = Counter(values)
    
    if total == 0:
        return 0
    
    s = 0
    for x in c.values():
        p = x/total
        s += p * np.log2(p)
        
    return -s

Plotting entropy in the binary class case

Let's take a quick look at an example where we have two classes, 0 and 1, in 100 samples (so 50 of each class) to get an idea of what entropy is describing. If we perfectly split our data into 0's in the left bucket, and 1's into the right bucket, we have 0 entropy. Whatever (hypothetical) decision rule that caused that split has split our data perfectly.

left, right = np.zeros(50), np.ones(50)

Because each node is pure, the entropy is 0 for each node:

entropy(left), entropy(right)
(-0.0, -0.0)
left[0], right[0] = 1, 0

Now let's change 1 sample - moving one of the 0's from the left node to the right, and one of the 1's from the right node to the left. Here are the entropies of those child nodes. In both cases, the entropy goes up by the same amount.

entropy(left), entropy(right)
(0.14144054254182067, 0.14144054254182067)

Let's make this discussion even simpler by looking at just a single node containing 100 samples of either 0's or 1's. Let's start with a pure node of containing just 0's and incrementally change those 0's to 1's and find out what happens with the entropy:

node = np.zeros(100)
x = []
y = []
for i in range(99):
    x.append((i+1)/100.)
    node[i] = 1
    y.append(entropy(node))

Here's the plot of the entropy of that node as a function of the fraction of positive class samples (1's) in that node. Initially when the fraction is zero (the node is all 0's), the node is pure, so the entropy is 0. As we begin to change 0's for 1's the entropy increases and eventually reaches 1 when the fraction is 0.5 (equal amount of 1's and 0's). From there, the entropy decreases as the fraction of the positive class continues to increase and eventually reaches zero again when the node is all 1's and is thu pure.

plt.plot(x,y)
plt.ylabel("Entropy")
plt.xlabel("Fraction positive class")
plt.show()

Remember, splits are determined by information gain, that is, the difference in entropy of the parent and the weighted summation of the entropies of the children. If you have a child node with an entropy of 1.0 that means that node contributes nothing to understanding of how to split the data into the two classes. If you have a child node with an entropy of 0.0 you have been able to perfectly segment out those samples. But again, we want the weighted combination of the entropies of the children to be low. You might have a situation where one is very low and one is very high, but it's better to split elsewhere to make the weighted sum lower.

Finding the best split through iteration

Here's our original list of targets from the Iris example:

y_train
array([1, 1, 2, 0, 2, 0, 0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2,
       1, 0, 2, 1, 1, 1, 1, 2, 0, 0, 2, 1, 0, 0, 1, 0, 2, 1, 0, 1, 2, 1,
       0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 0, 0, 2, 0, 0, 0, 1, 2,
       2, 0, 0, 0, 1, 1, 0, 0, 1, 0, 2, 1, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0,
       2, 1, 1, 1, 2, 2, 1, 1, 0, 1, 2, 2, 0, 1, 1, 1, 1, 0, 0, 0, 2, 1,
       2, 0])

Here's the associated entropy which matches above. This is the entropy of the root node.

entropy(y_train)
1.5807197138422104

Now let's get the entropies of the child nodes for the split that was found. In this case it was feature 3 with the split at 0.8.

feat = 3
x = 0.8

Here are the samples that went into the left node. As you can see all of them are of class 0. There is no entropy associated with this node - the node is pure.

left = y_train[X_train[:,feat] <= x]
left
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
entropy(left)
-0.0

Here are the samples that went into the right node:

right = y_train[X_train[:,feat] > x]
right
array([1, 1, 2, 2, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2, 1, 2, 1,
       1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 1, 1, 1,
       2, 2, 1, 1, 1, 1, 2, 1, 2])
entropy(right)
0.993707106604508

Now we need the cycle through each feature, and then cycle through each possible split.

import math

def calc_entropies(X, y):
    
    n = y.shape[0]
    m = X.shape[1]
    min_ents = np.zeros(m)
    min_splits = np.zeros(m)

    for feat in range(m):
    
        # Use set to remove dups; sort it to get halfway points
        points = sorted(list(set(X[:,feat])))
        
        splits = [(points[i-1]+points[i])/2. for i in range(1, len(points))]
        entropies = []
        
        for x in splits:
        
            l = y_train[X_train[:,feat] <= x]
            r = y_train[X_train[:,feat] > x]
        
            e = (l.shape[0]*entropy(l) + r.shape[0]*entropy(r)) / n
            entropies.append(e)

        feat_min_ent = np.argmin(entropies)
        min_ents[feat] = entropies[feat_min_ent]
        min_splits[feat] = splits[feat_min_ent]
        
    min_feat = np.argmin(min_ents)
            
    return min_feat, min_splits[min_feat], min_ents[min_feat]
calc_entropies(X_train, y_train)
(2, 2.35, 0.6654288660298044)

Note that in this example splitting feature 3 at 0.8 gives the same entropy as the above split. np.argmin returns the argument of the first minimum in the case of a tie.