import numpy as np
from dpdt import dpdt
from dpdt.utils import CartAIGSelector
from data import *

functions_ = [
    get_avila_data,
    get_bank_data,
    get_bean_data,
    get_bidding_data,
    get_eeg_data,
    get_fault_data,
    get_htru_data,
    get_occupancy_data,
    get_magic_data,
    get_page_data,
    get_rice_data,
    get_raisin_data,
    get_segment_data,
    get_room_data,
    get_wilt_data,
    get_skin_data,
]
names_ = [
    "avila",
    "bank",
    "bean",
    "bidding",
    "eeg",
    "fault",
    "htru",
    "occupancy",
    "magic",
    "page",
    "rice",
    "raisin",
    "segment",
    "room",
    "wilt",
    "skin",
]

for f, dataset in enumerate(names_):
    S, Y = functions_[f]()
    print("{} N: {}, p: {}".format(dataset, S.shape[0], S.shape[1]))
    scores_dp, depths_dp, nodes_dp, time_ = dpdt(
        S,
        Y,
        max_depth=10,
        aig_fn=CartAIGSelector(max_tree_sizes=[3,3,1,1,1,1,1,1,1]),
        zetas=np.linspace(-1, 0, 1000),
        save_policy=True,
        policy_folder="saved_policies/{}_cart_adapt_more_selector_depth10".format(dataset),
        verbose=False,
        tree_folder="plots/trees/{}/".format(dataset),
        plot_tree=False,
    )
    np.save(
        "results_npz/{}_cart_adapt_more_selector_depth10".format(dataset), [nodes_dp, scores_dp]
    )
    np.save("results_npz/time_{}_cart_adapt_more_selector_depth10".format(dataset), [time_])
