## imports
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from preprocessing import *
from sklearn.metrics import balanced_accuracy_score, f1_score, roc_auc_score, average_precision_score, confusion_matrix
import random
from ComplexityRejection import ComplRej
from ULER import ULER
import os
from PredictionAmbiguity import AmbRejClas, AmbRejReg
from FaithfulnessRejection import FaithRej
from StabilityRejection import StabRej
from RandomRejection import RandRej
from NoveltyRejection import NovRej
from PASTA import PASTA
import click
import warnings
warnings.filterwarnings("ignore")

def compute_results(dataset, X, y, function, categorical, iteration, task, wrapper):
    shaps = pd.read_csv(f'./explanations/{dataset}.csv')
    shaps.set_index("id", inplace=True)
    shaps = shaps.iloc[:,3:] 
    shaps[shaps > 1] = 0
    shaps[shaps < -1] = 0
    
    llm_labels = pd.read_csv(f'./llm_labels/{dataset}.csv')
    llm_labels.set_index("id", inplace=True)
    
    X['pred_labels'] = y
    X.dropna(inplace=True)
    indexes = np.intersect1d(X.index, llm_labels.index)
    indexes = np.intersect1d(indexes, shaps.index)
    
    X = X.loc[indexes]
    X.sort_index(inplace=True)
    shaps = shaps.loc[indexes]
    shaps.sort_index(inplace=True)
    llm_labels = llm_labels.loc[indexes]
    llm_labels.sort_index(inplace=True)
    
    
    labels = np.where(llm_labels['label'].values >= 3, 0, 1)
    shaps['labels'] = labels    
    
    idx_train, idx_test = train_test_split(shaps.index, test_size=0.5, random_state=iteration, stratify=shaps['labels']) 
    idx_train, idx_val = train_test_split(idx_train, test_size=1/8, random_state=iteration, stratify=shaps.loc[idx_train]['labels'])

    shaps_train = shaps.loc[idx_train]
    shaps_val = shaps.loc[idx_val]
    shaps_test = shaps.loc[idx_test]

    Xtr = X.loc[idx_train]
    Xtr.fillna(-1, inplace=True)
    ytrain = Xtr['pred_labels'].values.copy()
    Xtr.drop(columns=['pred_labels'], inplace=True)
    train_expl_labels = shaps_train['labels'].values
    shaps_train = shaps_train.iloc[:,:-1]

    Xval = X.loc[idx_val]
    Xval.fillna(-1, inplace=True)
    yval = Xval['pred_labels'].values.copy()
    Xval.drop(columns=['pred_labels'], inplace=True)
    val_expl_labels = shaps_val['labels'].values
    shaps_val = shaps_val.iloc[:,:-1]

    Xte = X.loc[idx_test]
    Xte.fillna(-1, inplace=True)
    ytest = Xte['pred_labels'].values.copy()
    Xte.drop(columns=['pred_labels'], inplace=True)
    test_expl_labels = shaps_test['labels'].values
    shaps_test = shaps_test.iloc[:,:-1]

    features_matrix = llm_labels.loc[idx_train]
    features_matrix = features_matrix.iloc[:,2:].values

    contamination = train_expl_labels.mean()
    batch_size = 128
    epochs = 100
    Xtr_enc = wrapper.encoder.transform(Xtr) if wrapper != None else Xtr
    Xva_enc = wrapper.encoder.transform(Xval) if wrapper != None else Xval
    Xte_enc = wrapper.encoder.transform(Xte) if wrapper != None else Xte
    ## Q1
    columns = ["dataset", "contamination", "iteration", "method", "rr", "k", "val_AUROC","tn","fp","fn","tp","f1", "balanced_accuracy"]

    ## Standard rejection setting
    rand_rejector = RandRej(seed=iteration, contamination=contamination)
    rand_rejector.fit(shaps_train)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        rand_rejector.set_rejection_rate(rr)
        rejector_labels = rand_rejector.reject(shaps_test)
        tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
        f1 = f1_score(test_expl_labels, rejector_labels)
        bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
        values = [dataset, contamination, iteration, "RandRej", rr, 0, 0, tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        if os.path.exists(f'./results/Q1/baselines_rr.csv'):
            df.to_csv(f'./results/Q1/baselines_rr.csv', mode='a', index=False, header=False)
        else:
            df.to_csv(f'./results/Q1/baselines_rr.csv', mode='w', index=False, header=True)
    
    if task == 'regression':
        amb_rejector = AmbRejReg(func=function, seed=iteration, contamination=contamination) 
        amb_rejector.fit(Xtr_enc, ytrain, Xtr_enc, Xva_enc)
    else:
        amb_rejector = AmbRejClas(func=function, seed=iteration, contamination=contamination)
        amb_rejector.fit(Xtr, ytrain)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        amb_rejector.set_rejection_rate(rr)
        rejector_labels = amb_rejector.reject(Xte) if task == 'classification' else amb_rejector.reject(Xte_enc)
        tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
        f1 = f1_score(test_expl_labels, rejector_labels)
        bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
        values = [dataset, contamination, iteration, "PredAmb", rr, 0, 0, tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Q1/baselines_rr.csv', mode='a', index=False, header=False)
            
    
    for k in [1, 10, 100]:
        nov_rejector = NovRej(k=k, seed=iteration, contamination=contamination)
        nov_rejector.fit(Xtr_enc, train_expl_labels)
        scores = nov_rejector.score(Xva_enc)
        val_auroc = roc_auc_score(val_expl_labels, scores)
        for rr in np.linspace(0.01, 0.25, 25):
            rr = np.round(rr, 2)
            nov_rejector.set_rejection_rate(rr)
            rejector_labels = nov_rejector.reject(Xte_enc)
            tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
            f1 = f1_score(test_expl_labels, rejector_labels)
            bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
            values = [dataset, contamination, iteration, "NovRej_{X}", rr, k, val_auroc, tn,fp,fn,tp, f1, bacc]
            df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
            df.to_csv(f'./results/Q1/baselines_rr.csv', mode='a', index=False, header=False)

    
    ## Explanation-based rejection baselines
    for k in [1, 10, 100]:
        nov_rejector = NovRej(k=k, seed=iteration, contamination=contamination)
        nov_rejector.fit(shaps_train, train_expl_labels)
        scores = nov_rejector.score(shaps_val)
        val_auroc = roc_auc_score(val_expl_labels, scores)
        for rr in np.linspace(0.01, 0.25, 25):
            rr = np.round(rr, 2)
            nov_rejector.set_rejection_rate(rr)
            rejector_labels = nov_rejector.reject(shaps_test)
            tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
            f1 = f1_score(test_expl_labels, rejector_labels)
            bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
            values = [dataset, contamination, iteration, "NovRej_{Z}", rr, k, val_auroc, tn,fp,fn,tp, f1, bacc]
            df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
            df.to_csv(f'./results/Q1/baselines_rr.csv', mode='a', index=False, header=False)

    metr_rejector = FaithRej(dataset=dataset, func=function, seed=iteration, contamination=contamination)
    metr_rejector.fit(pd.DataFrame(Xtr), shaps_train, categorical)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        metr_rejector.set_rejection_rate(rr)
        rejector_labels = metr_rejector.reject(pd.DataFrame(Xte), shaps_test)
        tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
        f1 = f1_score(test_expl_labels, rejector_labels)
        bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
        values = [dataset, contamination, iteration, "FaithRej", rr, 0, 0, tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Q1/baselines_rr.csv', mode='a', index=False, header=False)

    metr_rejector = StabRej(dataset=dataset,   seed=iteration, contamination=contamination)
    metr_rejector.fit(shaps_train)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        metr_rejector.set_rejection_rate(rr)
        rejector_labels = metr_rejector.reject(shaps_test)
        tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
        f1 = f1_score(test_expl_labels, rejector_labels)
        bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
        values = [dataset, contamination, iteration, "StabRej", rr, 0, 0,tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Q1/baselines_rr.csv', mode='a', index=False, header=False)
        
    complexity_rejector = ComplRej(seed=iteration, contamination=contamination)
    complexity_rejector.fit(shaps_train)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        complexity_rejector.set_rejection_rate(rr)
        rejector_labels = complexity_rejector.reject(shaps_test)
        tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
        f1 = f1_score(test_expl_labels, rejector_labels)
        bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
        values = [dataset, contamination, iteration, "ComplRej", rr, 0, 0, tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Q1/baselines_rr.csv', mode='a', index=False, header=False)
    
    name = "PASTARej"
    columns = ["dataset", "contamination", "iteration", "method", "rr", "alpha", "beta", "gamma", "val_AUROC",
                "tn","fp","fn","tp","f1", "balanced_accuracy"]  
    for alpha in [0.1, 1, 10]:
        for beta in [0.001, 0.01, 0.1]:
            for gamma in [0.01, 0.1, 1]:
                PASTA_rejector = PASTA(dataset=dataset, name=name, seed = iteration, batch_size=batch_size, epochs=epochs, contamination=contamination, alpha=alpha, beta=beta, gamma=gamma)
                PASTA_rejector.fit(shaps_train, train_expl_labels, shaps_val, val_expl_labels)
                scores = PASTA_rejector.score(shaps_val)
                val_auroc = roc_auc_score(val_expl_labels, scores)
                for rr in np.linspace(0.01, 0.25, 25):
                    rr = np.round(rr, 2)
                    PASTA_rejector.set_rejection_rate(rr)
                    rejector_labels = PASTA_rejector.reject(shaps_test)
                    tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
                    f1 = f1_score(test_expl_labels, rejector_labels)
                    bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
                    values = [dataset, contamination, iteration, name, rr,  alpha, beta, gamma, val_auroc, tn,fp,fn,tp, f1, bacc]    
                    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                    if os.path.exists(f'./results/Q1/PASTARej_rr.csv'):
                        df.to_csv(f'./results/Q1/PASTARej_rr.csv', mode='a', index=False, header=False)
                    else:
                        df.to_csv(f'./results/Q1/PASTARej_rr.csv', mode='w', index=False, header=True)
            
            
    columns = ["dataset", "contamination", "iteration", "method", "rr", "kernel","C", "k","eps","val_AUROC",
            "tn","fp","fn","tp","f1", "balanced_accuracy"]
    for kernel in ['linear', 'poly', 'rbf']:
        for C in [0.01,1,100]:
            for k in [3,10,20]:
                for eps in [0.01, 0.1, 1]:
                    filter = ULER(k=k, eps=eps, C=C, kernel=kernel, seed=iteration, contamination=contamination)
                    #filter = RandomForestClassifier(n_estimators=nest, max_depth=depth, class_weight='balanced', random_state=iteration)
                    filter.fit(shaps_train, train_expl_labels, features_matrix)
                    train_scores = filter.score(shaps_train)
                    scores = filter.score(shaps_val)
                    val_auroc = roc_auc_score(val_expl_labels, scores)
                    for rr in np.linspace(0.01, 0.25, 25):
                        rr = np.round(rr, 2)
                        threshold = np.quantile(train_scores, 1-rr)
                        rejector_labels = np.where(filter.score(shaps_test) >= threshold, 1, 0)
                        tn, fp, fn, tp = confusion_matrix(test_expl_labels, rejector_labels).ravel()
                        f1 = f1_score(test_expl_labels, rejector_labels)
                        bacc = balanced_accuracy_score(test_expl_labels, rejector_labels)
                        values = [dataset, contamination, iteration, "ULER", rr, kernel, C, k, eps, val_auroc, tn,fp,fn,tp, f1, bacc]
                        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                        if os.path.exists(f'./results/Q1/LtX_rr.csv'):
                            df.to_csv(f'./results/Q1/LtX_rr.csv', mode='a', index=False, header=False)
                        else:
                            df.to_csv(f'./results/Q1/LtX_rr.csv', mode='w', index=False, header=True)
                    
    print(f"Finished Q1 dataset {dataset}, iteration {iteration}")
    print()

@click.command()
@click.option('--dataset', default='compas', help='Dataset to use')
@click.option('--iteration', default=0, help='Number of iterations')
@click.option('--seed', default=9, help='Seed for reproducibility')
def run_experiment(dataset, iteration, seed):
    np.random.seed(9)
    random.seed(9)
    Xtrain, Xexpl, ytrain, yexpl, categorical, _, task = load_data(dataset, seed)
    if len(Xexpl) > 2000:
        index = np.random.choice(len(Xexpl), 2000, replace=False)
        Xexpl = Xexpl.iloc[index]
        yexpl = yexpl[index]
    wrapper = OHEWrapper(categorical, task, seed, model='svm')
    wrapper.fit(Xtrain, ytrain)
    function = wrapper.predict
    compute_results(dataset, Xexpl, yexpl, function, categorical, iteration, task, wrapper)
    
    
if __name__ == '__main__':
    run_experiment()
    