import numpy as np
import pandas as pd
from pprint import pprint
from decision_tree_functions import decision_tree_algorithm, make_predictions, calculate_accuracy
from helper_functions import generate_data, create_plot, train_test_split
np.random.seed(0)
df_train = generate_data(n=300, specific_outliers=[(5.4, 8.4)])
tree = decision_tree_algorithm(df_train, ml_task="classification", max_depth=10)
create_plot(df_train, tree, title="Training Data")
np.random.seed(7)
df_val = generate_data(n=300)
create_plot(df_val, tree, title="Validation Data")
tree = {'x <= 5.0': [True, False]}
def post_pruning(tree, df_train, df_val):
question = list(tree.keys())[0]
yes_answer, no_answer = tree[question]
# base case
if not isinstance(yes_answer, dict) and not isinstance(no_answer, dict):
leaf = df_train.label.value_counts().index[0]
errors_leaf = sum(df_val.label != leaf)
errors_decision_node = sum(df_val.label != make_predictions(df_val, tree))
if errors_leaf <= errors_decision_node:
return leaf
else:
return tree
# recursive part
else:
return tree