import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold
from gosdt import GOSDT

from tqdm import tqdm
import os
from os import chdir, devnull

from contextlib import contextmanager,redirect_stderr,redirect_stdout

N_NOISE_VALUES = 51
N_NOISE_DRAWS = 250

@contextmanager
def suppress_stdout_stderr():
    """A context manager that redirects stdout and stderr to devnull"""
    with open(devnull, 'w') as fnull:
        with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
            yield (err, out)
            
def noise_y(y, noise: float, random_seed=42):
    rng = np.random.default_rng(random_seed)
    return (y ^ (rng.random(y.shape[0]) < noise)).astype(int)

def make_expected_noise_dataset(X, y, h, noise, num_draws):
    for i in range(num_draws):
        y_exp = noise_y(y, noise, random_seed=11*i)
        if i == 0:
            X_exp_all = X
            y_exp_all = y_exp
        else:
            X_exp_all = np.concatenate((X_exp_all, X))
            y_exp_all = np.concatenate((y_exp_all, y_exp))
    return pd.DataFrame(X_exp_all, columns=h), pd.DataFrame(y_exp_all)
        
def tune_lambda(X: np.ndarray, y, h, reg_array: np.ndarray, k: int) -> tuple[float, np.float64]:
    """
    Tune regularization for sparse decision trees using k-fold cross-validation.
    """
    best_score = np.float64(0)
    best_lambda = 0.

    # Loop over hyperparameters
    for i, a in enumerate(reg_array):
        test_scores = np.zeros(len(reg_array))
        config = {
        "regularization": a,
        "depth_budget": 5,
        "allow_small_reg": True,
        }
        
        # Initialize k-fold cross-validation
        kf = KFold(n_splits=k, shuffle=True, random_state = 0)

        # Loop over folds
        for j, (train_index, test_index) in enumerate(kf.split(X)):
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]
            
            X_train = pd.DataFrame(X_train, columns=h)
            X_test = pd.DataFrame(X_test, columns=h)
            y_train = pd.DataFrame(y_train)
            y_test = pd.DataFrame(y_test)

            # Train the model on the training data
            model = GOSDT(config)
            with suppress_stdout_stderr():
                model.fit(X_train, y_train)

            # Compute accuracy score and store test score
            score = model.score(X_test, y_test)
            test_scores[i] = score

        # Compute average test score and update best hyperparameters if necessary
        avg_score = np.mean(test_scores)
        # print(f"Lambda: {a}, Average test score: {avg_score}")
        if avg_score >= best_score:
            best_score = avg_score
            best_lambda = a

    return best_lambda

def run_experiment(X_train, X_test, y_train, y_test):
    l = tune_lambda(X_train.values, y_train.values, X_train.columns, np.linspace(0.0005, 0.03, 100), 5)
    config = {
        'regularization': l,
        'max_depth': 5,
    }
    clean_config = config.copy()

    noise_values = np.linspace(0, 0.3, N_NOISE_VALUES)
    clean_train_log = []
    noisy_train_log = []
    for rho in tqdm(noise_values):
        clean_config['regularization'] = l/(1-2*rho)
        
        model = GOSDT(clean_config)
        with suppress_stdout_stderr():
            model = model.fit(X_train, y_train)
        log = {
            'rho': rho,
            'lambda': l/(1-2*rho),
            'train_accuracy': model.score(X_train, y_train),
            'test_accuracy': model.score(X_test, y_test),
            'depth': model.max_depth(),
            'n_leaves': model.leaves(),
            'n_nodes': model.nodes(),
        }
        clean_train_log.append(log)
        
        X_exp, y_exp = make_expected_noise_dataset(X_train.values, y_train.values, X_train.columns, rho, N_NOISE_DRAWS)
        model = GOSDT(config)
        with suppress_stdout_stderr():
            model = model.fit(X_exp, y_exp)
        log = {
            'rho': rho,
            'lambda': l,
            'clean_train_accuracy': model.score(X_train, y_train),
            'noisy_train_accuracy': model.score(X_exp, y_exp),
            'test_accuracy': model.score(X_test, y_test),
            'depth': model.max_depth(),
            'n_leaves': model.leaves(),
            'n_nodes': model.nodes(),
        }
        noisy_train_log.append(log)
    clean_train_log = pd.DataFrame(clean_train_log)
    noisy_train_log = pd.DataFrame(noisy_train_log)
    return clean_train_log, noisy_train_log

data_list = ['amsterdam'
             , 'NIJ_Recidivism'
             , 'broward'
             , 'compas'
             , 'fico'
             , 'bankfull'
             , 'australian_credit'
             , 'german_credit'
             , 'GiveMeSomeCredit'
             , 'polish_companies'
             , 'telco_churn'
             , 'iranian_churn'
             , 'occupancy_detection'
             , 'car_evaluation'
             , 'monks2'
             , 'monks1'
             , 'monks3'
             , 'bar7'
             , 'bcw_bin'
             , 'carryout_takeaway'
             , 'restaurant_20'
             , 'bar'
             , 'coffee_house']

def load_data(job_id: int) -> pd.DataFrame:
    """
    Load a dataset based on the job_id assigned by SLURM.
    note: for this to work for N different datasets
    must pass the argument --array=1-N to sbatch
    """
    name = data_list[job_id]
    print(f"Loading {name} dataset.")
    dataset = '../datasets/binarized/' + name + '.csv'
    return pd.read_csv(dataset)

def main():
    job_id = int(os.getenv('SLURM_ARRAY_TASK_ID')) - 1 # type: ignore
    df = load_data(job_id)
    df_name = data_list[job_id]
    print(f"{df_name} dataset has {df.shape[0]} examples and {df.shape[1]} features.")
    X, y = df.iloc[:,:-1], df.iloc[:,-1]
    h = df.columns[:-1]
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    clean_train_log, noisy_train_log = run_experiment(X_train, X_test, y_train, y_test)
    
    # make directory if not already a directory
    base_dir = f'./results/expected_noise/values{N_NOISE_VALUES}_{N_NOISE_DRAWS}/'
    if not os.path.exists(base_dir):
        os.makedirs(base_dir, exist_ok=True)
    
    clean_train_log.to_pickle(base_dir + f'{df_name}_clean_d5.pkl')
    noisy_train_log.to_pickle(base_dir + f'{df_name}_noise_d5.pkl')
    

if __name__ == "__main__":
    main()