import numpy as np
import pandas as pd
%matplotlib inline
import random
from pprint import pprint
from decision_tree_functions import decision_tree_algorithm, decision_tree_predictions
from helper_functions import train_test_split, calculate_accuracy
df = pd.read_csv("../../data/winequality-red.csv")
df["label"] = df.quality
df = df.drop("quality", axis=1)
column_names = []
for column in df.columns:
name = column.replace(" ", "_")
column_names.append(name)
df.columns = column_names
df.head()
wine_quality = df.label.value_counts(normalize=True)
wine_quality = wine_quality.sort_index()
wine_quality.plot(kind="bar")
def transform_label(value):
if value <= 5:
return "bad"
else:
return "good"
df["label"] = df.label.apply(transform_label)
wine_quality = df.label.value_counts(normalize=True)
wine_quality[["bad", "good"]].plot(kind="bar")
wine_quality
random.seed(0)
train_df, test_df = train_test_split(df, test_size=0.2)
def bootstrapping(train_df, n_bootstrap):
bootstrap_indices = np.random.randint(low=0, high=len(train_df), size=n_bootstrap)
df_bootstrapped = train_df.iloc[bootstrap_indices]
return df_bootstrapped
def random_forest_algorithm(train_df, n_trees, n_bootstrap, n_features, dt_max_depth):
forest = []
for i in range(n_trees):
df_bootstrapped = bootstrapping(train_df, n_bootstrap)
tree = decision_tree_algorithm(df_bootstrapped, max_depth=dt_max_depth, random_subspace=n_features)
forest.append(tree)
return forest
def random_forest_predictions(test_df, forest):
df_predictions = {}
for i in range(len(forest)):
column_name = "tree_{}".format(i)
predictions = decision_tree_predictions(test_df, tree=forest[i])
df_predictions[column_name] = predictions
df_predictions = pd.DataFrame(df_predictions)
random_forest_predictions = df_predictions.mode(axis=1)[0]
return random_forest_predictions
forest = random_forest_algorithm(train_df, n_trees=4, n_bootstrap=800, n_features=2, dt_max_depth=4)
predictions = random_forest_predictions(test_df, forest)
accuracy = calculate_accuracy(predictions, test_df.label)
print("Accuracy = {}".format(accuracy))