import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
from sklearn.datasets import load_digits
from sklearn.datasets import load_wine
from sklearn.tree import DecisionTreeClassifier
from csv import reader

fig, ax = plt.subplots()
ax.set_xlabel(r"$\eta$*size")
ax.set_ylabel(r"accuracy")
ax.set_title("Accuracy vs inexplainability")


eta = 0.01


# Load a CSV file
def load_csv(filename):
	file = open(filename, "rt")
	lines = reader(file)
	dataset = list(lines)
	return dataset

# Convert string column to float
def str_column_to_float(dataset, column):
	for row in dataset:
		row[column] = float(row[column].strip())

def plot_path(X, y, dataset, rs):

    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rs)
    
    # clf = DecisionTreeClassifier(random_state=0)
    clf = DecisionTreeClassifier(criterion='entropy',random_state=0)
    path = clf.cost_complexity_pruning_path(X_train, y_train)
    ccp_alphas, impurities = path.ccp_alphas, path.impurities
    
    
    clfs = []
    for ccp_alpha in ccp_alphas:
        # clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
        clf = DecisionTreeClassifier(criterion='entropy',random_state=0, ccp_alpha=ccp_alpha)
        clf.fit(X_train, y_train)
        clfs.append(clf)
    
    
    clfs = clfs[:-1]
    ccp_alphas = ccp_alphas[:-1]
    
    node_counts = [clf.tree_.node_count for clf in clfs]
    
    # train_scores = [clf.score(X_train, y_train) for clf in clfs]
    test_scores = [clf.score(X_test, y_test) for clf in clfs]
    ccp_alphas=np.append(ccp_alphas,0.3)
    test_scores.append(test_scores[-1])
    node_counts.append(node_counts[-1])
    
    exp_scores = []
    for i in range(len(node_counts)):
        exp_scores.append(eta*node_counts[i])
        # exp_scores.append(test_scores[i]+eta*node_counts[i])
    # ccp_alphas=np.append(ccp_alphas,0.3)
    # test_scores.append(test_scores[-1])
    
    # ax.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
    ax.plot(exp_scores, test_scores, marker="x", label=dataset)
    ax.legend()


X, y = load_breast_cancer(return_X_y=True)
plot_path(X, y, 'breast_cancer', 0)
X, y = load_wine(return_X_y=True)
plot_path(X, y, 'wine', 3)
X, y = load_iris(return_X_y=True)
plot_path(X, y, 'iris', 2)

filename = 'data_banknote_authentication.csv'
dataset = load_csv(filename)
# convert string attributes to integers
for i in range(len(dataset[0])):
 	str_column_to_float(dataset, i)
     
X = dataset
y = []
for i in range(len(dataset)):
    y.append(X[i][-1])
    X[i]=X[i][:-1]
plot_path(X, y, 'banknote', 0)

plt.xlim([0.02,0.15])
plt.show()