## imports
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from preprocessing import *
from sklearn.metrics import recall_score, precision_score, balanced_accuracy_score, f1_score, roc_auc_score, average_precision_score
from sklearn.svm import SVC
import random
from ComplexityRejection import ComplRej
import os
from PredictionAmbiguity import AmbRejClas, AmbRejReg
from FaithfulnessRejection import FaithRej
from StabilityRejection import StabRej
from RandomRejection import RandRej
from NoveltyRejection import NovRej
from PASTA import PASTA
import click
import warnings
from ULER import ULER
warnings.filterwarnings("ignore")

def compute_results(dataset, X, y, function, categorical, iteration, task, wrapper):
    shaps = pd.read_csv(f'./explanations/{dataset}.csv')
    shaps.set_index("id", inplace=True)
    shaps = shaps.iloc[:,3:] 
    shaps[shaps > 1] = 0
    shaps[shaps < -1] = 0
    
    llm_labels = pd.read_csv(f'./llm_labels/{dataset}.csv')
    llm_labels.set_index("id", inplace=True)
    
    X['pred_labels'] = y
    X.dropna(inplace=True)
    indexes = np.intersect1d(X.index, llm_labels.index)
    indexes = np.intersect1d(indexes, shaps.index)
    
    X = X.loc[indexes]
    X.sort_index(inplace=True)
    shaps = shaps.loc[indexes]
    shaps.sort_index(inplace=True)
    llm_labels = llm_labels.loc[indexes]
    llm_labels.sort_index(inplace=True)
    
    
    labels = np.where(llm_labels['label'].values >= 3, 0, 1)
    shaps['labels'] = labels    
    
    idx_train, idx_test = train_test_split(shaps.index, test_size=0.5, random_state=iteration, stratify=shaps['labels']) 
    idx_train, idx_val = train_test_split(idx_train, test_size=1/8, random_state=iteration, stratify=shaps.loc[idx_train]['labels'])

    shaps_train = shaps.loc[idx_train]
    shaps_val = shaps.loc[idx_val]
    shaps_test = shaps.loc[idx_test]

    Xtr = X.loc[idx_train]
    Xtr.fillna(-1, inplace=True)
    ytrain = Xtr['pred_labels'].values.copy()
    Xtr.drop(columns=['pred_labels'], inplace=True)
    train_expl_labels = shaps_train['labels'].values
    shaps_train = shaps_train.iloc[:,:-1]

    Xval = X.loc[idx_val]
    Xval.fillna(-1, inplace=True)
    yval = Xval['pred_labels'].values.copy()
    Xval.drop(columns=['pred_labels'], inplace=True)
    val_expl_labels = shaps_val['labels'].values
    shaps_val = shaps_val.iloc[:,:-1]

    Xte = X.loc[idx_test]
    Xte.fillna(-1, inplace=True)
    ytest = Xte['pred_labels'].values.copy()
    Xte.drop(columns=['pred_labels'], inplace=True)
    test_expl_labels = shaps_test['labels'].values
    shaps_test = shaps_test.iloc[:,:-1]

    features_matrix = llm_labels.loc[idx_train]
    features_matrix = features_matrix.iloc[:,2:].values

    contamination = train_expl_labels.mean()
    batch_size = 128
    epochs = 100
    Xtr_enc = wrapper.encoder.transform(Xtr) if wrapper != None else Xtr
    Xva_enc = wrapper.encoder.transform(Xval) if wrapper != None else Xval
    Xte_enc = wrapper.encoder.transform(Xte) if wrapper != None else Xte
    
    ## Q1
    columns = ["dataset", "contamination", "iteration", "method", "k", "val_AUROC", "precision", "recall", "f1", "balanced_accuracy", "AUROC", "AUPR"]

    ## Standard rejection setting
    rand_rejector = RandRej(seed=iteration, contamination=contamination)
    rand_rejector.fit(shaps_train)
    rand_rejector_scores = rand_rejector.score(shaps_test)
    rand_rejector_labels = rand_rejector.reject(shaps_test)
    rand_precision = precision_score(test_expl_labels, rand_rejector_labels, zero_division=0)
    rand_recall = recall_score(test_expl_labels, rand_rejector_labels)
    rand_f1 = f1_score(test_expl_labels, rand_rejector_labels)
    rand_bacc = balanced_accuracy_score(test_expl_labels, rand_rejector_labels)
    rand_auroc = roc_auc_score(test_expl_labels, rand_rejector_scores)
    rand_aupr = average_precision_score(test_expl_labels, rand_rejector_scores)
    values = [dataset, contamination, iteration, "RandRej", 0, 0, rand_precision, rand_recall, rand_f1, rand_bacc, rand_auroc, rand_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    if os.path.exists(f'./results/Q1/baselines.csv'):
        df.to_csv(f'./results/Q1/baselines.csv', mode='a', index=False, header=False)
    else:
        df.to_csv(f'./results/Q1/baselines.csv', mode='w', index=False, header=True)
        
    if task == 'regression':
        amb_rejector = AmbRejReg(func=function, seed=iteration, contamination=contamination)  
        amb_rejector.fit(Xtr_enc, ytrain, Xtr_enc, Xva_enc)
    else:
        amb_rejector = AmbRejClas(func=function, seed=iteration, contamination=contamination)
        amb_rejector.fit(Xtr, ytrain)
    amb_rejector_scores = amb_rejector.score(Xte_enc) if task != 'regression' else amb_rejector.score(Xte)
    amb_rejector_labels = amb_rejector.reject(Xte_enc) if task != 'regression' else amb_rejector.reject(Xte)
    amb_precision = precision_score(test_expl_labels, amb_rejector_labels, zero_division=0)
    amb_recall = recall_score(test_expl_labels, amb_rejector_labels)
    amb_f1 = f1_score(test_expl_labels, amb_rejector_labels)
    amb_bacc = balanced_accuracy_score(test_expl_labels, amb_rejector_labels)
    amb_auroc = roc_auc_score(test_expl_labels, amb_rejector_scores)
    amb_aupr = average_precision_score(test_expl_labels, amb_rejector_scores)
    values = [dataset, contamination, iteration, "PredAmb", 0, 0, amb_precision, amb_recall, amb_f1, amb_bacc, amb_auroc, amb_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    df.to_csv(f'./results/Q1/baselines.csv', mode='a', index=False, header=False)
            
    
    for k in [1, 10, 100]:
        nov_rejector = NovRej(k=k, seed=iteration, contamination=contamination)
        nov_rejector.fit(Xtr_enc, train_expl_labels)
        nov_rejector_val_scores = nov_rejector.score(Xva_enc)
        nov_val_auroc = roc_auc_score(val_expl_labels, nov_rejector_val_scores)
        nov_rejector_scores = nov_rejector.score(Xte_enc)
        nov_rejector_labels = nov_rejector.reject(Xte_enc)
        nov_precision = precision_score(test_expl_labels, nov_rejector_labels, zero_division=0)
        nov_recall = recall_score(test_expl_labels, nov_rejector_labels)
        nov_f1 = f1_score(test_expl_labels, nov_rejector_labels)
        nov_bacc = balanced_accuracy_score(test_expl_labels, nov_rejector_labels)
        nov_auroc = roc_auc_score(test_expl_labels, nov_rejector_scores)
        nov_aupr = average_precision_score(test_expl_labels, nov_rejector_scores)
        values = [dataset, contamination, iteration, "NovRej_{X}", k ,nov_val_auroc, nov_precision, nov_recall, nov_f1, nov_bacc, nov_auroc, nov_aupr]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Q1/baselines.csv', mode='a', index=False, header=False)
    
    ## Explanation-based rejection baselines
    for k in [1, 10, 100]:
        nov_rejector = NovRej(k=k, seed=iteration, contamination=contamination)
        nov_rejector.fit(shaps_train, train_expl_labels)
        nov_rejector_val_scores = nov_rejector.score(shaps_val)
        nov_val_auroc = roc_auc_score(val_expl_labels, nov_rejector_val_scores)
        nov_rejector_scores = nov_rejector.score(shaps_test)
        nov_rejector_labels = nov_rejector.reject(shaps_test)
        nov_precision = precision_score(test_expl_labels, nov_rejector_labels, zero_division=0)
        nov_recall = recall_score(test_expl_labels, nov_rejector_labels)
        nov_f1 = f1_score(test_expl_labels, nov_rejector_labels)
        nov_bacc = balanced_accuracy_score(test_expl_labels, nov_rejector_labels)
        nov_auroc = roc_auc_score(test_expl_labels, nov_rejector_scores)
        nov_aupr = average_precision_score(test_expl_labels, nov_rejector_scores)
        values = [dataset, contamination, iteration, "NovRej_{Z}", k, nov_val_auroc, nov_precision, nov_recall, nov_f1, nov_bacc, nov_auroc, nov_aupr]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Q1/baselines.csv', mode='a', index=False, header=False)

    metr_rejector = FaithRej(dataset=dataset, func=function, seed=iteration, contamination=contamination)
    metr_rejector.fit(pd.DataFrame(Xtr), shaps_train, categorical)
    metr_rejector_scores = metr_rejector.score(Xte, shaps_test)
    metr_rejector_labels = metr_rejector.reject(Xte, shaps_test)
    metr_precision = precision_score(test_expl_labels, metr_rejector_labels, zero_division=0)
    metr_recall = recall_score(test_expl_labels, metr_rejector_labels)
    metr_f1 = f1_score(test_expl_labels, metr_rejector_labels)
    metr_bacc = balanced_accuracy_score(test_expl_labels, metr_rejector_labels)
    metr_auroc = roc_auc_score(test_expl_labels, metr_rejector_scores)
    metr_aupr = average_precision_score(test_expl_labels, metr_rejector_scores)
    values = [dataset, contamination, iteration, "FaithRej", 0, 0, metr_precision, metr_recall, metr_f1, metr_bacc, metr_auroc, metr_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    df.to_csv(f'./results/Q1/baselines.csv', mode='a', index=False, header=False)

    metr_rejector = StabRej(dataset=dataset, seed=iteration, contamination=contamination)
    metr_rejector.fit(shaps_train)
    metr_rejector_scores = metr_rejector.score(shaps_test)
    metr_rejector_labels = metr_rejector.reject(shaps_test)
    metr_precision = precision_score(test_expl_labels, metr_rejector_labels, zero_division=0)
    metr_recall = recall_score(test_expl_labels, metr_rejector_labels)
    metr_f1 = f1_score(test_expl_labels, metr_rejector_labels)
    metr_bacc = balanced_accuracy_score(test_expl_labels, metr_rejector_labels)
    metr_auroc = roc_auc_score(test_expl_labels, metr_rejector_scores)
    metr_aupr = average_precision_score(test_expl_labels, metr_rejector_scores)
    values = [dataset, contamination, iteration, "StabRej", 0, 0, metr_precision, metr_recall, metr_f1, metr_bacc, metr_auroc, metr_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    df.to_csv(f'./results/Q1/baselines.csv', mode='a', index=False, header=False)
    
    compl_rejector = ComplRej(seed=iteration, contamination=contamination)
    compl_rejector.fit(shaps_train)
    compl_rejector_scores = compl_rejector.score(shaps_test)
    compl_rejector_labels = compl_rejector.reject(shaps_test)
    compl_precision = precision_score(test_expl_labels, compl_rejector_labels, zero_division=0)
    compl_recall = recall_score(test_expl_labels, compl_rejector_labels)
    compl_f1 = f1_score(test_expl_labels, compl_rejector_labels)
    compl_bacc = balanced_accuracy_score(test_expl_labels, compl_rejector_labels)
    compl_auroc = roc_auc_score(test_expl_labels, compl_rejector_scores)
    compl_aupr = average_precision_score(test_expl_labels, compl_rejector_scores)
    values = [dataset, contamination, iteration, "ComplRej", 0, 0, compl_precision, compl_recall, compl_f1, compl_bacc, compl_auroc, compl_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    df.to_csv(f'./results/Q1/baselines.csv', mode='a', index=False, header=False)
    
    columns = ["dataset", "contamination", "iteration", "method", "alpha", "beta", "gamma", "val_AUROC",
            "precision", "recall", "f1", "balanced_accuracy", "AUROC", "AUPR",
            "train_precision", "train_recall", "train_f1", "train_balanced_accuracy", "train_AUROC", "train_AUPR"]   
    for alpha in [0.1, 1, 10]:
        for beta in [0.001, 0.01, 0.1]:
            for gamma in [0.01, 0.1, 1]:
                PASTA_rejector = PASTA(dataset=dataset, seed = iteration, batch_size=batch_size, epochs=epochs, contamination=contamination, alpha=alpha, beta=beta, gamma=gamma)
                PASTA_rejector.fit(shaps_train, train_expl_labels, shaps_val, val_expl_labels)
                PASTA_val_scores = PASTA_rejector.score(shaps_val)
                PASTA_threshold = find_best_threshold(PASTA_val_scores, val_expl_labels) 
                PASTA_val_auroc = roc_auc_score(val_expl_labels, PASTA_val_scores)
                PASTA_scores = PASTA_rejector.score(shaps_test)
                PASTA_labels = np.where(PASTA_scores >= PASTA_threshold, 1, 0)
                PASTA_precision = precision_score(test_expl_labels, PASTA_labels, zero_division=0)
                PASTA_recall = recall_score(test_expl_labels, PASTA_labels)
                PASTA_f1 = f1_score(test_expl_labels, PASTA_labels)
                PASTA_bacc = balanced_accuracy_score(test_expl_labels, PASTA_labels)
                PASTA_auroc = roc_auc_score(test_expl_labels, PASTA_scores)
                PASTA_aupr = average_precision_score(test_expl_labels, PASTA_scores)
                PASTA_train_scores = PASTA_rejector.score(shaps_train)
                PASTA_train_labels = np.where(PASTA_train_scores >= PASTA_threshold, 1, 0)
                PASTA_train_precision = precision_score(train_expl_labels, PASTA_train_labels, zero_division=0)
                PASTA_train_recall = recall_score(train_expl_labels, PASTA_train_labels)
                PASTA_train_f1 = f1_score(train_expl_labels, PASTA_train_labels)
                PASTA_train_bacc = balanced_accuracy_score(train_expl_labels, PASTA_train_labels)
                PASTA_train_auroc = roc_auc_score(train_expl_labels, PASTA_train_scores)
                PASTA_train_aupr = average_precision_score(train_expl_labels, PASTA_train_scores)
                values = [dataset, contamination, iteration, "PASTARej", alpha, beta, gamma, PASTA_val_auroc,
                        PASTA_precision, PASTA_recall, PASTA_f1, PASTA_bacc, PASTA_auroc, PASTA_aupr,
                        PASTA_train_precision, PASTA_train_recall, PASTA_train_f1, PASTA_train_bacc, PASTA_train_auroc, PASTA_train_aupr]
                df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                if os.path.exists(f'./results/Q1/PASTARej.csv'):
                    df.to_csv(f'./results/Q1/PASTARej.csv', mode='a', index=False, header=False)
                else:
                    df.to_csv(f'./results/Q1/PASTARej.csv', mode='w', index=False, header=True)
            
            
    columns = ["dataset", "contamination", "iteration", "method", "kernel","C", "k","eps","val_AUROC",
            "precision", "recall", "f1", "balanced_accuracy", "AUROC", "AUPR",
            "train_precision", "train_recall", "train_f1", "train_balanced_accuracy", "train_AUROC", "train_AUPR"]
    for kernel in ['linear', 'poly', 'rbf']:
        for C in [0.01,1,100]:
            for k in [3,10,20]:
                for eps in [0.01, 0.1, 1]:
                    filter = ULER(C=C, kernel=kernel, eps=eps, k=k, seed=iteration, contamination=contamination)
                    filter.fit(shaps_train, train_expl_labels, features_matrix)
                    filter_test_scores = filter.score(shaps_test)
                    filter_val_scores = filter.score(shaps_val)
                    filter_threshold = find_best_threshold(filter_val_scores, val_expl_labels)
                    filter_val_auroc = roc_auc_score(val_expl_labels, filter_val_scores)
                    filter_test_pred = np.where(filter_test_scores >= filter_threshold, 1, 0)
                    filter_precision = precision_score(test_expl_labels, filter_test_pred, zero_division=0)
                    filter_recall = recall_score(test_expl_labels, filter_test_pred)
                    filter_f1 = f1_score(test_expl_labels, filter_test_pred)
                    filter_bacc = balanced_accuracy_score(test_expl_labels, filter_test_pred)
                    filter_auroc = roc_auc_score(test_expl_labels, filter_test_scores)
                    filter_aupr = average_precision_score(test_expl_labels, filter_test_scores)
                    filter_train_scores = filter.score(shaps_train)
                    filter_train_pred = np.where(filter_train_scores >= filter_threshold, 1, 0)
                    filter_train_precision = precision_score(train_expl_labels, filter_train_pred, zero_division=0)
                    filter_train_recall = recall_score(train_expl_labels, filter_train_pred)
                    filter_train_f1 = f1_score(train_expl_labels, filter_train_pred)
                    filter_train_bacc = balanced_accuracy_score(train_expl_labels, filter_train_pred)
                    filter_train_auroc = roc_auc_score(train_expl_labels, filter_train_scores)
                    filter_train_aupr = average_precision_score(train_expl_labels, filter_train_scores)
                    values = [dataset, contamination, iteration, "ULER", kernel, C,k,eps, filter_val_auroc,
                            filter_precision, filter_recall, filter_f1, filter_bacc, filter_auroc, filter_aupr,
                            filter_train_precision, filter_train_recall, filter_train_f1, filter_train_bacc, filter_train_auroc, filter_train_aupr]
                    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                    if os.path.exists(f'./results/Q1/LtX.csv'):
                        df.to_csv(f'./results/Q1/LtX.csv', mode='a', index=False, header=False)
                    else:
                        df.to_csv(f'./results/Q1/LtX.csv', mode='w', index=False, header=True)
                    
    print(f"Finished Q1 dataset {dataset}, iteration {iteration}")
    print()

@click.command()
@click.option('--dataset', default='compas', help='Dataset to use')
@click.option('--iteration', default=0, help='Number of iterations')
@click.option('--seed', default=9, help='Seed for reproducibility')
def run_experiment(dataset, iteration, seed):
    np.random.seed(9)
    random.seed(9)
    Xtrain, Xexpl, ytrain, yexpl, categorical, cat_idx, task = load_data(dataset, seed)
    if len(Xexpl) > 2000:
        index = np.random.choice(len(Xexpl), 2000, replace=False)
        Xexpl = Xexpl.iloc[index]
        yexpl = yexpl[index]
    wrapper = OHEWrapper(categorical, task, seed, model='svm')
    wrapper.fit(Xtrain, ytrain)
    function = wrapper.predict
    compute_results(dataset, Xexpl, yexpl, function, categorical, iteration, task, wrapper)
    
    
if __name__ == '__main__':
    run_experiment()
    