In [1]:
import numpy as np
import pandas as pd

from pprint import pprint

from decision_tree_functions import decision_tree_algorithm, make_predictions, calculate_accuracy
from helper_functions import generate_data, create_plot, train_test_split

1. Post-Pruning

In [2]:
np.random.seed(0)
df_train = generate_data(n=300, specific_outliers=[(5.4, 8.4)])
tree = decision_tree_algorithm(df_train, ml_task="classification", max_depth=10)
create_plot(df_train, tree, title="Training Data")

np.random.seed(7)
df_val = generate_data(n=300)
create_plot(df_val, tree, title="Validation Data")
In [3]:
tree = {'x <= 5.0': [True, False]}
In [4]:
def post_pruning(tree, df_train, df_val):
    
    question = list(tree.keys())[0]
    yes_answer, no_answer = tree[question]

    # base case
    if not isinstance(yes_answer, dict) and not isinstance(no_answer, dict):
        leaf = df_train.label.value_counts().index[0]
        errors_leaf = sum(df_val.label != leaf)
        errors_decision_node = sum(df_val.label != make_predictions(df_val, tree))

        if errors_leaf <= errors_decision_node:
            return leaf
        else:
            return tree
        
    # recursive part
    else:
        
    
        return tree
In [ ]: