Controlling Tree Depth to Prevent Overfitting
Learners will train trees of varying max_depth, observe the overfitting-underfitting trade-off, and choose depth via validation score.
How Depth Leads to Overfitting
An unconstrained decision tree will grow until every training sample has its own leaf — achieving 100% training accuracy by memorising all data points, including noise. This is the extreme case of overfitting: the tree learns the idiosyncrasies of training data rather than general patterns. On any new data, such a tree performs poorly because its rules are too specific. Controlling tree depth is the primary regularisation mechanism for decision trees, analogous to choosing alpha in regularised linear models or choosing k in KNN.
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
X, y = load_breast_cancer(return_X_y=True)
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)
# Unlimited depth: memorises training data
tree = DecisionTreeClassifier(random_state=42) # no max_depth
tree.fit(X_tr, y_tr)
print('Unlimited depth tree:')
print(' Tree depth:', tree.get_depth())
print(' Train accuracy:', tree.score(X_tr, y_tr).round(3)) # 1.000
print(' Test accuracy:', tree.score(X_te, y_te).round(3)) # < 1.000The max_depth Parameter
max_depth limits how many levels the tree can grow. With max_depth=1, the tree makes exactly one decision (a 'stump'). With max_depth=3, the tree can make up to three sequential questions. Shallower trees generalise better but may underfit; deeper trees fit training data better but risk overfitting. The right max_depth is a hyperparameter found through cross-validation. A useful heuristic: start around max_depth=3-5 and tune from there using a validation curve.
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
X, y = load_breast_cancer(return_X_y=True)
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)
for depth in [1, 2, 3, 5, 10, None]:
tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
tree.fit(X_tr, y_tr)
print(f'max_depth={str(depth):4}: train={tree.score(X_tr,y_tr):.3f}, '
f'test={tree.score(X_te,y_te):.3f}')All lessons in this course
- Building a Tree: Splits, Nodes, and Leaves
- Gini Impurity and Information Gain
- Controlling Tree Depth to Prevent Overfitting
- Visualising and Interpreting Decision Trees