import numpy as np
from data import *
import pandas as pd
from dpdt.utils import Data, eval_in_mdp, extract_tree
import pickle
from sklearn.model_selection import train_test_split
from statistics import mean
from sklearn.tree import DecisionTreeClassifier

functions_ = [
    loadBalanceScale,
    loadBreastCancer,
    loadCarEvaluation,
    loadHayesRoth,
    loadHouseVotes84,
    loadSoybean,
    loadSpect,
    loadTicTacToe,
    # loadMonks
]
names_ = [
    "balance-scale",
    "breast-cancer",
    "car-evaluation",
    "hayes-roth",
    "house-votes-84",
    "soybean-small",
    "spect",
    "tic-tac-toe",
    # "monks"
]


def get_opt_results_for_compare(
    dataset, seed, S=None, Y=None, S_test=None, Y_test=None
):
    perfs, times, tests = [], [], []
    for algo in ["oct", "mfoct", "boct", "sk"]:
        res = pd.read_csv("data/small_data/res/" + algo + ".csv")
        if algo != "boct" and algo != "sk":
            perfs.append(
                res.loc[
                    (res["depth"] == 5)
                    & (res["seed"] == seed)
                    & (res["instance"] == dataset)
                    & (res["alpha"] == 0.0),
                    "train_acc",
                ].item()
            )
            tests.append(
                res.loc[
                    (res["depth"] == 5)
                    & (res["seed"] == seed)
                    & (res["instance"] == dataset)
                    & (res["alpha"] == 0.0),
                    "test_acc",
                ].item()
            )
            times.append(
                res.loc[
                    (res["depth"] == 5)
                    & (res["seed"] == seed)
                    & (res["instance"] == dataset)
                    & (res["alpha"] == 0.0),
                    "train_time",
                ].item()
            )
        elif algo == "sk":
            clf = DecisionTreeClassifier(
                criterion="entropy", max_depth=5, random_state=0
            )
            clf.fit(S, Y)
            perfs.append(clf.score(S, Y))
            tests.append(clf.score(S_test, Y_test))
            times.append(
                res.loc[
                    (res["depth"] == 5)
                    & (res["seed"] == seed)
                    & (res["instance"] == dataset),
                    "train_time",
                ].item()
            )
        else:
            perfs.append(
                res.loc[
                    (res["depth"] == 5)
                    & (res["seed"] == seed)
                    & (res["instance"] == dataset),
                    "train_acc",
                ].item()
            )
            tests.append(
                res.loc[
                    (res["depth"] == 5)
                    & (res["seed"] == seed)
                    & (res["instance"] == dataset),
                    "test_acc",
                ].item()
            )
            times.append(
                res.loc[
                    (res["depth"] == 5)
                    & (res["seed"] == seed)
                    & (res["instance"] == dataset),
                    "train_time",
                ].item()
            )
    return perfs, times, tests


# TRAIN ACC
zetas = np.zeros(1)
dict_res = dict()
for f, dataset in enumerate(names_):
    if dataset != "monks":
        avg_train_dpdt, avg_test_dpdt, avg_time_dpdt = [], [], []
        avg_train_dpdt5, avg_test_dpdt5, avg_time_dpdt5 = [], [], []
        avg_train_oct, avg_test_oct, avg_time_oct = [], [], []
        avg_train_mfoct, avg_test_mfoct, avg_time_mfoct = [], [], []
        avg_train_boct, avg_test_boct, avg_time_boct = [], [], []
        avg_train_sk, avg_test_sk, avg_time_sk = [], [], []

        for seed in [37, 42, 53]:
            res = np.load(
                "results_npz/"
                + dataset
                + "_seed_{}_cart4_selector_depth5.npy".format(seed),
                allow_pickle=True,
            )[1]
            time = np.load(
                "results_npz/"
                + "time_"
                + dataset
                + "_seed_{}_cart4_selector_depth5.npy".format(seed),
                allow_pickle=True,
            )[0]

            res5 = np.load(
                "results_npz/"
                + dataset
                + "_seed_{}_cart5_selector_depth5.npy".format(seed),
                allow_pickle=True,
            )[1]
            time5 = np.load(
                "results_npz/"
                + "time_"
                + dataset
                + "_seed_{}_cart5_selector_depth5.npy".format(seed),
                allow_pickle=True,
            )[0]

            S, Y = functions_[f]()

            S, S_test, Y, Y_test = train_test_split(
                S, Y, test_size=0.5, random_state=seed
            )
            data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))
            init_obs = np.concatenate(
                (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64
            )

            with open(
                "saved_policies/{}_seed_{}_cart4_selector_depth5".format(dataset, seed)
                + ".pkl",
                "rb",
            ) as pol:
                policy = pickle.load(pol)
            with open(
                "saved_policies/{}_seed_{}_cart5_selector_depth5".format(dataset, seed)
                + ".pkl",
                "rb",
            ) as pol:
                policy5 = pickle.load(pol)

            scores, depths, nodes = (
                np.zeros(1, dtype=np.float64),
                np.zeros(1, dtype=np.uint8),
                np.zeros(1, dtype=np.uint8),
            )
            for i, zeta in enumerate(zetas):
                scores[i] = eval_in_mdp(S_test, Y_test, policy, init_obs, i)

            avg_train_dpdt.append(res[-1])
            avg_test_dpdt.append(scores[-1])
            avg_time_dpdt.append(time)

            scores5, depths5, nodes5 = (
                np.zeros(1, dtype=np.float64),
                np.zeros(1, dtype=np.uint8),
                np.zeros(1, dtype=np.uint8),
            )
            for i, zeta in enumerate(zetas):
                scores5[i] = eval_in_mdp(S_test, Y_test, policy5, init_obs, i)

            avg_train_dpdt5.append(res5[-1])
            avg_test_dpdt5.append(scores5[-1])
            avg_time_dpdt5.append(time5)

            res, time, tests = get_opt_results_for_compare(
                dataset, seed, S, Y, S_test, Y_test
            )

            avg_train_oct.append(res[0])
            avg_train_mfoct.append(res[1])
            avg_train_boct.append(res[2])
            avg_train_sk.append(res[3])

            avg_test_oct.append(tests[0])
            avg_test_mfoct.append(tests[1])
            avg_test_boct.append(tests[2])
            avg_test_sk.append(tests[3])

            avg_time_oct.append(time[0])
            avg_time_mfoct.append(time[1])
            avg_time_boct.append(time[2])
            avg_time_sk.append(time[3])

        print("############################")

        print(
            "{} & ${}$ & {} & {} & $ {} \%$ & $ {} \%$ & $ {} \%$ & $ {} \%$ & $ {} \%$ & $ {} \%$ & $ {} \%$ & $ {} \%$ &$ {} \%$ & $ {} \%$ & $ {} \%$ & $ {} \%$ & $ {} $ & $ {} $ & $ {} $ & $ {} $ & $ {} $ & $ {} $ \\".format(
                dataset,
                int(S.shape[0] * 2),
                int(S.shape[1]),
                int(len(np.unique(Y))),
                round(mean(avg_train_dpdt) * 100, 1),
                round(mean(avg_train_dpdt5) * 100, 1),
                round(mean(avg_train_oct) * 100, 1),
                round(mean(avg_train_mfoct) * 100, 1),
                round(mean(avg_train_boct) * 100, 1),
                round(mean(avg_train_sk) * 100, 1),
                round(mean(avg_test_dpdt) * 100, 1),
                round(mean(avg_test_dpdt5) * 100, 1),
                round(mean(avg_test_oct) * 100, 1),
                round(mean(avg_test_mfoct) * 100, 1),
                round(mean(avg_test_boct) * 100, 1),
                round(mean(avg_test_sk) * 100, 1),
                round(mean(avg_time_dpdt), 2),
                round(mean(avg_time_dpdt5), 2),
                round(mean(avg_time_oct), 2),
                round(mean(avg_time_mfoct), 2),
                round(mean(avg_time_boct), 2),
                round(mean(avg_time_sk), 3),
            )
        )
        # print(dataset)
        # print(S.shape)
        # print(np.unique(Y))
        # print("------------DPDT--------")
        # print(mean(avg_train_dpdt), mean(avg_test_dpdt), mean(avg_time_dpdt))
        # print("------------OCT--------")
        # print(mean(avg_train_oct), mean(avg_test_oct), mean(avg_time_oct))
        # print("------------MFOCT--------")
        # print(mean(avg_train_mfoct), mean(avg_test_mfoct), mean(avg_time_mfoct))
        # print("------------BOCT--------")
        # print(mean(avg_train_boct), mean(avg_test_boct), mean(avg_time_boct))
        # print("############################")

    else:
        for idx in [1, 2, 3]:
            for seed in [37, 42, 53]:
                res = np.load(
                    "results_npz/"
                    + dataset
                    + "-"
                    + str(idx)
                    + "_seed_{}_cart4_selector_depth5.npy".format(seed),
                    allow_pickle=True,
                )[1]
                time = np.load(
                    "results_npz/"
                    + "time_"
                    + dataset
                    + "-"
                    + str(idx)
                    + "_seed_{}_cart4_selector_depth5.npy".format(seed),
                    allow_pickle=True,
                )[0]

                S, Y = functions_[f]("-" + str(idx))
                S, S_test, Y, Y_test = train_test_split(
                    S, Y, test_size=0.5, random_state=seed
                )
                data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))
                init_obs = np.concatenate(
                    (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3),
                    dtype=np.float64,
                )

                with open(
                    "saved_policies/{}_seed_{}_cart4_selector_depth5".format(
                        dataset + "-" + str(idx), seed
                    )
                    + ".pkl",
                    "rb",
                ) as pol:
                    policy = pickle.load(pol)

                scores, depths, nodes = (
                    np.zeros(1, dtype=np.float64),
                    np.zeros(1, dtype=np.uint8),
                    np.zeros(1, dtype=np.uint8),
                )
                for i, zeta in enumerate(zetas):
                    scores[i] = eval_in_mdp(S_test, Y_test, policy, init_obs, i)
                res, time, tests = get_opt_results_for_compare(
                    dataset + "-" + str(idx), seed
                )
