## imports
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from preprocessing import *
from sklearn.metrics import recall_score, precision_score, balanced_accuracy_score, f1_score, roc_auc_score, average_precision_score
from sklearn.svm import SVC
import random
import os
import click
import warnings
from ULER import ULER
from ULER_XZ import ULER_XZ
from ULER_YZ import ULER_YZ
from ULER_XYZ import ULER_XYZ
warnings.filterwarnings("ignore")

def compute_results(dataset, X, y, function, categorical, iteration, task, wrapper):
    shaps = pd.read_csv(f'./explanations/{dataset}.csv')
    shaps.set_index("id", inplace=True)
    shaps = shaps.iloc[:,3:] 
    shaps[shaps > 1] = 0
    shaps[shaps < -1] = 0
    
    llm_labels = pd.read_csv(f'./llm_labels/{dataset}.csv')
    llm_labels.set_index("id", inplace=True)
    
    X['pred_labels'] = y
    X.dropna(inplace=True)
    indexes = np.intersect1d(X.index, llm_labels.index)
    indexes = np.intersect1d(indexes, shaps.index)
    
    X = X.loc[indexes]
    X.sort_index(inplace=True)
    shaps = shaps.loc[indexes]
    shaps.sort_index(inplace=True)
    llm_labels = llm_labels.loc[indexes]
    llm_labels.sort_index(inplace=True)
    
    
    labels = np.where(llm_labels['label'].values >= 3, 0, 1)
    shaps['labels'] = labels    
    
    idx_train, idx_test = train_test_split(shaps.index, test_size=0.5, random_state=iteration, stratify=shaps['labels']) 
    idx_train, idx_val = train_test_split(idx_train, test_size=1/8, random_state=iteration, stratify=shaps.loc[idx_train]['labels'])

    shaps_train = shaps.loc[idx_train]
    shaps_val = shaps.loc[idx_val]
    shaps_test = shaps.loc[idx_test]

    Xtr = X.loc[idx_train]
    Xtr.fillna(-1, inplace=True)
    ytrain = Xtr['pred_labels'].values.copy()
    Xtr.drop(columns=['pred_labels'], inplace=True)
    train_expl_labels = shaps_train['labels'].values
    shaps_train = shaps_train.iloc[:,:-1]

    Xval = X.loc[idx_val]
    Xval.fillna(-1, inplace=True)
    yval = Xval['pred_labels'].values.copy()
    Xval.drop(columns=['pred_labels'], inplace=True)
    val_expl_labels = shaps_val['labels'].values
    shaps_val = shaps_val.iloc[:,:-1]

    Xte = X.loc[idx_test]
    Xte.fillna(-1, inplace=True)
    ytest = Xte['pred_labels'].values.copy()
    Xte.drop(columns=['pred_labels'], inplace=True)
    test_expl_labels = shaps_test['labels'].values
    shaps_test = shaps_test.iloc[:,:-1]

    features_matrix = llm_labels.loc[idx_train]
    features_matrix = features_matrix.iloc[:,2:].values

    contamination = train_expl_labels.mean()
    batch_size = 128
    epochs = 100
    Xtr_enc = wrapper.encoder.transform(Xtr) if wrapper != None else Xtr
    Xva_enc = wrapper.encoder.transform(Xval) if wrapper != None else Xval
    Xte_enc = wrapper.encoder.transform(Xte) if wrapper != None else Xte
    
    ypred_tr = function(Xtr) if task != 'classification' else function(Xtr)[:,1]
    ypred_val = function(Xval) if task != 'classification' else function(Xval)[:,1]
    ypred_te = function(Xte) if task != 'classification' else function(Xte)[:,1]
    
                    
    columns = ["dataset", "contamination", "iteration", "method","kernel","C", "k","eps","val_AUROC",
                            "precision", "recall", "f1", "balanced_accuracy", "AUROC", "AUPR"]

    for kernel in ['linear', 'poly', 'rbf']:
        for C in [0.01,1,100]:
            filter = ULER(kernel=kernel, C=C, seed=iteration, contamination=contamination, augment_the_data=False)
            filter.fit(shaps_train, train_expl_labels, features_matrix)
            filter_test_scores = filter.score(shaps_test)
            filter_val_scores = filter.score(shaps_val)
            filter_threshold = find_best_threshold(filter_val_scores, val_expl_labels)
            filter_val_auroc = roc_auc_score(val_expl_labels, filter_val_scores)
            filter_test_pred = np.where(filter_test_scores >= filter_threshold, 1, 0)
            filter_precision = precision_score(test_expl_labels, filter_test_pred, zero_division=0)
            filter_recall = recall_score(test_expl_labels, filter_test_pred)
            filter_f1 = f1_score(test_expl_labels, filter_test_pred)
            filter_bacc = balanced_accuracy_score(test_expl_labels, filter_test_pred)
            filter_auroc = roc_auc_score(test_expl_labels, filter_test_scores)
            filter_aupr = average_precision_score(test_expl_labels, filter_test_scores)
            values = [dataset, contamination, iteration, "ULER-AUG", kernel, C,0,0, filter_val_auroc,
                    filter_precision, filter_recall, filter_f1, filter_bacc, filter_auroc, filter_aupr]
            df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
            if os.path.exists(f'./results/Q2/LtX.csv'):
                df.to_csv(f'./results/Q2/LtX.csv', mode='a', index=False, header=False)
            else:
                df.to_csv(f'./csvFiles/Q2/LtX.csv', mode='w', index=False, header=True)
    
            for k in [3,10,20]:
                for eps in [0.01, 0.1, 1]:                        
                    filter = ULER_XZ(kernel=kernel, C=C, seed=iteration, contamination=contamination)
                    filter.fit(Xtr_enc, shaps_train, train_expl_labels, features_matrix)
                    filter_test_scores = filter.score(Xte_enc, shaps_test)
                    filter_val_scores = filter.score(Xva_enc, shaps_val)
                    filter_threshold = find_best_threshold(filter_val_scores, val_expl_labels)
                    filter_val_auroc = roc_auc_score(val_expl_labels, filter_val_scores)
                    filter_test_pred = np.where(filter_test_scores >= filter_threshold, 1, 0)
                    filter_precision = precision_score(test_expl_labels, filter_test_pred, zero_division=0)
                    filter_recall = recall_score(test_expl_labels, filter_test_pred)
                    filter_f1 = f1_score(test_expl_labels, filter_test_pred)
                    filter_bacc = balanced_accuracy_score(test_expl_labels, filter_test_pred)
                    filter_auroc = roc_auc_score(test_expl_labels, filter_test_scores)
                    filter_aupr = average_precision_score(test_expl_labels, filter_test_scores)
                    values = [dataset, contamination, iteration, "ULER_XZ", kernel, C,0,0, filter_val_auroc,
                            filter_precision, filter_recall, filter_f1, filter_bacc, filter_auroc, filter_aupr]
                    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                    df.to_csv(f'./csvFiles/Q2/LtX.csv', mode='a', index=False, header=False)
                        
                    filter = ULER_YZ(kernel=kernel, C=C, seed=iteration, contamination=contamination)
                    filter.fit(ypred_tr, shaps_train, train_expl_labels, features_matrix)
                    filter_test_scores = filter.score(ypred_te, shaps_test)
                    filter_val_scores = filter.score(ypred_val, shaps_val)
                    filter_threshold = find_best_threshold(filter_val_scores, val_expl_labels)
                    filter_val_auroc = roc_auc_score(val_expl_labels, filter_val_scores)
                    filter_test_pred = np.where(filter_test_scores >= filter_threshold, 1, 0)
                    filter_precision = precision_score(test_expl_labels, filter_test_pred, zero_division=0)
                    filter_recall = recall_score(test_expl_labels, filter_test_pred)
                    filter_f1 = f1_score(test_expl_labels, filter_test_pred)
                    filter_bacc = balanced_accuracy_score(test_expl_labels, filter_test_pred)
                    filter_auroc = roc_auc_score(test_expl_labels, filter_test_scores)
                    filter_aupr = average_precision_score(test_expl_labels, filter_test_scores)
                    values = [dataset, contamination, iteration, "ULER_YZ", kernel, C,0,0, filter_val_auroc,
                            filter_precision, filter_recall, filter_f1, filter_bacc, filter_auroc, filter_aupr]
                    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                    df.to_csv(f'./csvFiles/Q2/LtX.csv', mode='a', index=False, header=False)
            
                    filter = ULER_XYZ(kernel=kernel, C=C, seed=iteration, contamination=contamination)
                    filter.fit(Xtr_enc, ypred_tr, shaps_train, train_expl_labels, features_matrix)
                    filter_test_scores = filter.score(Xte_enc, ypred_te, shaps_test)
                    filter_val_scores = filter.score(Xva_enc, ypred_val, shaps_val)
                    filter_threshold = find_best_threshold(filter_val_scores, val_expl_labels)
                    filter_val_auroc = roc_auc_score(val_expl_labels, filter_val_scores)
                    filter_test_pred = np.where(filter_test_scores >= filter_threshold, 1, 0)
                    filter_precision = precision_score(test_expl_labels, filter_test_pred, zero_division=0)
                    filter_recall = recall_score(test_expl_labels, filter_test_pred)
                    filter_f1 = f1_score(test_expl_labels, filter_test_pred)
                    filter_bacc = balanced_accuracy_score(test_expl_labels, filter_test_pred)
                    filter_auroc = roc_auc_score(test_expl_labels, filter_test_scores)
                    filter_aupr = average_precision_score(test_expl_labels, filter_test_scores)
                    values = [dataset, contamination, iteration, "ULER_XYZ", kernel, C,0,0, filter_val_auroc,
                            filter_precision, filter_recall, filter_f1, filter_bacc, filter_auroc, filter_aupr]
                    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                    df.to_csv(f'./csvFiles/Q2/LtX.csv', mode='a', index=False, header=False)  


    print(f"Finished Q2 dataset {dataset}, iteration {iteration}")
    print()

@click.command()
@click.option('--dataset', default='compas', help='Dataset to use')
@click.option('--iteration', default=0, help='Number of iterations')
@click.option('--seed', default=9, help='Seed for reproducibility')
def run_experiment(dataset, iteration, seed):
    np.random.seed(9)
    random.seed(9)
    Xtrain, Xexpl, ytrain, yexpl, categorical, cat_idx, task = load_data(dataset, seed)
    if len(Xexpl) > 2000:
        index = np.random.choice(len(Xexpl), 2000, replace=False)
        Xexpl = Xexpl.iloc[index]
        yexpl = yexpl[index]
    wrapper = OHEWrapper(categorical, task, seed, model='svm')
    wrapper.fit(Xtrain, ytrain)
    function = wrapper.predict
    compute_results(dataset, Xexpl, yexpl, function, categorical, iteration, task, wrapper)
    
    
if __name__ == '__main__':
    run_experiment()
    