import numpy as np
import pandas as pd
from pprint import pprint
from decision_tree_functions import decision_tree_algorithm, make_predictions, calculate_accuracy
from helper_functions import generate_data, create_plot, train_test_split
def filter_df(df, question):
feature, _, value = question.split()
df_yes = df[df[feature] <= float(value)]
df_no = df[df[feature] > float(value)]
return df_yes, df_no
def pruning_result(tree, df_train, df_val):
leaf = df_train.label.value_counts().index[0]
errors_leaf = sum(df_val.label != leaf)
errors_decision_node = sum(df_val.label != make_predictions(df_val, tree))
if errors_leaf <= errors_decision_node:
return leaf
else:
return tree
tree = {'x <= 5.0': [{"y <= 5": [True,
False]},
False]}
def post_pruning(tree, df_train, df_val):
question = list(tree.keys())[0]
yes_answer, no_answer = tree[question]
# base case
if not isinstance(yes_answer, dict) and not isinstance(no_answer, dict):
return pruning_result(tree, df_train, df_val)
# recursive part
else:
df_train_yes, df_train_no = filter_df(df_train, question)
df_val_yes, df_val_no = filter_df(df_val, question)
if isinstance(yes_answer, dict):
yes_answer = post_pruning(yes_answer, df_train_yes, df_val_yes)
if isinstance(no_answer, dict):
no_answer = post_pruning(no_answer, df_train_no, df_val_no)
tree = {question: [yes_answer, no_answer]}
return pruning_result(tree, df_train, df_val)
# np.random.seed(0)
df_train = generate_data(n=300, n_random_outliers=5)
tree = decision_tree_algorithm(df_train, ml_task="classification", max_depth=10)
create_plot(df_train, tree, title="Tree before Post-pruning")
# np.random.seed(7)
df_val = generate_data(n=300)
tree_pruned = post_pruning(tree, df_train, df_val)
create_plot(df_val, tree_pruned, title="Tree after Post-pruning")