Gini Impurity and Information Gain
Learners will calculate Gini impurity and entropy for sample splits, understand why the tree picks the split that maximises information gain.
The Splitting Problem: Which Feature to Ask About?
When building a decision tree, at each node we must choose which feature and which threshold produces the most useful split. The goal is to create child nodes where samples are as pure as possible — ideally each child contains only one class. We need a mathematical measure of impurity that tells us how mixed the classes are in a node. Lower impurity is better: a node with all samples belonging to one class has zero impurity (perfect purity). Two widely-used impurity measures are Gini impurity and entropy.
# Impurity measures how mixed the classes are in a node
# Perfect purity: all samples belong to one class -> impurity = 0
# Maximum impurity: classes are equally distributed
import numpy as np
# Node A: all class 0 -> pure
node_a = [0, 0, 0, 0] # impurity = 0
# Node B: 50/50 mix -> maximally impure
node_b = [0, 0, 1, 1] # impurity = maximum
# Node C: mostly one class
node_c = [0, 0, 0, 1] # impurity = low
for name, node in [('A', node_a), ('B', node_b), ('C', node_c)]:
print(f'Node {name}: classes = {node}')Gini Impurity: The Default Criterion
Gini impurity measures the probability that a randomly chosen sample from a node would be incorrectly labelled if labelled randomly according to the class distribution in that node. The formula is: Gini = 1 - sum(p_i^2) where p_i is the proportion of class i. Gini ranges from 0 (pure) to 0.5 (two-class equal split). For K classes, the maximum is 1 - 1/K. Gini impurity is the default criterion in scikit-learn's DecisionTreeClassifier because it is computationally efficient (no logarithms).
import numpy as np
def gini_impurity(y):
classes, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
return 1 - np.sum(probabilities ** 2)
# Pure node
print('Pure [0,0,0,0]:', gini_impurity([0,0,0,0])) # 0.0
# 50/50 split
print('50/50 [0,0,1,1]:', gini_impurity([0,0,1,1])) # 0.5
# 75/25 split
print('75/25 [0,0,0,1]:', gini_impurity([0,0,0,1])) # 0.375
# Three classes equal
print('3-class equal:', gini_impurity([0,1,2,0,1,2])) # ~0.667All lessons in this course
- Building a Tree: Splits, Nodes, and Leaves
- Gini Impurity and Information Gain
- Controlling Tree Depth to Prevent Overfitting
- Visualising and Interpreting Decision Trees