Random Forests in Depth
Comprehensive guide to Random Forests: theory, implementation, tuning, and interpretation.
What are Random Forests?
Random Forest is an ensemble learning method that constructs multiple decision trees and combines their predictions.
Key Concepts:
- Bagging: Bootstrap Aggregating - train each tree on random subset of data
- Feature Randomness: Each split considers random subset of features
- Ensemble: Combine predictions by voting (classification) or averaging (regression)
Basic Usage
Classification
1from sklearn.ensemble import RandomForestClassifier
2from sklearn.datasets import make_classification
3from sklearn.model_selection import train_test_split
4from sklearn.metrics import accuracy_score, classification_report
5
6# Generate sample data
7X, y = make_classification(n_samples=1000, n_features=20, n_informative=15,
8 n_redundant=5, random_state=42)
9
10# Split data
11X_train, X_test, y_train, y_test = train_test_split(
12 X, y, test_size=0.2, random_state=42
13)
14
15# Train Random Forest
16rf = RandomForestClassifier(
17 n_estimators=100, # Number of trees
18 max_depth=None, # Maximum depth (None = unlimited)
19 min_samples_split=2, # Minimum samples to split node
20 min_samples_leaf=1, # Minimum samples in leaf
21 max_features='sqrt', # Features to consider for split
22 bootstrap=True, # Use bootstrap samples
23 random_state=42,
24 n_jobs=-1 # Use all CPU cores
25)
26
27rf.fit(X_train, y_train)
28
29# Predict
30y_pred = rf.predict(X_test)
31y_pred_proba = rf.predict_proba(X_test)
32
33# Evaluate
34print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
35print("\nClassification Report:")
36print(classification_report(y_test, y_pred))
Regression
1from sklearn.ensemble import RandomForestRegressor
2from sklearn.datasets import make_regression
3from sklearn.metrics import mean_squared_error, r2_score
4import numpy as np
5
6# Generate sample data
7X, y = make_regression(n_samples=1000, n_features=20, noise=0.1, random_state=42)
8
9# Split data
10X_train, X_test, y_train, y_test = train_test_split(
11 X, y, test_size=0.2, random_state=42
12)
13
14# Train Random Forest Regressor
15rf = RandomForestRegressor(
16 n_estimators=100,
17 max_depth=None,
18 min_samples_split=2,
19 min_samples_leaf=1,
20 max_features='sqrt',
21 bootstrap=True,
22 random_state=42,
23 n_jobs=-1
24)
25
26rf.fit(X_train, y_train)
27
28# Predict
29y_pred = rf.predict(X_test)
30
31# Evaluate
32mse = mean_squared_error(y_test, y_pred)
33rmse = np.sqrt(mse)
34r2 = r2_score(y_test, y_pred)
35
36print(f"RMSE: {rmse:.3f}")
37print(f"R²: {r2:.3f}")
Hyperparameter Tuning
Key Hyperparameters
1# Number of trees
2n_estimators = [50, 100, 200, 500]
3
4# Maximum depth of trees
5max_depth = [10, 20, 30, None]
6
7# Minimum samples required to split a node
8min_samples_split = [2, 5, 10]
9
10# Minimum samples required at leaf node
11min_samples_leaf = [1, 2, 4]
12
13# Number of features to consider at each split
14max_features = ['sqrt', 'log2', None] # sqrt = sqrt(n_features)
15
16# Bootstrap samples
17bootstrap = [True, False]
18
19# Criterion
20criterion = ['gini', 'entropy'] # For classification
21criterion = ['squared_error', 'absolute_error'] # For regression
Grid Search
1from sklearn.model_selection import GridSearchCV
2
3param_grid = {
4 'n_estimators': [100, 200, 300],
5 'max_depth': [10, 20, 30, None],
6 'min_samples_split': [2, 5, 10],
7 'min_samples_leaf': [1, 2, 4],
8 'max_features': ['sqrt', 'log2'],
9 'bootstrap': [True, False]
10}
11
12rf = RandomForestClassifier(random_state=42, n_jobs=-1)
13
14grid_search = GridSearchCV(
15 estimator=rf,
16 param_grid=param_grid,
17 cv=5,
18 scoring='accuracy',
19 n_jobs=-1,
20 verbose=2
21)
22
23grid_search.fit(X_train, y_train)
24
25print("Best parameters:", grid_search.best_params_)
26print("Best score:", grid_search.best_score_)
27
28# Use best model
29best_rf = grid_search.best_estimator_
Random Search (Faster)
1from sklearn.model_selection import RandomizedSearchCV
2from scipy.stats import randint, uniform
3
4param_dist = {
5 'n_estimators': randint(100, 500),
6 'max_depth': [10, 20, 30, 40, None],
7 'min_samples_split': randint(2, 20),
8 'min_samples_leaf': randint(1, 10),
9 'max_features': ['sqrt', 'log2', None],
10 'bootstrap': [True, False]
11}
12
13rf = RandomForestClassifier(random_state=42, n_jobs=-1)
14
15random_search = RandomizedSearchCV(
16 estimator=rf,
17 param_distributions=param_dist,
18 n_iter=100,
19 cv=5,
20 scoring='accuracy',
21 n_jobs=-1,
22 random_state=42,
23 verbose=2
24)
25
26random_search.fit(X_train, y_train)
27
28print("Best parameters:", random_search.best_params_)
29print("Best score:", random_search.best_score_)
Feature Importance
Basic Feature Importance
1import matplotlib.pyplot as plt
2import pandas as pd
3
4# Get feature importances
5importances = rf.feature_importances_
6indices = np.argsort(importances)[::-1]
7
8# Print feature ranking
9print("Feature ranking:")
10for i, idx in enumerate(indices):
11 print(f"{i+1}. Feature {idx} ({importances[idx]:.4f})")
12
13# Plot
14plt.figure(figsize=(10, 6))
15plt.title("Feature Importances")
16plt.bar(range(X_train.shape[1]), importances[indices])
17plt.xticks(range(X_train.shape[1]), indices)
18plt.xlabel("Feature Index")
19plt.ylabel("Importance")
20plt.tight_layout()
21plt.show()
22
23# With feature names
24if isinstance(X_train, pd.DataFrame):
25 feature_importance_df = pd.DataFrame({
26 'feature': X_train.columns,
27 'importance': importances
28 }).sort_values('importance', ascending=False)
29
30 print(feature_importance_df)
31
32 plt.figure(figsize=(10, 6))
33 plt.barh(feature_importance_df['feature'][:20],
34 feature_importance_df['importance'][:20])
35 plt.xlabel("Importance")
36 plt.title("Top 20 Feature Importances")
37 plt.tight_layout()
38 plt.show()
Permutation Importance
More reliable than default feature importances.
1from sklearn.inspection import permutation_importance
2
3# Calculate permutation importance
4perm_importance = permutation_importance(
5 rf, X_test, y_test,
6 n_repeats=10,
7 random_state=42,
8 n_jobs=-1
9)
10
11# Sort by importance
12sorted_idx = perm_importance.importances_mean.argsort()[::-1]
13
14# Plot
15plt.figure(figsize=(10, 6))
16plt.boxplot(perm_importance.importances[sorted_idx].T,
17 labels=np.array(range(X_test.shape[1]))[sorted_idx],
18 vert=False)
19plt.xlabel("Permutation Importance")
20plt.title("Permutation Feature Importance")
21plt.tight_layout()
22plt.show()
Model Interpretation
Partial Dependence Plots
1from sklearn.inspection import PartialDependenceDisplay
2
3# Plot partial dependence for top features
4features = [0, 1, (0, 1)] # Single features and interaction
5fig, ax = plt.subplots(figsize=(12, 4))
6PartialDependenceDisplay.from_estimator(
7 rf, X_train, features, ax=ax
8)
9plt.tight_layout()
10plt.show()
SHAP Values
1import shap
2
3# Create explainer
4explainer = shap.TreeExplainer(rf)
5
6# Calculate SHAP values
7shap_values = explainer.shap_values(X_test)
8
9# Summary plot
10shap.summary_plot(shap_values, X_test)
11
12# Force plot for single prediction
13shap.force_plot(
14 explainer.expected_value[1],
15 shap_values[1][0],
16 X_test[0]
17)
18
19# Dependence plot
20shap.dependence_plot(0, shap_values[1], X_test)
Out-of-Bag (OOB) Score
Random Forests can estimate test error without cross-validation using OOB samples.
1# Enable OOB score
2rf = RandomForestClassifier(
3 n_estimators=100,
4 oob_score=True, # Enable OOB scoring
5 bootstrap=True, # Required for OOB
6 random_state=42,
7 n_jobs=-1
8)
9
10rf.fit(X_train, y_train)
11
12# OOB score (similar to cross-validation score)
13print(f"OOB Score: {rf.oob_score_:.3f}")
14
15# OOB predictions
16oob_pred = rf.oob_decision_function_
Handling Imbalanced Data
Class Weights
1from sklearn.utils import class_weight
2
3# Compute class weights
4class_weights = class_weight.compute_class_weight(
5 'balanced',
6 classes=np.unique(y_train),
7 y=y_train
8)
9
10# Use in model
11rf = RandomForestClassifier(
12 n_estimators=100,
13 class_weight='balanced', # or dict with custom weights
14 random_state=42
15)
16
17rf.fit(X_train, y_train)
Balanced Random Forest
1from imblearn.ensemble import BalancedRandomForestClassifier
2
3# Automatically balances classes
4brf = BalancedRandomForestClassifier(
5 n_estimators=100,
6 sampling_strategy='auto', # Balance all classes
7 replacement=True,
8 random_state=42,
9 n_jobs=-1
10)
11
12brf.fit(X_train, y_train)
Ensemble Methods with Random Forests
Stacking
1from sklearn.ensemble import StackingClassifier
2from sklearn.linear_model import LogisticRegression
3from sklearn.svm import SVC
4
5# Base models
6estimators = [
7 ('rf', RandomForestClassifier(n_estimators=100, random_state=42)),
8 ('svm', SVC(probability=True, random_state=42))
9]
10
11# Stacking with logistic regression as meta-model
12stacking = StackingClassifier(
13 estimators=estimators,
14 final_estimator=LogisticRegression(),
15 cv=5
16)
17
18stacking.fit(X_train, y_train)
19y_pred = stacking.predict(X_test)
Voting
1from sklearn.ensemble import VotingClassifier
2
3# Create voting classifier
4voting = VotingClassifier(
5 estimators=[
6 ('rf1', RandomForestClassifier(n_estimators=100, max_depth=10)),
7 ('rf2', RandomForestClassifier(n_estimators=200, max_depth=20)),
8 ('rf3', RandomForestClassifier(n_estimators=300, max_depth=None))
9 ],
10 voting='soft' # 'hard' for majority vote, 'soft' for probability average
11)
12
13voting.fit(X_train, y_train)
14y_pred = voting.predict(X_test)
Advanced Techniques
Extremely Randomized Trees (Extra Trees)
Faster training, sometimes better performance.
1from sklearn.ensemble import ExtraTreesClassifier
2
3# Extra Trees: more randomness in splits
4et = ExtraTreesClassifier(
5 n_estimators=100,
6 max_features='sqrt',
7 random_state=42,
8 n_jobs=-1
9)
10
11et.fit(X_train, y_train)
Isolation Forest (Anomaly Detection)
1from sklearn.ensemble import IsolationForest
2
3# Detect anomalies
4iso_forest = IsolationForest(
5 n_estimators=100,
6 contamination=0.1, # Expected proportion of outliers
7 random_state=42
8)
9
10# Fit and predict (-1 for outliers, 1 for inliers)
11predictions = iso_forest.fit_predict(X)
12anomalies = X[predictions == -1]
Quantile Regression Forest
Estimate prediction intervals.
1from sklearn.ensemble import RandomForestRegressor
2import numpy as np
3
4# Train multiple trees
5rf = RandomForestRegressor(n_estimators=100, random_state=42)
6rf.fit(X_train, y_train)
7
8# Get predictions from all trees
9all_predictions = np.array([tree.predict(X_test) for tree in rf.estimators_])
10
11# Calculate quantiles
12lower_bound = np.percentile(all_predictions, 5, axis=0)
13upper_bound = np.percentile(all_predictions, 95, axis=0)
14median = np.percentile(all_predictions, 50, axis=0)
15
16# Plot prediction intervals
17plt.figure(figsize=(10, 6))
18plt.scatter(y_test, median, alpha=0.5, label='Predictions')
19plt.fill_between(range(len(y_test)), lower_bound, upper_bound,
20 alpha=0.2, label='90% Prediction Interval')
21plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()],
22 'r--', label='Perfect Prediction')
23plt.xlabel('Actual')
24plt.ylabel('Predicted')
25plt.legend()
26plt.show()
Optimization Tips
Memory Optimization
1# Reduce memory usage
2rf = RandomForestClassifier(
3 n_estimators=100,
4 max_depth=10, # Limit tree depth
5 min_samples_leaf=5, # Increase minimum leaf size
6 max_features='sqrt', # Limit features per split
7 warm_start=False, # Don't keep trees in memory
8 n_jobs=-1
9)
Speed Optimization
1# Faster training
2rf = RandomForestClassifier(
3 n_estimators=100,
4 max_depth=20, # Limit depth
5 min_samples_split=10, # Larger minimum split
6 max_features='log2', # Fewer features
7 bootstrap=True,
8 n_jobs=-1, # Parallel processing
9 random_state=42
10)
11
12# Use warm_start for incremental training
13rf = RandomForestClassifier(n_estimators=50, warm_start=True)
14rf.fit(X_train, y_train)
15
16# Add more trees
17rf.n_estimators = 100
18rf.fit(X_train, y_train) # Only trains 50 new trees
Common Pitfalls
Overfitting
1# ❌ BAD: Overfitting
2rf = RandomForestClassifier(
3 n_estimators=1000,
4 max_depth=None, # Unlimited depth
5 min_samples_split=2, # Split on 2 samples
6 min_samples_leaf=1 # Leaf can have 1 sample
7)
8
9# ✅ GOOD: Regularization
10rf = RandomForestClassifier(
11 n_estimators=100,
12 max_depth=20, # Limit depth
13 min_samples_split=10, # Require more samples to split
14 min_samples_leaf=5, # Require more samples in leaf
15 max_features='sqrt' # Limit features
16)
Class Imbalance
1# ❌ BAD: Ignoring imbalance
2rf = RandomForestClassifier()
3
4# ✅ GOOD: Handle imbalance
5rf = RandomForestClassifier(class_weight='balanced')
6
7# Or use sampling
8from imblearn.over_sampling import SMOTE
9smote = SMOTE()
10X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
11rf.fit(X_resampled, y_resampled)
Comparison with Other Algorithms
| Algorithm | Pros | Cons |
|---|---|---|
| Random Forest | Robust, handles non-linear, feature importance | Slow prediction, black box, large memory |
| Gradient Boosting | Higher accuracy, handles missing values | Slower training, prone to overfitting |
| Logistic Regression | Fast, interpretable, probabilistic | Linear only, requires feature engineering |
| SVM | Effective in high dimensions, kernel trick | Slow on large data, hard to interpret |
| Neural Networks | Handles complex patterns, flexible | Requires lots of data, hard to tune |
Production Deployment
1import joblib
2
3# Save model
4joblib.dump(rf, 'random_forest_model.joblib')
5
6# Load model
7loaded_rf = joblib.load('random_forest_model.joblib')
8
9# Predict
10predictions = loaded_rf.predict(new_data)
11
12# Model size optimization
13from sklearn.tree import _tree
14
15def get_model_size(model):
16 """Estimate model size in MB"""
17 size = 0
18 for tree in model.estimators_:
19 size += tree.tree_.__sizeof__()
20 return size / (1024 * 1024)
21
22print(f"Model size: {get_model_size(rf):.2f} MB")
Further Reading
- Scikit-learn Random Forest Documentation
- Random Forests Paper (Breiman, 2001)
- Understanding Random Forests
- Feature Importances with Random Forests
Related Snippets
- Click CLI Framework
Building CLI applications with Click in Python - FastAPI with OpenAPI
FastAPI with automatic OpenAPI documentation using Pydantic models and … - Flask Essentials
Flask web framework essentials for building web applications and APIs. … - Function Timing Decorator
Decorator to measure function execution time - LangChain Chatbot with Tools
Simple stdin chatbot using LangChain with tool calling (OpenRouter). … - Pandas DataFrames Essential Patterns
Essential patterns for working with Pandas DataFrames: creation, manipulation, … - Pydantic Data Validation
Pydantic - Data validation using Python type hints. Installation 1pip install … - Python Dataclasses
Python dataclasses for clean, boilerplate-free data structures. Basic Usage … - Python Metaclasses
Python metaclasses with visual explanations using Mermaid diagrams. What are … - Python Virtual Environments
Managing Python virtual environments and dependencies - Scikit-learn Common Patterns
Common patterns and workflows for scikit-learn: preprocessing, model training, …