Skip to main content

Decision Trees

Intuitive and Interpretable Machine Learning! 🌳

Decision trees are one of the most intuitive machine learning algorithms, mimicking human decision-making processes. From their simple if-then structure to their powerful ensemble variants, decision trees form the foundation of many state-of-the-art algorithms. Master tree construction, pruning, and visualization to build interpretable yet powerful models.

Decision Tree Fundamentals

graph TD A[Decision Trees] --> B[Tree Structure] A --> C[Splitting Criteria] A --> D[Tree Types] B --> E[Root Node] B --> F[Internal Nodes] B --> G[Leaf Nodes] B --> H[Branches] C --> I[Gini Impurity] C --> J[Information Gain] C --> K[Gain Ratio] C --> L[MSE Reduction] D --> M[Classification Trees] D --> N[Regression Trees] style A fill:#f9f,stroke:#333,stroke-width:2px style C fill:#bbf,stroke:#333,stroke-width:2px style D fill:#fbf,stroke:#333,stroke-width:2px

Decision Tree Implementation from Scratch

Building a Classification Tree

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_classification, load_iris, load_wine
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.tree import plot_tree, export_text
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

class DecisionTreeNode:
    """Node class for decision tree"""
    
    def __init__(self, feature=None, threshold=None, left=None, right=None, 
                 value=None, samples=None, impurity=None):
        # For internal nodes
        self.feature = feature          # Feature index for splitting
        self.threshold = threshold      # Threshold value for splitting
        self.left = left               # Left child node
        self.right = right             # Right child node
        
        # For leaf nodes
        self.value = value             # Class prediction or regression value
        
        # For visualization
        self.samples = samples         # Number of samples at this node
        self.impurity = impurity       # Impurity measure at this node
        
    def is_leaf(self):
        """Check if node is a leaf"""
        return self.value is not None

class DecisionTreeFromScratch:
    """Decision Tree implementation from scratch"""
    
    def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1,
                 criterion='gini', random_state=None):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.criterion = criterion
        self.random_state = random_state
        self.root = None
        self.n_features_ = None
        self.n_classes_ = None
        self.feature_names = None
        self.class_names = None
        
        if random_state is not None:
            np.random.seed(random_state)
    
    def fit(self, X, y, feature_names=None, class_names=None):
        """Build decision tree"""
        self.n_features_ = X.shape[1]
        self.n_classes_ = len(np.unique(y))
        self.feature_names = feature_names or [f'Feature_{i}' for i in range(self.n_features_)]
        self.class_names = class_names or [f'Class_{i}' for i in range(self.n_classes_)]
        
        # Build tree recursively
        self.root = self._build_tree(X, y, depth=0)
        return self
    
    def _build_tree(self, X, y, depth):
        """Recursively build the tree"""
        n_samples = X.shape[0]
        n_classes = len(np.unique(y))
        
        # Calculate node impurity
        impurity = self._calculate_impurity(y)
        
        # Check stopping criteria
        if (depth >= self.max_depth or 
            n_samples < self.min_samples_split or 
            n_classes == 1):
            
            # Create leaf node
            leaf_value = self._most_common_label(y)
            return DecisionTreeNode(value=leaf_value, samples=n_samples, 
                                   impurity=impurity)
        
        # Find best split
        best_feature, best_threshold = self._best_split(X, y)
        
        if best_feature is None:
            # No valid split found
            leaf_value = self._most_common_label(y)
            return DecisionTreeNode(value=leaf_value, samples=n_samples, 
                                   impurity=impurity)
        
        # Split data
        left_indices = X[:, best_feature] <= best_threshold
        right_indices = X[:, best_feature] > best_threshold
        
        # Check minimum samples in leaves
        if (np.sum(left_indices) < self.min_samples_leaf or 
            np.sum(right_indices) < self.min_samples_leaf):
            leaf_value = self._most_common_label(y)
            return DecisionTreeNode(value=leaf_value, samples=n_samples, 
                                   impurity=impurity)
        
        # Recursively build left and right subtrees
        left_child = self._build_tree(X[left_indices], y[left_indices], depth + 1)
        right_child = self._build_tree(X[right_indices], y[right_indices], depth + 1)
        
        return DecisionTreeNode(feature=best_feature, threshold=best_threshold,
                               left=left_child, right=right_child,
                               samples=n_samples, impurity=impurity)
    
    def _best_split(self, X, y):
        """Find best feature and threshold to split on"""
        best_gain = -1
        best_feature = None
        best_threshold = None
        
        # Calculate parent impurity
        parent_impurity = self._calculate_impurity(y)
        
        # Try all features
        for feature_idx in range(self.n_features_):
            feature_values = X[:, feature_idx]
            thresholds = np.unique(feature_values)
            
            # Try all unique values as thresholds
            for threshold in thresholds:
                # Split data
                left_mask = feature_values <= threshold
                right_mask = feature_values > threshold
                
                # Skip if split doesn't divide data
                if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
                    continue
                
                # Calculate information gain
                gain = self._information_gain(y, left_mask, right_mask, parent_impurity)
                
                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature_idx
                    best_threshold = threshold
        
        return best_feature, best_threshold
    
    def _information_gain(self, y, left_mask, right_mask, parent_impurity):
        """Calculate information gain from split"""
        n_total = len(y)
        n_left = np.sum(left_mask)
        n_right = np.sum(right_mask)
        
        # Calculate weighted average impurity after split
        left_impurity = self._calculate_impurity(y[left_mask])
        right_impurity = self._calculate_impurity(y[right_mask])
        
        weighted_impurity = (n_left / n_total) * left_impurity + \
                          (n_right / n_total) * right_impurity
        
        # Information gain is reduction in impurity
        return parent_impurity - weighted_impurity
    
    def _calculate_impurity(self, y):
        """Calculate impurity (Gini or Entropy)"""
        proportions = np.bincount(y.astype(int)) / len(y)
        
        if self.criterion == 'gini':
            # Gini impurity: 1 - Σ(p_i^2)
            return 1 - np.sum(proportions ** 2)
        elif self.criterion == 'entropy':
            # Entropy: -Σ(p_i * log2(p_i))
            # Avoid log(0)
            proportions = proportions[proportions > 0]
            return -np.sum(proportions * np.log2(proportions))
        else:
            raise ValueError(f"Unknown criterion: {self.criterion}")
    
    def _most_common_label(self, y):
        """Return most common class label"""
        return np.bincount(y.astype(int)).argmax()
    
    def predict(self, X):
        """Predict class for samples"""
        return np.array([self._predict_sample(sample) for sample in X])
    
    def _predict_sample(self, sample):
        """Predict class for a single sample"""
        node = self.root
        
        while not node.is_leaf():
            if sample[node.feature] <= node.threshold:
                node = node.left
            else:
                node = node.right
        
        return node.value
    
    def print_tree(self, node=None, depth=0):
        """Print tree structure"""
        if node is None:
            node = self.root
        
        if node.is_leaf():
            print(f"{'  ' * depth}Predict: {self.class_names[node.value]} "
                  f"(samples={node.samples}, impurity={node.impurity:.3f})")
        else:
            print(f"{'  ' * depth}If {self.feature_names[node.feature]} <= {node.threshold:.3f} "
                  f"(samples={node.samples}, impurity={node.impurity:.3f})")
            self.print_tree(node.left, depth + 1)
            print(f"{'  ' * depth}Else:")
            self.print_tree(node.right, depth + 1)
    
    def get_depth(self, node=None):
        """Get tree depth"""
        if node is None:
            node = self.root
        
        if node.is_leaf():
            return 0
        
        return 1 + max(self.get_depth(node.left), self.get_depth(node.right))
    
    def get_n_leaves(self, node=None):
        """Get number of leaf nodes"""
        if node is None:
            node = self.root
        
        if node.is_leaf():
            return 1
        
        return self.get_n_leaves(node.left) + self.get_n_leaves(node.right)

# Generate sample data
X, y = make_classification(n_samples=200, n_features=4, n_informative=3,
                          n_redundant=0, n_classes=3, random_state=42)

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, 
                                                    random_state=42, stratify=y)

# Train custom decision tree
print("="*60)
print("DECISION TREE FROM SCRATCH")
print("="*60)

tree_custom = DecisionTreeFromScratch(max_depth=3, min_samples_split=5, 
                                      criterion='gini', random_state=42)
tree_custom.fit(X_train, y_train, 
                feature_names=['Feature_A', 'Feature_B', 'Feature_C', 'Feature_D'],
                class_names=['Class_0', 'Class_1', 'Class_2'])

# Make predictions
y_pred_custom = tree_custom.predict(X_test)
accuracy_custom = accuracy_score(y_test, y_pred_custom)

print(f"\nCustom Tree Statistics:")
print(f"  Depth: {tree_custom.get_depth()}")
print(f"  Number of leaves: {tree_custom.get_n_leaves()}")
print(f"  Test Accuracy: {accuracy_custom:.3f}")

print(f"\nTree Structure:")
tree_custom.print_tree()

# Compare with scikit-learn
tree_sklearn = DecisionTreeClassifier(max_depth=3, min_samples_split=5,
                                      criterion='gini', random_state=42)
tree_sklearn.fit(X_train, y_train)
y_pred_sklearn = tree_sklearn.predict(X_test)
accuracy_sklearn = accuracy_score(y_test, y_pred_sklearn)

print(f"\nScikit-learn Tree Accuracy: {accuracy_sklearn:.3f}")

Understanding Splitting Criteria

Gini Impurity vs Information Gain

class SplittingCriteriaAnalysis:
    """Analyze different splitting criteria for decision trees"""
    
    def __init__(self):
        self.criteria = ['gini', 'entropy']
        
    def calculate_gini(self, y):
        """Calculate Gini impurity"""
        if len(y) == 0:
            return 0
        proportions = np.bincount(y) / len(y)
        return 1 - np.sum(proportions ** 2)
    
    def calculate_entropy(self, y):
        """Calculate entropy"""
        if len(y) == 0:
            return 0
        proportions = np.bincount(y) / len(y)
        proportions = proportions[proportions > 0]
        return -np.sum(proportions * np.log2(proportions))
    
    def calculate_gain_ratio(self, y, left_mask, right_mask):
        """Calculate gain ratio (C4.5 criterion)"""
        n_total = len(y)
        n_left = np.sum(left_mask)
        n_right = np.sum(right_mask)
        
        # Information gain
        parent_entropy = self.calculate_entropy(y)
        left_entropy = self.calculate_entropy(y[left_mask])
        right_entropy = self.calculate_entropy(y[right_mask])
        
        info_gain = parent_entropy - \
                   (n_left/n_total * left_entropy + n_right/n_total * right_entropy)
        
        # Split information
        p_left = n_left / n_total
        p_right = n_right / n_total
        
        split_info = 0
        if p_left > 0:
            split_info -= p_left * np.log2(p_left)
        if p_right > 0:
            split_info -= p_right * np.log2(p_right)
        
        # Gain ratio
        if split_info == 0:
            return 0
        return info_gain / split_info
    
    def visualize_impurity_functions(self):
        """Visualize impurity functions for binary classification"""
        
        # Probability range for binary classification
        p = np.linspace(0.001, 0.999, 100)
        
        # Calculate impurities
        gini = 2 * p * (1 - p)
        entropy = -p * np.log2(p) - (1-p) * np.log2(1-p)
        misclassification = 1 - np.maximum(p, 1-p)
        
        # Plot
        fig, ax = plt.subplots(figsize=(10, 6))
        
        ax.plot(p, gini, label='Gini Impurity', linewidth=2)
        ax.plot(p, entropy, label='Entropy', linewidth=2)
        ax.plot(p, misclassification, label='Misclassification Error', linewidth=2)
        
        ax.set_xlabel('Proportion of Class 1')
        ax.set_ylabel('Impurity')
        ax.set_title('Comparison of Impurity Measures (Binary Classification)')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Mark maximum impurity point
        ax.axvline(x=0.5, color='red', linestyle='--', alpha=0.5)
        ax.text(0.5, 0.1, 'Maximum\nImpurity', ha='center', fontsize=9)
        
        plt.tight_layout()
        plt.show()
    
    def compare_split_quality(self, X, y):
        """Compare split quality using different criteria"""
        
        # Try different features and thresholds
        results = []
        
        for feature_idx in range(X.shape[1]):
            feature_values = X[:, feature_idx]
            threshold = np.median(feature_values)
            
            left_mask = feature_values <= threshold
            right_mask = feature_values > threshold
            
            if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
                continue
            
            # Calculate different metrics
            parent_gini = self.calculate_gini(y)
            parent_entropy = self.calculate_entropy(y)
            
            # Gini gain
            left_gini = self.calculate_gini(y[left_mask])
            right_gini = self.calculate_gini(y[right_mask])
            n_left, n_right = np.sum(left_mask), np.sum(right_mask)
            n_total = len(y)
            
            gini_gain = parent_gini - \
                       (n_left/n_total * left_gini + n_right/n_total * right_gini)
            
            # Information gain
            left_entropy = self.calculate_entropy(y[left_mask])
            right_entropy = self.calculate_entropy(y[right_mask])
            
            info_gain = parent_entropy - \
                       (n_left/n_total * left_entropy + n_right/n_total * right_entropy)
            
            # Gain ratio
            gain_ratio = self.calculate_gain_ratio(y, left_mask, right_mask)
            
            results.append({
                'Feature': f'Feature_{feature_idx}',
                'Threshold': threshold,
                'Gini Gain': gini_gain,
                'Info Gain': info_gain,
                'Gain Ratio': gain_ratio
            })
        
        results_df = pd.DataFrame(results)
        
        # Visualize comparison
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Gini Gain
        axes[0].bar(range(len(results_df)), results_df['Gini Gain'], alpha=0.7)
        axes[0].set_xlabel('Feature')
        axes[0].set_ylabel('Gini Gain')
        axes[0].set_title('Gini Gain by Feature')
        axes[0].set_xticks(range(len(results_df)))
        axes[0].set_xticklabels(results_df['Feature'], rotation=45)
        axes[0].grid(True, alpha=0.3)
        
        # Information Gain
        axes[1].bar(range(len(results_df)), results_df['Info Gain'], 
                   alpha=0.7, color='orange')
        axes[1].set_xlabel('Feature')
        axes[1].set_ylabel('Information Gain')
        axes[1].set_title('Information Gain by Feature')
        axes[1].set_xticks(range(len(results_df)))
        axes[1].set_xticklabels(results_df['Feature'], rotation=45)
        axes[1].grid(True, alpha=0.3)
        
        # Gain Ratio
        axes[2].bar(range(len(results_df)), results_df['Gain Ratio'], 
                   alpha=0.7, color='green')
        axes[2].set_xlabel('Feature')
        axes[2].set_ylabel('Gain Ratio')
        axes[2].set_title('Gain Ratio by Feature')
        axes[2].set_xticks(range(len(results_df)))
        axes[2].set_xticklabels(results_df['Feature'], rotation=45)
        axes[2].grid(True, alpha=0.3)
        
        plt.suptitle('Comparison of Splitting Criteria', fontsize=14)
        plt.tight_layout()
        plt.show()
        
        return results_df

# Analyze splitting criteria
print("\n" + "="*60)
print("SPLITTING CRITERIA ANALYSIS")
print("="*60)

criteria_analysis = SplittingCriteriaAnalysis()

# Visualize impurity functions
print("\nVisualizing impurity functions...")
criteria_analysis.visualize_impurity_functions()

# Compare split quality
print("\nComparing split quality metrics...")
split_comparison = criteria_analysis.compare_split_quality(X_train, y_train)
print("\nSplit Quality Comparison:")
print(split_comparison.to_string(index=False))

Tree Visualization and Interpretation

Visualizing Decision Trees

# Load iris dataset for visualization
iris = load_iris()
X_iris, y_iris = iris.data[:, [2, 3]], iris.target  # Use only 2 features for visualization

# Split data
X_train_iris, X_test_iris, y_train_iris, y_test_iris = train_test_split(
    X_iris, y_iris, test_size=0.3, random_state=42, stratify=y_iris
)

class TreeVisualization:
    """Visualize and interpret decision trees"""
    
    def __init__(self, tree_model, X, y, feature_names=None, class_names=None):
        self.tree = tree_model
        self.X = X
        self.y = y
        self.feature_names = feature_names or [f'Feature_{i}' for i in range(X.shape[1])]
        self.class_names = class_names or [f'Class_{i}' for i in range(len(np.unique(y)))]
    
    def plot_tree_structure(self):
        """Plot tree structure using sklearn's plot_tree"""
        fig, ax = plt.subplots(figsize=(20, 10))
        
        plot_tree(self.tree, 
                  feature_names=self.feature_names,
                  class_names=self.class_names,
                  filled=True,
                  rounded=True,
                  fontsize=10,
                  ax=ax)
        
        plt.title('Decision Tree Structure', fontsize=14)
        plt.tight_layout()
        plt.show()
    
    def plot_decision_boundary(self):
        """Plot decision boundary for 2D data"""
        if self.X.shape[1] != 2:
            print("Decision boundary plot requires exactly 2 features")
            return
        
        # Create mesh
        h = 0.02
        x_min, x_max = self.X[:, 0].min() - 1, self.X[:, 0].max() + 1
        y_min, y_max = self.X[:, 1].min() - 1, self.X[:, 1].max() + 1
        xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                            np.arange(y_min, y_max, h))
        
        # Predict on mesh
        Z = self.tree.predict(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)
        
        # Plot
        fig, ax = plt.subplots(figsize=(10, 8))
        
        # Decision regions
        ax.contourf(xx, yy, Z, alpha=0.4, cmap='viridis')
        
        # Data points
        scatter = ax.scatter(self.X[:, 0], self.X[:, 1], c=self.y, 
                            cmap='viridis', edgecolor='black', s=50)
        
        ax.set_xlabel(self.feature_names[0])
        ax.set_ylabel(self.feature_names[1])
        ax.set_title('Decision Tree Decision Boundary')
        
        # Add legend
        plt.colorbar(scatter, ax=ax, label='Class')
        
        plt.tight_layout()
        plt.show()
    
    def plot_feature_importance(self):
        """Plot feature importance"""
        importances = self.tree.feature_importances_
        indices = np.argsort(importances)[::-1]
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        ax.bar(range(len(importances)), importances[indices], alpha=0.7)
        ax.set_xlabel('Feature')
        ax.set_ylabel('Importance')
        ax.set_title('Feature Importances')
        ax.set_xticks(range(len(importances)))
        ax.set_xticklabels([self.feature_names[i] for i in indices], rotation=45)
        ax.grid(True, alpha=0.3)
        
        # Add cumulative importance line
        cumsum = np.cumsum(importances[indices])
        ax2 = ax.twinx()
        ax2.plot(range(len(importances)), cumsum, 'r-', marker='o', 
                label='Cumulative')
        ax2.set_ylabel('Cumulative Importance', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        ax2.axhline(y=0.95, color='r', linestyle='--', alpha=0.5)
        
        plt.tight_layout()
        plt.show()
        
        return pd.DataFrame({
            'Feature': self.feature_names,
            'Importance': importances
        }).sort_values('Importance', ascending=False)
    
    def get_decision_path(self, sample):
        """Get decision path for a sample"""
        decision_path = self.tree.decision_path([sample])
        leaf = self.tree.apply([sample])[0]
        
        feature = self.tree.tree_.feature
        threshold = self.tree.tree_.threshold
        
        node_indicator = decision_path.toarray()[0]
        node_index = np.where(node_indicator)[0]
        
        print(f"Decision path for sample: {sample}")
        print("="*50)
        
        for node_id in node_index:
            if leaf == node_id:
                print(f"→ Leaf node {node_id}: Predict {self.class_names[np.argmax(self.tree.tree_.value[node_id])]}")
            else:
                if sample[feature[node_id]] <= threshold[node_id]:
                    threshold_sign = "<="
                else:
                    threshold_sign = ">"
                
                print(f"→ Node {node_id}: "
                      f"{self.feature_names[feature[node_id]]} "
                      f"{threshold_sign} {threshold[node_id]:.3f}")

# Create and train tree for visualization
tree_vis = DecisionTreeClassifier(max_depth=4, random_state=42)
tree_vis.fit(X_train_iris, y_train_iris)

# Visualize tree
print("\n" + "="*60)
print("TREE VISUALIZATION")
print("="*60)

visualizer = TreeVisualization(tree_vis, X_train_iris, y_train_iris,
                              feature_names=['Petal Length', 'Petal Width'],
                              class_names=iris.target_names)

# Plot tree structure
print("\nPlotting tree structure...")
visualizer.plot_tree_structure()

# Plot decision boundary
print("\nPlotting decision boundary...")
visualizer.plot_decision_boundary()

# Feature importance
print("\nAnalyzing feature importance...")
importance_df = visualizer.plot_feature_importance()
print("\nFeature Importance:")
print(importance_df.to_string(index=False))

# Decision path for a sample
print("\nTracing decision path for a sample...")
sample = X_test_iris[0]
visualizer.get_decision_path(sample)

Tree Pruning and Regularization

Preventing Overfitting

class TreePruning:
    """Analyze tree pruning and regularization techniques"""
    
    def __init__(self, X_train, y_train, X_test, y_test):
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
    
    def analyze_depth_impact(self, max_depths=range(1, 21)):
        """Analyze impact of tree depth on performance"""
        
        train_scores = []
        test_scores = []
        n_leaves = []
        
        for depth in max_depths:
            tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
            tree.fit(self.X_train, self.y_train)
            
            train_scores.append(accuracy_score(self.y_train, tree.predict(self.X_train)))
            test_scores.append(accuracy_score(self.y_test, tree.predict(self.X_test)))
            n_leaves.append(tree.get_n_leaves())
        
        # Visualize
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Accuracy vs Depth
        axes[0].plot(max_depths, train_scores, 'o-', label='Train', linewidth=2)
        axes[0].plot(max_depths, test_scores, 's-', label='Test', linewidth=2)
        axes[0].set_xlabel('Max Depth')
        axes[0].set_ylabel('Accuracy')
        axes[0].set_title('Accuracy vs Tree Depth')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Find optimal depth
        optimal_depth = max_depths[np.argmax(test_scores)]
        axes[0].axvline(x=optimal_depth, color='red', linestyle='--', 
                       label=f'Optimal ({optimal_depth})')
        
        # Number of leaves vs Depth
        axes[1].plot(max_depths, n_leaves, 'o-', color='green', linewidth=2)
        axes[1].set_xlabel('Max Depth')
        axes[1].set_ylabel('Number of Leaves')
        axes[1].set_title('Tree Complexity vs Depth')
        axes[1].grid(True, alpha=0.3)
        
        plt.suptitle('Tree Depth Analysis', fontsize=14)
        plt.tight_layout()
        plt.show()
        
        return optimal_depth
    
    def compare_regularization_parameters(self):
        """Compare different regularization parameters"""
        
        params = {
            'max_depth': [None, 3, 5, 10, 15],
            'min_samples_split': [2, 5, 10, 20],
            'min_samples_leaf': [1, 5, 10],
            'max_features': [None, 'sqrt', 'log2']
        }
        
        results = []
        
        # Test different combinations
        for depth in params['max_depth']:
            for min_split in params['min_samples_split']:
                tree = DecisionTreeClassifier(
                    max_depth=depth,
                    min_samples_split=min_split,
                    random_state=42
                )
                tree.fit(self.X_train, self.y_train)
                
                train_score = accuracy_score(self.y_train, tree.predict(self.X_train))
                test_score = accuracy_score(self.y_test, tree.predict(self.X_test))
                
                results.append({
                    'max_depth': depth,
                    'min_samples_split': min_split,
                    'train_accuracy': train_score,
                    'test_accuracy': test_score,
                    'overfit_gap': train_score - test_score,
                    'n_leaves': tree.get_n_leaves()
                })
        
        results_df = pd.DataFrame(results)
        
        # Find best parameters
        best_idx = results_df['test_accuracy'].argmax()
        best_params = results_df.iloc[best_idx]
        
        print("\nRegularization Parameter Comparison:")
        print("="*50)
        print(f"Best Parameters:")
        print(f"  max_depth: {best_params['max_depth']}")
        print(f"  min_samples_split: {best_params['min_samples_split']}")
        print(f"  Test Accuracy: {best_params['test_accuracy']:.3f}")
        print(f"  Overfit Gap: {best_params['overfit_gap']:.3f}")
        
        # Visualize top results
        top_results = results_df.nlargest(10, 'test_accuracy')
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        x = range(len(top_results))
        width = 0.35
        
        ax.bar([i - width/2 for i in x], top_results['train_accuracy'], 
               width, label='Train', alpha=0.7)
        ax.bar([i + width/2 for i in x], top_results['test_accuracy'], 
               width, label='Test', alpha=0.7)
        
        ax.set_xlabel('Configuration')
        ax.set_ylabel('Accuracy')
        ax.set_title('Top 10 Regularization Configurations')
        ax.set_xticks(x)
        ax.set_xticklabels([f"d={row['max_depth']},s={row['min_samples_split']}" 
                            for _, row in top_results.iterrows()], 
                           rotation=45)
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        return results_df
    
    def cost_complexity_pruning(self):
        """Demonstrate cost complexity pruning (minimal cost-complexity)"""
        
        # Get cost complexity pruning path
        tree = DecisionTreeClassifier(random_state=42)
        path = tree.cost_complexity_pruning_path(self.X_train, self.y_train)
        ccp_alphas, impurities = path.ccp_alphas, path.impurities
        
        # Train trees with different alpha values
        trees = []
        for ccp_alpha in ccp_alphas:
            tree = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp_alpha)
            tree.fit(self.X_train, self.y_train)
            trees.append(tree)
        
        # Calculate scores
        train_scores = [accuracy_score(self.y_train, tree.predict(self.X_train)) 
                       for tree in trees]
        test_scores = [accuracy_score(self.y_test, tree.predict(self.X_test)) 
                      for tree in trees]
        n_nodes = [tree.tree_.node_count for tree in trees]
        depths = [tree.tree_.max_depth for tree in trees]
        
        # Visualize
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Accuracy vs alpha
        axes[0, 0].plot(ccp_alphas, train_scores, marker='o', label='Train')
        axes[0, 0].plot(ccp_alphas, test_scores, marker='s', label='Test')
        axes[0, 0].set_xlabel('Alpha (Cost-Complexity Parameter)')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].set_title('Accuracy vs Alpha')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Number of nodes vs alpha
        axes[0, 1].plot(ccp_alphas, n_nodes, marker='o', color='green')
        axes[0, 1].set_xlabel('Alpha')
        axes[0, 1].set_ylabel('Number of Nodes')
        axes[0, 1].set_title('Tree Size vs Alpha')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Depth vs alpha
        axes[1, 0].plot(ccp_alphas, depths, marker='o', color='orange')
        axes[1, 0].set_xlabel('Alpha')
        axes[1, 0].set_ylabel('Tree Depth')
        axes[1, 0].set_title('Tree Depth vs Alpha')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Impurity vs alpha
        axes[1, 1].plot(ccp_alphas, impurities, marker='o', color='red')
        axes[1, 1].set_xlabel('Alpha')
        axes[1, 1].set_ylabel('Total Impurity')
        axes[1, 1].set_title('Impurity vs Alpha')
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.suptitle('Cost Complexity Pruning Analysis', fontsize=14)
        plt.tight_layout()
        plt.show()
        
        # Find optimal alpha
        optimal_idx = np.argmax(test_scores)
        optimal_alpha = ccp_alphas[optimal_idx]
        
        print(f"\nOptimal Alpha: {optimal_alpha:.6f}")
        print(f"Test Accuracy: {test_scores[optimal_idx]:.3f}")
        print(f"Number of Nodes: {n_nodes[optimal_idx]}")
        print(f"Tree Depth: {depths[optimal_idx]}")

# Analyze pruning and regularization
print("\n" + "="*60)
print("PRUNING AND REGULARIZATION")
print("="*60)

pruning = TreePruning(X_train, y_train, X_test, y_test)

# Analyze depth impact
print("\nAnalyzing impact of tree depth...")
optimal_depth = pruning.analyze_depth_impact()
print(f"Optimal depth: {optimal_depth}")

# Compare regularization parameters
print("\nComparing regularization parameters...")
reg_results = pruning.compare_regularization_parameters()

# Cost complexity pruning
print("\nDemonstrating cost complexity pruning...")
pruning.cost_complexity_pruning()

Practice Exercises

Exercise 1: Custom Splitting Criterion

Implement a decision tree with a custom splitting criterion:

  1. Create a weighted Gini impurity that considers class imbalance
  2. Implement Chi-square test for categorical features
  3. Add support for missing values handling
  4. Compare performance with standard criteria

Exercise 2: Tree Ensemble

Build a simple ensemble of decision trees:

  1. Implement bagging (bootstrap aggregating)
  2. Create a voting classifier with multiple trees
  3. Add feature randomness (random subspace)
  4. Compare with a single deep tree

Exercise 3: Interpretability Tools

Create advanced interpretation tools:

  1. Extract and visualize all decision rules
  2. Calculate feature interaction strengths
  3. Generate natural language explanations
  4. Create counterfactual explanations

Key Takeaways

Further Resources