from numpy import ndarray, array
from sklearn.tree import DecisionTreeClassifier, plot_tree
from time import time
import matplotlib.pyplot as plt


def cart_study(S: ndarray, Y: ndarray, max_depth: int):
    res = []
    start = time()
    for d in range(max_depth):
        start = time()
        clf = DecisionTreeClassifier(
            criterion="entropy", max_depth=d + 1, random_state=0
        )
        clf.fit(S, Y)
        res.append([clf.score(S, Y), clf.get_depth(), clf.tree_.node_count])
    end = time()
    res = array(res)
    return res[:, 0], res[:, 1], res[:, 2], end - start


def count_avg_test_cart(clf: DecisionTreeClassifier, S: ndarray):
    node_indicator = clf.decision_path(S)
    return node_indicator.sum(axis=1).mean() - 1


def cart_study_post_pruning(
    S: ndarray,
    Y: ndarray,
    max_depth: int,
    plot_trees: bool = False,
    tree_folder: str = "",
    S_test = None,
    Y_test = None
):
    res = []
    start = time()

    clf = DecisionTreeClassifier(
        criterion="entropy", max_depth=max_depth, random_state=0
    )
    path = clf.cost_complexity_pruning_path(S, Y)
    ccp_alphas, impurities = path.ccp_alphas, path.impurities

    for ccp_alpha in ccp_alphas:
        clf = DecisionTreeClassifier(
            random_state=0,
            ccp_alpha=ccp_alpha,
            criterion="entropy",
            max_depth=max_depth,
        )
        clf.fit(S, Y)
        if S_test is not None:
            res.append(
                [
                    clf.score(S, Y),
                    clf.score(S_test, Y_test),
                    clf.get_depth(),
                    clf.tree_.node_count,
                    count_avg_test_cart(clf, S),
                ]
            )
        else:
            res.append(
            [
                clf.score(S, Y),
                clf.get_depth(),
                clf.tree_.node_count,
                count_avg_test_cart(clf, S),
            ]
        )
        if plot_trees:
            plot_tree(clf)
            plt.savefig(
                tree_folder
                + str(clf.tree_.node_count)
                + "_"
                + str(round(clf.score(S, Y), 4))
                + ".png"
            )
    end = time()
    res = array(res)
    return res[:, 0], res[:,1], res[:, 2], res[:, 3], end - start, res[:, 4]
