import numpy as np
from dpdt import dpdt
from dpdt.utils import CartAIGSelector
from data import *
from sklearn.model_selection import train_test_split

# from sklearn.tree import DecisionTreeClassifier

functions_ = [
    loadBalanceScale,
    loadBreastCancer,
    loadCarEvaluation,
    loadHayesRoth,
    loadHouseVotes84,
    loadSoybean,
    loadSpect,
    loadTicTacToe,
    # loadMonks
]
names_ = [
    "balance-scale",
    "breast-cancer",
    "car-evaluation",
    "hayes-roth",
    "house-votes-84",
    "soybean-small",
    "spect",
    "tic-tac-toe",
    # "monks"
]


# We dont save policies here
for seeds in [37, 42, 53]:
    for f, dataset in enumerate(names_):
        if dataset != "monks":
            S, Y = functions_[f]()
            print("{} N: {}, p: {}".format(dataset, S.shape[0], S.shape[1]))

            # clf = DecisionTreeClassifier(criterion="entropy", max_depth=5, random_state=0)
            S, S_test, Y, Y_test = train_test_split(
                S, Y, test_size=0.5, random_state=seeds
            )
            # clf.fit(S, Y)
            # print(clf.score(S,Y))

            scores_dp, depths_dp, nodes_dp, time_ = dpdt(
                S,
                Y,
                max_depth=5,
                aig_fn=CartAIGSelector(depth=4),
                zetas=np.zeros(1),  # no trade-offs
                save_policy=True,
                policy_folder="saved_policies/{}_seed_{}_cart4_selector_depth5".format(
                    dataset, seeds
                ),
                verbose=False,
            )
            print(scores_dp)
            np.save(
                "results_npz/{}_seed_{}_cart4_selector_depth5".format(dataset, seeds),
                [nodes_dp, scores_dp],
            )
            np.save(
                "results_npz/time_{}_seed_{}_cart4_selector_depth5".format(
                    dataset, seeds
                ),
                [time_],
            )
