import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.isotonic import spearmanr
from sklearn.metrics import roc_auc_score
from ULER import ULER
from sklearn.model_selection import train_test_split
from preprocessing import *
import random
from ComplexityRejection import ComplRej
import os
from FaithfulnessRejection import FaithRej
from StabilityRejection import StabRej
import click
import warnings
warnings.filterwarnings("ignore")

def compute_results(dataset, X, y, function, categorical, iteration, task, wrapper):
    shaps = pd.read_csv(f'./explanations/{dataset}.csv')
    shaps = shaps.iloc[:,1:] 
    shaps[shaps > 1] = 0
    shaps[shaps < -1] = 0
    
    llm_labels = pd.read_csv(f'./llm_labels/{dataset}.csv')
    llm_labels.set_index("id", inplace=True)
    
    X['pred_labels'] = y
    X.dropna(inplace=True)
    indexes = np.intersect1d(X.index, llm_labels.index)
    indexes = np.intersect1d(indexes, shaps.index)
    
    X = X.loc[indexes]
    X.sort_index(inplace=True)
    shaps = shaps.loc[indexes]
    shaps.sort_index(inplace=True)
    llm_labels = llm_labels.loc[indexes]
    llm_labels.sort_index(inplace=True)
    
    
    labels = np.where(llm_labels['label'].values >= 3, 0, 1)
    shaps['labels'] = labels    
    
    idx_train, idx_test = train_test_split(shaps.index, test_size=0.5, random_state=iteration, stratify=shaps['labels']) 
    idx_train, idx_val = train_test_split(idx_train, test_size=1/8, random_state=iteration, stratify=shaps.loc[idx_train]['labels'])

    shaps_train = shaps.loc[idx_train]
    shaps_val = shaps.loc[idx_val]
    shaps_test = shaps.loc[idx_test]

    Xtr = X.loc[idx_train]
    Xtr.fillna(-1, inplace=True)
    ytrain = Xtr['pred_labels'].values.copy()
    Xtr.drop(columns=['pred_labels'], inplace=True)
    train_expl_labels = shaps_train['labels'].values
    shaps_train = shaps_train.iloc[:,:-1]

    Xval = X.loc[idx_val]
    Xval.fillna(-1, inplace=True)
    yval = Xval['pred_labels'].values.copy()
    Xval.drop(columns=['pred_labels'], inplace=True)
    val_expl_labels = shaps_val['labels'].values
    shaps_val = shaps_val.iloc[:,:-1]

    Xte = X.loc[idx_test]
    Xte.fillna(-1, inplace=True)
    ytest = Xte['pred_labels'].values.copy()
    Xte.drop(columns=['pred_labels'], inplace=True)
    test_expl_labels = shaps_test['labels'].values
    shaps_test = shaps_test.iloc[:,:-1]

    features_matrix = llm_labels.loc[idx_train]
    features_matrix = features_matrix.iloc[:,2:].values

    contamination = train_expl_labels.mean()
    Xtr_enc = wrapper.encoder.transform(Xtr) if wrapper != None else Xtr
    
    ypred_tr = function(Xtr) if task != 'classification' else function(Xtr)[:,1]
    metr_rejector = FaithRej(dataset=dataset, func=function, seed=iteration, contamination=contamination)
    metr_rejector.fit(pd.DataFrame(Xtr), shaps_train, categorical)
    faithfulness_test = metr_rejector.score(Xte, shaps_test)

    metr_rejector = StabRej(dataset=dataset, seed=iteration, contamination=contamination)
    metr_rejector.fit(shaps_train)
    stability_test = metr_rejector.score(shaps_test)

    compl_rejector = ComplRej(seed=iteration, contamination=contamination)
    compl_rejector.fit(shaps_train)
    complexity_test = compl_rejector.score(shaps_test)
    

    columns = ["dataset", "contamination", "iteration", "method", "kernel","C", "k","eps","val_AUROC",
            "corr_faith", "spear_faith",
            "corr_stab", "spear_stab",
            "corr_comp", "spear_comp"]
    for kernel in ['linear', 'poly', 'rbf']:
        for C in [0.01,1,100]:
            for k in [3,10,20]:
                for eps in [0.01, 0.1, 1]:
                    filter = ULER(C=C, kernel=kernel, eps=eps, k=k, seed=iteration, contamination=contamination)
                    filter.fit(shaps_train, train_expl_labels, features_matrix)
                    filter_test_scores = filter.score(shaps_test)
                    filter_val_scores = filter.score(shaps_val)
                    filter_threshold = find_best_threshold(filter_val_scores, val_expl_labels)
                    filter_val_auroc = roc_auc_score(val_expl_labels, filter_val_scores)
                    correlation_ULER_faithfulness = np.corrcoef(filter_test_scores, faithfulness_test)[0,1]
                    spearman_ULER_faithfulness = spearmanr(filter_test_scores, faithfulness_test)[0]
                    correlation_ULER_stability = np.corrcoef(filter_test_scores, stability_test)[0,1]
                    spearman_ULER_stability = spearmanr(filter_test_scores, stability_test)[0]
                    correlation_ULER_complexity = np.corrcoef(filter_test_scores, complexity_test)[0,1]
                    spearman_ULER_complexity = spearmanr(filter_test_scores, complexity_test)[0]
                    values = [dataset, contamination, iteration, kernel, C,k,eps, filter_val_auroc,
                            correlation_ULER_faithfulness, spearman_ULER_faithfulness,
                            correlation_ULER_stability, spearman_ULER_stability,
                            correlation_ULER_complexity, spearman_ULER_complexity,]
                    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                    if os.path.exists(f'./results/Appendix/correlation/simulated.csv'):
                        df.to_csv(f'./results/Appendix/correlation/simulated.csv', mode='a', index=False, header=False)
                    else:
                        df.to_csv(f'./results/Appendix/correlation/simulated.csv', mode='w', index=False, header=True)

    print(f"Finished APP correlation dataset {dataset}, iteration {iteration}")
    print()

@click.command()
@click.option('--dataset', default='compas', help='Dataset to use')
@click.option('--iteration', default=0, help='Number of iterations')
@click.option('--seed', default=9, help='Seed for reproducibility')
def run_experiment(dataset, iteration, seed):
    np.random.seed(9)
    random.seed(9)
    Xtrain, Xexpl, ytrain, yexpl, categorical, cat_idx, task = load_data(dataset, 9)
    if len(Xexpl) > 2000:
        index = np.random.choice(len(Xexpl), 2000, replace=False)
        Xexpl = Xexpl.iloc[index]
        yexpl = yexpl[index]
    wrapper = OHEWrapper(categorical, task, seed, model='svm')
    wrapper.fit(Xtrain, ytrain)
    function = wrapper.predict
    compute_results(dataset, Xexpl, yexpl, function, categorical, iteration, task, wrapper)
    
    
if __name__ == '__main__':
    run_experiment()
    