Decision Trees
Intuitive and Interpretable Machine Learning! 🌳
Decision trees are one of the most intuitive machine learning algorithms, mimicking human decision-making processes. From their simple if-then structure to their powerful ensemble variants, decision trees form the foundation of many state-of-the-art algorithms. Master tree construction, pruning, and visualization to build interpretable yet powerful models.
Decision Tree Fundamentals
graph TD
A[Decision Trees] --> B[Tree Structure]
A --> C[Splitting Criteria]
A --> D[Tree Types]
B --> E[Root Node]
B --> F[Internal Nodes]
B --> G[Leaf Nodes]
B --> H[Branches]
C --> I[Gini Impurity]
C --> J[Information Gain]
C --> K[Gain Ratio]
C --> L[MSE Reduction]
D --> M[Classification Trees]
D --> N[Regression Trees]
style A fill:#f9f,stroke:#333,stroke-width:2px
style C fill:#bbf,stroke:#333,stroke-width:2px
style D fill:#fbf,stroke:#333,stroke-width:2px
Decision Tree Implementation from Scratch
Building a Classification Tree
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_classification, load_iris, load_wine
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.tree import plot_tree, export_text
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')
# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
class DecisionTreeNode:
"""Node class for decision tree"""
def __init__(self, feature=None, threshold=None, left=None, right=None,
value=None, samples=None, impurity=None):
# For internal nodes
self.feature = feature # Feature index for splitting
self.threshold = threshold # Threshold value for splitting
self.left = left # Left child node
self.right = right # Right child node
# For leaf nodes
self.value = value # Class prediction or regression value
# For visualization
self.samples = samples # Number of samples at this node
self.impurity = impurity # Impurity measure at this node
def is_leaf(self):
"""Check if node is a leaf"""
return self.value is not None
class DecisionTreeFromScratch:
"""Decision Tree implementation from scratch"""
def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1,
criterion='gini', random_state=None):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.criterion = criterion
self.random_state = random_state
self.root = None
self.n_features_ = None
self.n_classes_ = None
self.feature_names = None
self.class_names = None
if random_state is not None:
np.random.seed(random_state)
def fit(self, X, y, feature_names=None, class_names=None):
"""Build decision tree"""
self.n_features_ = X.shape[1]
self.n_classes_ = len(np.unique(y))
self.feature_names = feature_names or [f'Feature_{i}' for i in range(self.n_features_)]
self.class_names = class_names or [f'Class_{i}' for i in range(self.n_classes_)]
# Build tree recursively
self.root = self._build_tree(X, y, depth=0)
return self
def _build_tree(self, X, y, depth):
"""Recursively build the tree"""
n_samples = X.shape[0]
n_classes = len(np.unique(y))
# Calculate node impurity
impurity = self._calculate_impurity(y)
# Check stopping criteria
if (depth >= self.max_depth or
n_samples < self.min_samples_split or
n_classes == 1):
# Create leaf node
leaf_value = self._most_common_label(y)
return DecisionTreeNode(value=leaf_value, samples=n_samples,
impurity=impurity)
# Find best split
best_feature, best_threshold = self._best_split(X, y)
if best_feature is None:
# No valid split found
leaf_value = self._most_common_label(y)
return DecisionTreeNode(value=leaf_value, samples=n_samples,
impurity=impurity)
# Split data
left_indices = X[:, best_feature] <= best_threshold
right_indices = X[:, best_feature] > best_threshold
# Check minimum samples in leaves
if (np.sum(left_indices) < self.min_samples_leaf or
np.sum(right_indices) < self.min_samples_leaf):
leaf_value = self._most_common_label(y)
return DecisionTreeNode(value=leaf_value, samples=n_samples,
impurity=impurity)
# Recursively build left and right subtrees
left_child = self._build_tree(X[left_indices], y[left_indices], depth + 1)
right_child = self._build_tree(X[right_indices], y[right_indices], depth + 1)
return DecisionTreeNode(feature=best_feature, threshold=best_threshold,
left=left_child, right=right_child,
samples=n_samples, impurity=impurity)
def _best_split(self, X, y):
"""Find best feature and threshold to split on"""
best_gain = -1
best_feature = None
best_threshold = None
# Calculate parent impurity
parent_impurity = self._calculate_impurity(y)
# Try all features
for feature_idx in range(self.n_features_):
feature_values = X[:, feature_idx]
thresholds = np.unique(feature_values)
# Try all unique values as thresholds
for threshold in thresholds:
# Split data
left_mask = feature_values <= threshold
right_mask = feature_values > threshold
# Skip if split doesn't divide data
if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
continue
# Calculate information gain
gain = self._information_gain(y, left_mask, right_mask, parent_impurity)
if gain > best_gain:
best_gain = gain
best_feature = feature_idx
best_threshold = threshold
return best_feature, best_threshold
def _information_gain(self, y, left_mask, right_mask, parent_impurity):
"""Calculate information gain from split"""
n_total = len(y)
n_left = np.sum(left_mask)
n_right = np.sum(right_mask)
# Calculate weighted average impurity after split
left_impurity = self._calculate_impurity(y[left_mask])
right_impurity = self._calculate_impurity(y[right_mask])
weighted_impurity = (n_left / n_total) * left_impurity + \
(n_right / n_total) * right_impurity
# Information gain is reduction in impurity
return parent_impurity - weighted_impurity
def _calculate_impurity(self, y):
"""Calculate impurity (Gini or Entropy)"""
proportions = np.bincount(y.astype(int)) / len(y)
if self.criterion == 'gini':
# Gini impurity: 1 - Σ(p_i^2)
return 1 - np.sum(proportions ** 2)
elif self.criterion == 'entropy':
# Entropy: -Σ(p_i * log2(p_i))
# Avoid log(0)
proportions = proportions[proportions > 0]
return -np.sum(proportions * np.log2(proportions))
else:
raise ValueError(f"Unknown criterion: {self.criterion}")
def _most_common_label(self, y):
"""Return most common class label"""
return np.bincount(y.astype(int)).argmax()
def predict(self, X):
"""Predict class for samples"""
return np.array([self._predict_sample(sample) for sample in X])
def _predict_sample(self, sample):
"""Predict class for a single sample"""
node = self.root
while not node.is_leaf():
if sample[node.feature] <= node.threshold:
node = node.left
else:
node = node.right
return node.value
def print_tree(self, node=None, depth=0):
"""Print tree structure"""
if node is None:
node = self.root
if node.is_leaf():
print(f"{' ' * depth}Predict: {self.class_names[node.value]} "
f"(samples={node.samples}, impurity={node.impurity:.3f})")
else:
print(f"{' ' * depth}If {self.feature_names[node.feature]} <= {node.threshold:.3f} "
f"(samples={node.samples}, impurity={node.impurity:.3f})")
self.print_tree(node.left, depth + 1)
print(f"{' ' * depth}Else:")
self.print_tree(node.right, depth + 1)
def get_depth(self, node=None):
"""Get tree depth"""
if node is None:
node = self.root
if node.is_leaf():
return 0
return 1 + max(self.get_depth(node.left), self.get_depth(node.right))
def get_n_leaves(self, node=None):
"""Get number of leaf nodes"""
if node is None:
node = self.root
if node.is_leaf():
return 1
return self.get_n_leaves(node.left) + self.get_n_leaves(node.right)
# Generate sample data
X, y = make_classification(n_samples=200, n_features=4, n_informative=3,
n_redundant=0, n_classes=3, random_state=42)
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
random_state=42, stratify=y)
# Train custom decision tree
print("="*60)
print("DECISION TREE FROM SCRATCH")
print("="*60)
tree_custom = DecisionTreeFromScratch(max_depth=3, min_samples_split=5,
criterion='gini', random_state=42)
tree_custom.fit(X_train, y_train,
feature_names=['Feature_A', 'Feature_B', 'Feature_C', 'Feature_D'],
class_names=['Class_0', 'Class_1', 'Class_2'])
# Make predictions
y_pred_custom = tree_custom.predict(X_test)
accuracy_custom = accuracy_score(y_test, y_pred_custom)
print(f"\nCustom Tree Statistics:")
print(f" Depth: {tree_custom.get_depth()}")
print(f" Number of leaves: {tree_custom.get_n_leaves()}")
print(f" Test Accuracy: {accuracy_custom:.3f}")
print(f"\nTree Structure:")
tree_custom.print_tree()
# Compare with scikit-learn
tree_sklearn = DecisionTreeClassifier(max_depth=3, min_samples_split=5,
criterion='gini', random_state=42)
tree_sklearn.fit(X_train, y_train)
y_pred_sklearn = tree_sklearn.predict(X_test)
accuracy_sklearn = accuracy_score(y_test, y_pred_sklearn)
print(f"\nScikit-learn Tree Accuracy: {accuracy_sklearn:.3f}")
Understanding Splitting Criteria
Gini Impurity vs Information Gain
class SplittingCriteriaAnalysis:
"""Analyze different splitting criteria for decision trees"""
def __init__(self):
self.criteria = ['gini', 'entropy']
def calculate_gini(self, y):
"""Calculate Gini impurity"""
if len(y) == 0:
return 0
proportions = np.bincount(y) / len(y)
return 1 - np.sum(proportions ** 2)
def calculate_entropy(self, y):
"""Calculate entropy"""
if len(y) == 0:
return 0
proportions = np.bincount(y) / len(y)
proportions = proportions[proportions > 0]
return -np.sum(proportions * np.log2(proportions))
def calculate_gain_ratio(self, y, left_mask, right_mask):
"""Calculate gain ratio (C4.5 criterion)"""
n_total = len(y)
n_left = np.sum(left_mask)
n_right = np.sum(right_mask)
# Information gain
parent_entropy = self.calculate_entropy(y)
left_entropy = self.calculate_entropy(y[left_mask])
right_entropy = self.calculate_entropy(y[right_mask])
info_gain = parent_entropy - \
(n_left/n_total * left_entropy + n_right/n_total * right_entropy)
# Split information
p_left = n_left / n_total
p_right = n_right / n_total
split_info = 0
if p_left > 0:
split_info -= p_left * np.log2(p_left)
if p_right > 0:
split_info -= p_right * np.log2(p_right)
# Gain ratio
if split_info == 0:
return 0
return info_gain / split_info
def visualize_impurity_functions(self):
"""Visualize impurity functions for binary classification"""
# Probability range for binary classification
p = np.linspace(0.001, 0.999, 100)
# Calculate impurities
gini = 2 * p * (1 - p)
entropy = -p * np.log2(p) - (1-p) * np.log2(1-p)
misclassification = 1 - np.maximum(p, 1-p)
# Plot
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(p, gini, label='Gini Impurity', linewidth=2)
ax.plot(p, entropy, label='Entropy', linewidth=2)
ax.plot(p, misclassification, label='Misclassification Error', linewidth=2)
ax.set_xlabel('Proportion of Class 1')
ax.set_ylabel('Impurity')
ax.set_title('Comparison of Impurity Measures (Binary Classification)')
ax.legend()
ax.grid(True, alpha=0.3)
# Mark maximum impurity point
ax.axvline(x=0.5, color='red', linestyle='--', alpha=0.5)
ax.text(0.5, 0.1, 'Maximum\nImpurity', ha='center', fontsize=9)
plt.tight_layout()
plt.show()
def compare_split_quality(self, X, y):
"""Compare split quality using different criteria"""
# Try different features and thresholds
results = []
for feature_idx in range(X.shape[1]):
feature_values = X[:, feature_idx]
threshold = np.median(feature_values)
left_mask = feature_values <= threshold
right_mask = feature_values > threshold
if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
continue
# Calculate different metrics
parent_gini = self.calculate_gini(y)
parent_entropy = self.calculate_entropy(y)
# Gini gain
left_gini = self.calculate_gini(y[left_mask])
right_gini = self.calculate_gini(y[right_mask])
n_left, n_right = np.sum(left_mask), np.sum(right_mask)
n_total = len(y)
gini_gain = parent_gini - \
(n_left/n_total * left_gini + n_right/n_total * right_gini)
# Information gain
left_entropy = self.calculate_entropy(y[left_mask])
right_entropy = self.calculate_entropy(y[right_mask])
info_gain = parent_entropy - \
(n_left/n_total * left_entropy + n_right/n_total * right_entropy)
# Gain ratio
gain_ratio = self.calculate_gain_ratio(y, left_mask, right_mask)
results.append({
'Feature': f'Feature_{feature_idx}',
'Threshold': threshold,
'Gini Gain': gini_gain,
'Info Gain': info_gain,
'Gain Ratio': gain_ratio
})
results_df = pd.DataFrame(results)
# Visualize comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Gini Gain
axes[0].bar(range(len(results_df)), results_df['Gini Gain'], alpha=0.7)
axes[0].set_xlabel('Feature')
axes[0].set_ylabel('Gini Gain')
axes[0].set_title('Gini Gain by Feature')
axes[0].set_xticks(range(len(results_df)))
axes[0].set_xticklabels(results_df['Feature'], rotation=45)
axes[0].grid(True, alpha=0.3)
# Information Gain
axes[1].bar(range(len(results_df)), results_df['Info Gain'],
alpha=0.7, color='orange')
axes[1].set_xlabel('Feature')
axes[1].set_ylabel('Information Gain')
axes[1].set_title('Information Gain by Feature')
axes[1].set_xticks(range(len(results_df)))
axes[1].set_xticklabels(results_df['Feature'], rotation=45)
axes[1].grid(True, alpha=0.3)
# Gain Ratio
axes[2].bar(range(len(results_df)), results_df['Gain Ratio'],
alpha=0.7, color='green')
axes[2].set_xlabel('Feature')
axes[2].set_ylabel('Gain Ratio')
axes[2].set_title('Gain Ratio by Feature')
axes[2].set_xticks(range(len(results_df)))
axes[2].set_xticklabels(results_df['Feature'], rotation=45)
axes[2].grid(True, alpha=0.3)
plt.suptitle('Comparison of Splitting Criteria', fontsize=14)
plt.tight_layout()
plt.show()
return results_df
# Analyze splitting criteria
print("\n" + "="*60)
print("SPLITTING CRITERIA ANALYSIS")
print("="*60)
criteria_analysis = SplittingCriteriaAnalysis()
# Visualize impurity functions
print("\nVisualizing impurity functions...")
criteria_analysis.visualize_impurity_functions()
# Compare split quality
print("\nComparing split quality metrics...")
split_comparison = criteria_analysis.compare_split_quality(X_train, y_train)
print("\nSplit Quality Comparison:")
print(split_comparison.to_string(index=False))
Tree Visualization and Interpretation
Visualizing Decision Trees
# Load iris dataset for visualization
iris = load_iris()
X_iris, y_iris = iris.data[:, [2, 3]], iris.target # Use only 2 features for visualization
# Split data
X_train_iris, X_test_iris, y_train_iris, y_test_iris = train_test_split(
X_iris, y_iris, test_size=0.3, random_state=42, stratify=y_iris
)
class TreeVisualization:
"""Visualize and interpret decision trees"""
def __init__(self, tree_model, X, y, feature_names=None, class_names=None):
self.tree = tree_model
self.X = X
self.y = y
self.feature_names = feature_names or [f'Feature_{i}' for i in range(X.shape[1])]
self.class_names = class_names or [f'Class_{i}' for i in range(len(np.unique(y)))]
def plot_tree_structure(self):
"""Plot tree structure using sklearn's plot_tree"""
fig, ax = plt.subplots(figsize=(20, 10))
plot_tree(self.tree,
feature_names=self.feature_names,
class_names=self.class_names,
filled=True,
rounded=True,
fontsize=10,
ax=ax)
plt.title('Decision Tree Structure', fontsize=14)
plt.tight_layout()
plt.show()
def plot_decision_boundary(self):
"""Plot decision boundary for 2D data"""
if self.X.shape[1] != 2:
print("Decision boundary plot requires exactly 2 features")
return
# Create mesh
h = 0.02
x_min, x_max = self.X[:, 0].min() - 1, self.X[:, 0].max() + 1
y_min, y_max = self.X[:, 1].min() - 1, self.X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
# Predict on mesh
Z = self.tree.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# Plot
fig, ax = plt.subplots(figsize=(10, 8))
# Decision regions
ax.contourf(xx, yy, Z, alpha=0.4, cmap='viridis')
# Data points
scatter = ax.scatter(self.X[:, 0], self.X[:, 1], c=self.y,
cmap='viridis', edgecolor='black', s=50)
ax.set_xlabel(self.feature_names[0])
ax.set_ylabel(self.feature_names[1])
ax.set_title('Decision Tree Decision Boundary')
# Add legend
plt.colorbar(scatter, ax=ax, label='Class')
plt.tight_layout()
plt.show()
def plot_feature_importance(self):
"""Plot feature importance"""
importances = self.tree.feature_importances_
indices = np.argsort(importances)[::-1]
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(range(len(importances)), importances[indices], alpha=0.7)
ax.set_xlabel('Feature')
ax.set_ylabel('Importance')
ax.set_title('Feature Importances')
ax.set_xticks(range(len(importances)))
ax.set_xticklabels([self.feature_names[i] for i in indices], rotation=45)
ax.grid(True, alpha=0.3)
# Add cumulative importance line
cumsum = np.cumsum(importances[indices])
ax2 = ax.twinx()
ax2.plot(range(len(importances)), cumsum, 'r-', marker='o',
label='Cumulative')
ax2.set_ylabel('Cumulative Importance', color='r')
ax2.tick_params(axis='y', labelcolor='r')
ax2.axhline(y=0.95, color='r', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()
return pd.DataFrame({
'Feature': self.feature_names,
'Importance': importances
}).sort_values('Importance', ascending=False)
def get_decision_path(self, sample):
"""Get decision path for a sample"""
decision_path = self.tree.decision_path([sample])
leaf = self.tree.apply([sample])[0]
feature = self.tree.tree_.feature
threshold = self.tree.tree_.threshold
node_indicator = decision_path.toarray()[0]
node_index = np.where(node_indicator)[0]
print(f"Decision path for sample: {sample}")
print("="*50)
for node_id in node_index:
if leaf == node_id:
print(f"→ Leaf node {node_id}: Predict {self.class_names[np.argmax(self.tree.tree_.value[node_id])]}")
else:
if sample[feature[node_id]] <= threshold[node_id]:
threshold_sign = "<="
else:
threshold_sign = ">"
print(f"→ Node {node_id}: "
f"{self.feature_names[feature[node_id]]} "
f"{threshold_sign} {threshold[node_id]:.3f}")
# Create and train tree for visualization
tree_vis = DecisionTreeClassifier(max_depth=4, random_state=42)
tree_vis.fit(X_train_iris, y_train_iris)
# Visualize tree
print("\n" + "="*60)
print("TREE VISUALIZATION")
print("="*60)
visualizer = TreeVisualization(tree_vis, X_train_iris, y_train_iris,
feature_names=['Petal Length', 'Petal Width'],
class_names=iris.target_names)
# Plot tree structure
print("\nPlotting tree structure...")
visualizer.plot_tree_structure()
# Plot decision boundary
print("\nPlotting decision boundary...")
visualizer.plot_decision_boundary()
# Feature importance
print("\nAnalyzing feature importance...")
importance_df = visualizer.plot_feature_importance()
print("\nFeature Importance:")
print(importance_df.to_string(index=False))
# Decision path for a sample
print("\nTracing decision path for a sample...")
sample = X_test_iris[0]
visualizer.get_decision_path(sample)
Tree Pruning and Regularization
Preventing Overfitting
class TreePruning:
"""Analyze tree pruning and regularization techniques"""
def __init__(self, X_train, y_train, X_test, y_test):
self.X_train = X_train
self.y_train = y_train
self.X_test = X_test
self.y_test = y_test
def analyze_depth_impact(self, max_depths=range(1, 21)):
"""Analyze impact of tree depth on performance"""
train_scores = []
test_scores = []
n_leaves = []
for depth in max_depths:
tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
tree.fit(self.X_train, self.y_train)
train_scores.append(accuracy_score(self.y_train, tree.predict(self.X_train)))
test_scores.append(accuracy_score(self.y_test, tree.predict(self.X_test)))
n_leaves.append(tree.get_n_leaves())
# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Accuracy vs Depth
axes[0].plot(max_depths, train_scores, 'o-', label='Train', linewidth=2)
axes[0].plot(max_depths, test_scores, 's-', label='Test', linewidth=2)
axes[0].set_xlabel('Max Depth')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Accuracy vs Tree Depth')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Find optimal depth
optimal_depth = max_depths[np.argmax(test_scores)]
axes[0].axvline(x=optimal_depth, color='red', linestyle='--',
label=f'Optimal ({optimal_depth})')
# Number of leaves vs Depth
axes[1].plot(max_depths, n_leaves, 'o-', color='green', linewidth=2)
axes[1].set_xlabel('Max Depth')
axes[1].set_ylabel('Number of Leaves')
axes[1].set_title('Tree Complexity vs Depth')
axes[1].grid(True, alpha=0.3)
plt.suptitle('Tree Depth Analysis', fontsize=14)
plt.tight_layout()
plt.show()
return optimal_depth
def compare_regularization_parameters(self):
"""Compare different regularization parameters"""
params = {
'max_depth': [None, 3, 5, 10, 15],
'min_samples_split': [2, 5, 10, 20],
'min_samples_leaf': [1, 5, 10],
'max_features': [None, 'sqrt', 'log2']
}
results = []
# Test different combinations
for depth in params['max_depth']:
for min_split in params['min_samples_split']:
tree = DecisionTreeClassifier(
max_depth=depth,
min_samples_split=min_split,
random_state=42
)
tree.fit(self.X_train, self.y_train)
train_score = accuracy_score(self.y_train, tree.predict(self.X_train))
test_score = accuracy_score(self.y_test, tree.predict(self.X_test))
results.append({
'max_depth': depth,
'min_samples_split': min_split,
'train_accuracy': train_score,
'test_accuracy': test_score,
'overfit_gap': train_score - test_score,
'n_leaves': tree.get_n_leaves()
})
results_df = pd.DataFrame(results)
# Find best parameters
best_idx = results_df['test_accuracy'].argmax()
best_params = results_df.iloc[best_idx]
print("\nRegularization Parameter Comparison:")
print("="*50)
print(f"Best Parameters:")
print(f" max_depth: {best_params['max_depth']}")
print(f" min_samples_split: {best_params['min_samples_split']}")
print(f" Test Accuracy: {best_params['test_accuracy']:.3f}")
print(f" Overfit Gap: {best_params['overfit_gap']:.3f}")
# Visualize top results
top_results = results_df.nlargest(10, 'test_accuracy')
fig, ax = plt.subplots(figsize=(12, 6))
x = range(len(top_results))
width = 0.35
ax.bar([i - width/2 for i in x], top_results['train_accuracy'],
width, label='Train', alpha=0.7)
ax.bar([i + width/2 for i in x], top_results['test_accuracy'],
width, label='Test', alpha=0.7)
ax.set_xlabel('Configuration')
ax.set_ylabel('Accuracy')
ax.set_title('Top 10 Regularization Configurations')
ax.set_xticks(x)
ax.set_xticklabels([f"d={row['max_depth']},s={row['min_samples_split']}"
for _, row in top_results.iterrows()],
rotation=45)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
return results_df
def cost_complexity_pruning(self):
"""Demonstrate cost complexity pruning (minimal cost-complexity)"""
# Get cost complexity pruning path
tree = DecisionTreeClassifier(random_state=42)
path = tree.cost_complexity_pruning_path(self.X_train, self.y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
# Train trees with different alpha values
trees = []
for ccp_alpha in ccp_alphas:
tree = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp_alpha)
tree.fit(self.X_train, self.y_train)
trees.append(tree)
# Calculate scores
train_scores = [accuracy_score(self.y_train, tree.predict(self.X_train))
for tree in trees]
test_scores = [accuracy_score(self.y_test, tree.predict(self.X_test))
for tree in trees]
n_nodes = [tree.tree_.node_count for tree in trees]
depths = [tree.tree_.max_depth for tree in trees]
# Visualize
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Accuracy vs alpha
axes[0, 0].plot(ccp_alphas, train_scores, marker='o', label='Train')
axes[0, 0].plot(ccp_alphas, test_scores, marker='s', label='Test')
axes[0, 0].set_xlabel('Alpha (Cost-Complexity Parameter)')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].set_title('Accuracy vs Alpha')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Number of nodes vs alpha
axes[0, 1].plot(ccp_alphas, n_nodes, marker='o', color='green')
axes[0, 1].set_xlabel('Alpha')
axes[0, 1].set_ylabel('Number of Nodes')
axes[0, 1].set_title('Tree Size vs Alpha')
axes[0, 1].grid(True, alpha=0.3)
# Depth vs alpha
axes[1, 0].plot(ccp_alphas, depths, marker='o', color='orange')
axes[1, 0].set_xlabel('Alpha')
axes[1, 0].set_ylabel('Tree Depth')
axes[1, 0].set_title('Tree Depth vs Alpha')
axes[1, 0].grid(True, alpha=0.3)
# Impurity vs alpha
axes[1, 1].plot(ccp_alphas, impurities, marker='o', color='red')
axes[1, 1].set_xlabel('Alpha')
axes[1, 1].set_ylabel('Total Impurity')
axes[1, 1].set_title('Impurity vs Alpha')
axes[1, 1].grid(True, alpha=0.3)
plt.suptitle('Cost Complexity Pruning Analysis', fontsize=14)
plt.tight_layout()
plt.show()
# Find optimal alpha
optimal_idx = np.argmax(test_scores)
optimal_alpha = ccp_alphas[optimal_idx]
print(f"\nOptimal Alpha: {optimal_alpha:.6f}")
print(f"Test Accuracy: {test_scores[optimal_idx]:.3f}")
print(f"Number of Nodes: {n_nodes[optimal_idx]}")
print(f"Tree Depth: {depths[optimal_idx]}")
# Analyze pruning and regularization
print("\n" + "="*60)
print("PRUNING AND REGULARIZATION")
print("="*60)
pruning = TreePruning(X_train, y_train, X_test, y_test)
# Analyze depth impact
print("\nAnalyzing impact of tree depth...")
optimal_depth = pruning.analyze_depth_impact()
print(f"Optimal depth: {optimal_depth}")
# Compare regularization parameters
print("\nComparing regularization parameters...")
reg_results = pruning.compare_regularization_parameters()
# Cost complexity pruning
print("\nDemonstrating cost complexity pruning...")
pruning.cost_complexity_pruning()
Practice Exercises
Exercise 1: Custom Splitting Criterion
Implement a decision tree with a custom splitting criterion:
- Create a weighted Gini impurity that considers class imbalance
- Implement Chi-square test for categorical features
- Add support for missing values handling
- Compare performance with standard criteria
Exercise 2: Tree Ensemble
Build a simple ensemble of decision trees:
- Implement bagging (bootstrap aggregating)
- Create a voting classifier with multiple trees
- Add feature randomness (random subspace)
- Compare with a single deep tree
Exercise 3: Interpretability Tools
Create advanced interpretation tools:
- Extract and visualize all decision rules
- Calculate feature interaction strengths
- Generate natural language explanations
- Create counterfactual explanations
Key Takeaways
- 🌳 Decision trees are intuitive and interpretable
- 📊 Gini impurity and entropy measure node purity
- ✂️ Pruning prevents overfitting
- 📈 Feature importance shows variable contribution
- 🎯 Trees can handle both classification and regression
- ⚖️ Regularization parameters control tree complexity
- 🔍 Decision paths provide transparent predictions
- 📉 Shallow trees often generalize better
- 🎨 Visualization helps understand model decisions
- ⚡ Trees are fast to train and predict