import os
import argparse

import numpy as np

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier

from experiment_utils import run_experiment, find_diffs, graph_single

def cart_init(noise, n, t, parameter=None):
    random_state = int((t+9) * (n+1) * ((noise+3.3) ** 2) * 1010)
    if parameter is None:
        return DecisionTreeClassifier(
            random_state=random_state
        )
    else:
        return DecisionTreeClassifier(
            max_depth=parameter, 
            random_state=random_state
        )
    
def boosting_init(noise, n, t, parameter=None):
    random_state = int((t+9) * (n+1) * ((noise+3.3) ** 2) * 1010)
    if parameter is None:
        return GradientBoostingClassifier(
            max_depth=2, 
            random_state=random_state
        )
    else:
        return GradientBoostingClassifier(
            max_depth=2, 
            n_estimators=parameter,
            random_state=random_state
        )

def main(
    X,
    y, 
    output_dir, 
    num_draws, 
    num_splits, 
    noises, 
    parameter_name, 
    parameter_list, 
    model_init,
    legend_label,
):
    params, chosen_params = run_experiment(
        X = X, 
        y = y, 
        noises = noises, 
        model_init = model_init, 
        parameter_name = parameter_name, 
        parameter_list = parameter_list, 
        num_draws = num_draws, 
        num_splits = num_splits,
    )

    scores = find_diffs(
        X = X, 
        y = y, 
        chosen_params = chosen_params,
        noises = noises, 
        model_init = model_init, 
        num_draws = num_draws, 
        num_splits = num_splits,
    )

    os.makedirs(output_dir, exist_ok=True)

    params.to_csv(f'{output_dir}/params.csv', index=False)
    scores.to_csv(f'{output_dir}/scores.csv', index=False)

    graph_single(
        path = f'{output_dir}/params', 
        data = params, 
        y_name = 'Parameter', 
        y_label = f'Best {parameter_name}', 
        legend_label = legend_label,
    )

    graph_single(
        path = f'{output_dir}/score_diffs', 
        data = scores, 
        y_name = 'Train-Val', 
        y_label = r'(Train$-$Val) Accuracy', 
        legend_label = legend_label,
    )

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Runs the experiment described for Figure 1')
    parser.add_argument('-d', '--data_dir', default='./datasets/monks1')
    parser.add_argument('-o', '--output_dir', default='./data/cart/monks1')
    parser.add_argument('-nd', '--num_draws', default=5, type=int)
    parser.add_argument('-ns', '--num_splits', default=5, type=int)
    parser.add_argument('-n', '--noises', nargs='*', type=float, default=[0, 0.03, 0.05, 0.1, 0.15, 0.2, 0.25])
    parser.add_argument('-m', '--model', choices=['cart', 'boosting'], default='cart')
    parser.add_argument('-l', '--legend_label', default='Monks1')

    args = parser.parse_args()

    data_dir = args.data_dir

    X = np.load(f'{data_dir}/X_data.npy')
    y = np.load(f'{data_dir}/y_data.npy')

    model = args.model

    if model == 'cart':
        parameter_name = 'max_depth'
        parameter_list = range(1, X.shape[1]+1)
        model_init = cart_init
    elif model == 'boosting':
        parameter_name = 'n_estimators'
        parameter_list = [5,10,20,30,40,50,60,70,80,90,100,110,120,130,140,150]
        model_init = boosting_init

    main(
        X = X,
        y = y, 
        output_dir = args.output_dir, 
        num_draws = args.num_draws, 
        num_splits = args.num_splits, 
        noises = args.noises, 
        parameter_name = parameter_name, 
        parameter_list = parameter_list, 
        model_init = model_init,
        legend_label = args.legend_label,
    )