import numpy as np
from data import *
import pickle
from dpdt.utils import extract_tree, Data, average_traj_length_in_mdp
import matplotlib.pyplot as plt
from baselines import cart_study_post_pruning

functions_ = [
    get_avila_data,
    get_bank_data,
    get_bean_data,
    get_bidding_data,
    get_eeg_data,
    get_fault_data,
    get_htru_data,
    get_occupancy_data,
    get_magic_data,
    get_page_data,
    get_rice_data,
    get_raisin_data,
    get_segment_data,
    get_room_data,
    get_wilt_data,
    get_skin_data,
]
names_ = [
    "avila",
    "bank",
    "bean",
    "bidding",
    "eeg",
    "fault",
    "htru",
    "occupancy",
    "magic",
    "page",
    "rice",
    "raisin",
    "segment",
    "room",
    "wilt",
    "skin",
]

zetas = np.linspace(-1, 0, 1000)
policy_folder = ""
for f, dataset in enumerate(names_):
    S, Y = functions_[f]()
    # DPDT
    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))
    init_obs = np.concatenate(
        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64
    )
    with open(
        "saved_policies/{}_cart3_selector_depth5".format(dataset) + ".pkl", "rb"
    ) as pol:
        policy = pickle.load(pol)

    scores, depths, nodes, lengths = (
        np.zeros(zetas.shape[0], dtype=np.float64),
        np.zeros(zetas.shape[0], dtype=np.uint8),
        np.zeros(zetas.shape[0], dtype=np.uint8),
        np.zeros(zetas.shape[0], dtype=np.float64),
    )
    for i, zeta in enumerate(zetas):
        scores[i], lengths[i] = average_traj_length_in_mdp(S, Y, policy, init_obs, i)
        _, nodes[i], depths[i] = extract_tree(policy, init_obs, zeta=i)
    time_ = np.load(
        "results_npz/time_" + dataset + "_cart3_selector_depth5.npy",
        allow_pickle=True,
    )[0]
    # CART
    (
        scores_cart,
        depths_cart,
        nodes_cart,
        time_cart,
        lengths_cart,
    ) = cart_study_post_pruning(S, Y, max_depth=5)

    # nodes
    plt.clf()
    plt.plot(nodes, scores, linewidth=4, label="DPDT-3 in " + str(int(time_)) + "s")
    plt.plot(
        nodes_cart,
        scores_cart,
        linewidth=4,
        label="CART-postpruning in " + str(int(time_cart)) + "s",
    )
    plt.xlabel("nodes", fontdict={"size": 14})
    plt.ylabel("accuracy", fontdict={"size": 14})
    plt.grid()
    plt.title(dataset, fontdict={"size": 18})
    plt.legend(loc="lower right", prop={"size": 18})
    plt.savefig("plots/trade_offs/nodes/{}.pdf".format(dataset))

    # avg tests
    plt.clf()
    plt.plot(lengths, scores, linewidth=4, label="DPDT-3 in " + str(int(time_)) + "s")
    plt.plot(
        lengths_cart,
        scores_cart,
        linewidth=4,
        label="CART-postpruning in " + str(int(time_cart)) + "s",
    )
    plt.xlabel("average tests per sample", fontdict={"size": 14})
    plt.ylabel("accuracy", fontdict={"size": 14})
    plt.grid()
    plt.title(dataset, fontdict={"size": 18})
    plt.legend(loc="lower right", prop={"size": 18})
    plt.savefig("plots/trade_offs/simulability/{}.pdf".format(dataset))

    # Depths
    plt.clf()
    plt.plot(depths, scores, linewidth=4, label="DPDT-3 in " + str(int(time_)) + "s")
    plt.plot(
        depths_cart,
        scores_cart,
        linewidth=4,
        label="CART-postpruning in " + str(int(time_cart)) + "s",
    )
    plt.xlabel("depth", fontdict={"size": 14})
    plt.ylabel("accuracy", fontdict={"size": 14})
    plt.grid()
    plt.title(dataset, fontdict={"size": 18})
    plt.legend(loc="lower right", prop={"size": 18})
    plt.savefig("plots/trade_offs/depths/{}.pdf".format(dataset))
