import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold

from treefarms import TREEFARMS
from gosdt import GOSDT

import os

import pickle


def noise_y(y, noise: float, random_seed=11):
    rng = np.random.default_rng(random_seed)
    return (y ^ (rng.random(y.shape[0]) < noise)).astype(int)

data_list = ['amsterdam'
             , 'NIJ_Recidivism'
             , 'broward'
             , 'compas'
             , 'fico'
             , 'bankfull'
             , 'australian_credit'
             , 'german_credit'
             , 'GiveMeSomeCredit'
             , 'polish_companies'
             , 'telco_churn'
             , 'iranian_churn'
             , 'occupancy_detection'
             , 'car_evaluation'
             , 'monks2'
             , 'monks1'
             , 'monks3'
             , 'bar7'
             , 'bcw_bin'
             , 'carryout_takeaway'
             , 'restaurant_20'
             , 'bar'
             , 'coffee_house']


def make_expected_noise_dataset(X, y, h, noise, num_draws, seed_scale=11):
    for i in range(num_draws):
        y_exp = noise_y(y, noise, random_seed=seed_scale*i)
        if i == 0:
            X_exp_all = X
            y_exp_all = y_exp
        else:
            X_exp_all = np.concatenate((X_exp_all, X))
            y_exp_all = np.concatenate((y_exp_all, y_exp))
    return pd.DataFrame(X_exp_all, columns=h), pd.Series(y_exp_all)

def tune_lambda(X: np.ndarray, y, h, reg_array: np.ndarray, k: int) -> tuple[float, np.float64]:
    """
    Tune regularization for sparse decision trees using k-fold cross-validation.
    """
    best_score = np.float64(0)
    best_lambda = 0.

    # Loop over hyperparameters
    for i, a in enumerate(reg_array):
        test_scores = np.zeros(len(reg_array))
        config = {
        "regularization": a,
        "depth_budget": 5,
        "allow_small_reg": True,
        }
        
        # Initialize k-fold cross-validation
        kf = KFold(n_splits=k, shuffle=True, random_state = 0)

        # Loop over folds
        for j, (train_index, test_index) in enumerate(kf.split(X)):
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]
            
            X_train = pd.DataFrame(X_train, columns=h)
            X_test = pd.DataFrame(X_test, columns=h)
            y_train = pd.DataFrame(y_train)
            y_test = pd.DataFrame(y_test)

            # Train the model on the training data
            model = GOSDT(config)
            model.fit(X_train, y_train)

            # Compute accuracy score and store test score
            score = model.score(X_test, y_test)
            test_scores[i] = score

        # Compute average test score and update best hyperparameters if necessary
        avg_score = np.mean(test_scores)
        # print(f"Lambda: {a}, Average test score: {avg_score}")
        if avg_score >= best_score:
            best_score = avg_score
            best_lambda = a

    return best_lambda


def load_data(job_id: int) -> pd.DataFrame:
    """
    Load a dataset based on the job_id assigned by SLURM.
    note: for this to work for N different datasets
    must pass the argument --array=1-N to sbatch
    """
    name = data_list[job_id]
    print(f"Loading {name} dataset.")
    dataset = '../datasets/binarized/' + name + '.csv'
    return pd.read_csv(dataset)

def main():
    job_id = int(os.getenv('SLURM_ARRAY_TASK_ID')) - 1 # type: ignore
    df = load_data(job_id)
    df_name = data_list[job_id]
    print(f"{df_name} dataset has {df.shape[0]} examples and {df.shape[1]} features.")
    X, y = df.iloc[:,:-1], df.iloc[:,-1]
    h = df.columns[:-1]
    
    # l = tune_lambda(X.values, y.values, h, np.linspace(0.005, 0.03, 100), 5)
    l=0.02
    
    # train TREEFARMS model
    config = {
        "regularization": l,  # regularization penalizes the tree with more leaves. We recommend to set it to relative high value to find a sparse tree.
        "rashomon_bound_multiplier": 0.05,  # rashomon bound multiplier indicates how large of a Rashomon set would you like to get
        "verbose": False
    }
    clean_model = TREEFARMS(config)
    clean_model.fit(X, y)
    
    exp_X, exp_y = make_expected_noise_dataset(X, y, h, 0.1, 150)
    noisy_model = TREEFARMS(config)
    noisy_model.fit(exp_X, exp_y)
    
    exp_X, exp_y = make_expected_noise_dataset(X, y, h, 0.2, 150)
    noisier_model = TREEFARMS(config)
    noisier_model.fit(exp_X, exp_y)
    
    exp_X, exp_y = make_expected_noise_dataset(X, y, h, 0.3, 150)
    noisiest_model = TREEFARMS(config)
    noisiest_model.fit(exp_X, exp_y)
    
    clean_tree_count = clean_model.get_tree_count()
    noisy_tree_count = noisy_model.get_tree_count()
    noisier_tree_count = noisier_model.get_tree_count()
    noisiest_tree_count = noisiest_model.get_tree_count()
    clean_leaves = [clean_model[i].leaves() for i in range(clean_tree_count)]
    noisy_leaves = [noisy_model[i].leaves() for i in range(noisy_tree_count)]
    noisier_leaves = [noisier_model[i].leaves() for i in range(noisier_tree_count)]
    noisiest_leaves = [noisiest_model[i].leaves() for i in range(noisiest_tree_count)]
    
    entry = {
        'dataset': df_name,
        'lambda': l,
        'clean_tree_count': clean_tree_count,
        'noisy_tree_count': noisy_tree_count,
        'noisier_tree_count': noisier_tree_count,
        'noisiest_tree_count': noisiest_tree_count,
        'clean_leaves': clean_leaves,
        'noisy_leaves': noisy_leaves,
        'noisier_leaves': noisier_leaves,
        'noisiest_leaves': noisiest_leaves
    }
    
    os.makedirs('./results/noisy_rset/varying_noise', exist_ok=True)
    
    with open(f'./results/noisy_rset/varying_noise/{df_name}.pkl', 'wb') as f:
        pickle.dump(entry, f)


if __name__ == '__main__':
    main()