Monday, 24 February 2025

Build interpretable models using Decision Trees for Regression

Implementation Steps Explained

  1. Data Loading and Splitting:

    • Dataset: We use the California Housing dataset available from scikit-learn, which is a regression dataset.
    • Splitting: The dataset is split into training and testing sets (80/20 split) using train_test_split.
  2. Training a Decision Tree Regressor:

    • A DecisionTreeRegressor is instantiated and trained on the training data.
    • This initial tree is unpruned and might overfit, especially if it grows deep and captures noise.
  3. Visualizing the Decision Tree:

    • We use plot_tree from scikit-learn to visualize the tree structure.
    • To keep the visualization readable, only the first 3 levels are shown.
    • This step helps in understanding how the tree is splitting the data based on features.
  4. Pruning the Decision Tree:

    • Cost Complexity Pruning: We obtain a range of effective α\alpha values using the tree’s cost complexity pruning path. The α\alpha parameter controls the trade-off between tree complexity and its performance on training data.
    • Model Selection: For each candidate α\alpha, a new tree is trained and evaluated on the test set using Mean Squared Error (MSE).
    • Best Alpha Selection: The α\alpha that yields the lowest MSE on the test data is selected. This pruned tree is less complex and should generalize better.
  5. Visualizing the Pruned Tree:

    • Finally, the pruned tree is visualized (again showing just the first 3 levels) to see how pruning has simplified the tree structure.

This workflow helps in reducing overfitting by pruning unnecessary branches, leading to a model that is simpler and often more robust on unseen data.

 

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor, plot_tree
from sklearn.metrics import mean_squared_error, r2_score

# 1. Load and split the dataset
data = fetch_california_housing()
X, y = data.data, data.target
feature_names = data.feature_names

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 2. Train a Decision Tree regressor (without pruning)
tree_reg = DecisionTreeRegressor(random_state=42)
tree_reg.fit(X_train, y_train)

# 3. Visualize the tree structure (first 3 levels for clarity)
plt.figure(figsize=(20, 10))
plot_tree(tree_reg, filled=True, feature_names=feature_names, rounded=True, max_depth=3)
plt.title("Unpruned Decision Tree Regressor (First 3 Levels)")
plt.show()

# Evaluate unpruned tree performance
y_pred_unpruned = tree_reg.predict(X_test)
print("Unpruned Tree MSE:", mean_squared_error(y_test, y_pred_unpruned))
print("Unpruned Tree R2:", r2_score(y_test, y_pred_unpruned))

# 4. Prune the tree to avoid overfitting using cost complexity pruning
#    We first obtain the effective alphas and corresponding total impurities.
path = tree_reg.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

# Train trees for each candidate alpha and evaluate performance
trees = []
mse_scores = []
for ccp_alpha in ccp_alphas:
    reg = DecisionTreeRegressor(random_state=42, ccp_alpha=ccp_alpha)
    reg.fit(X_train, y_train)
    trees.append(reg)
    y_pred = reg.predict(X_test)
    mse_scores.append(mean_squared_error(y_test, y_pred))

# Choose the best alpha (i.e., the one that minimizes MSE)
best_index = np.argmin(mse_scores)
best_alpha = ccp_alphas[best_index]
best_tree = trees[best_index]
print("Best ccp_alpha:", best_alpha)
print("Pruned Tree MSE:", mse_scores[best_index])
print("Pruned Tree R2:", r2_score(y_test, best_tree.predict(X_test)))

# Visualize the pruned tree (again, first 3 levels for clarity)
plt.figure(figsize=(20, 10))
plot_tree(best_tree, filled=True, feature_names=feature_names, rounded=True, max_depth=3)
plt.title("Pruned Decision Tree Regressor (First 3 Levels)")
plt.show()


 

 Output:

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

0 comments :

Post a Comment

Note: only a member of this blog may post a comment.

Machine Learning

More

Advertisement

Java Tutorial

More

UGC NET CS TUTORIAL

MFCS
COA
PL-CG
DBMS
OPERATING SYSTEM
SOFTWARE ENG
DSA
TOC-CD
ARTIFICIAL INT

C Programming

More

Python Tutorial

More

Data Structures

More

computer Organization

More
Top