import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.isotonic import spearmanr
from sklearn.model_selection import train_test_split
from preprocessing import *
from sklearn.metrics import roc_auc_score
from ULER import ULER
import random
from ComplexityRejection import ComplRej
import os
from FaithfulnessRejection import FaithRej
from StabilityRejection import StabRej
from PASTA import PASTA
import click
import warnings
warnings.filterwarnings("ignore")

def compute_results(seed,iteration, threshold, threshold_features, filter_std, aggregation):
    np.random.seed(seed)
    random.seed(seed)
    
    folder = '../data/'
    explanations = pd.read_csv(folder+'form_explanations.csv')
    explanations.set_index(['Form', 'QID'], inplace=True)
    explanations.drop(columns=['game_id', 'action_id', 'event_id'], inplace=True)
    examples = pd.read_csv(folder+'form_instances.csv')
    examples.set_index(['Form', 'QID'], inplace=True)
    examples.drop(columns=['game_id', 'action_id', 'event_id'], inplace=True)
    num_forms = 35
    explanations = explanations.loc[:num_forms, :]
    examples = examples.loc[:num_forms, :]
    all_labels = pd.read_csv(folder+'merged_form.csv')
    all_labels.set_index(['Form', 'UserID'], inplace=True)
    all_labels = all_labels[all_labels['AttentionCheck'] > 1]

    explanation_score_columns = [col for col in all_labels.columns if 'Explanation' in col]
    prediction_score_columns = [col for col in all_labels.columns if 'Prediction' in col]
    explanation_labels = all_labels[explanation_score_columns]
    prediction_labels = all_labels[prediction_score_columns]

    explanation_labels.columns = [col.replace('ExplanationAgreement', '') for col in explanation_labels.columns]
    prediction_labels.columns = [col.replace('PredictionAgreement', '') for col in prediction_labels.columns]
    explanation_labels = pd.melt(explanation_labels.reset_index(), id_vars=['Form', 'UserID'], var_name='QID', value_name='expl_score').set_index(['Form', 'QID'])
    prediction_labels = pd.melt(prediction_labels.reset_index(), id_vars=['Form', 'UserID'], var_name='QID', value_name='pred_score').set_index(['Form', 'QID','UserID'])
    explanation_labels.index = pd.MultiIndex.from_tuples([(int(form), int(qid)) for form, qid in explanation_labels.index], names=['Form', 'QID'])
    explanation_labels['UserID'] = explanation_labels['UserID'].astype(int)
    prediction_labels.index = pd.MultiIndex.from_tuples([(int(form), int(qid), int(user)) for form, qid,user  in prediction_labels.index], names=['Form', 'QID', 'UserID'])
    explanation_labels.sort_index(inplace=True)
    prediction_labels.sort_index(inplace=True)
    example_explanations = examples.join(explanations, on=['Form', 'QID'], rsuffix='_expl', lsuffix='_pred')
    # for each (Form,QID) in example_explanations, create many (Form,QID,UserID) rows joining explanation labels
    example_explanation_explscores = example_explanations.join(explanation_labels, on=['Form', 'QID'])
    example_explanation_explscores = example_explanation_explscores.reset_index().set_index(['Form', 'QID', 'UserID'])
    example_explanation_explscores_predscores = example_explanation_explscores.join(prediction_labels, on=['Form', 'QID','UserID'])
    
    dataset = example_explanation_explscores_predscores
    expl_label_column = 'expl_score'
    pred_label_column = 'pred_score'
    expl_column = [col for col in dataset.columns if '_expl' in col]
    example_column = [col for col in dataset.columns if '_pred' in col]
    pred_column = 'prediction'
    to_remove = [(1,0), (8,4), (7,2), (5,3), (6,0), (8,3), (4,2), (9,3)] 

    ## remove from dataset rows where Form,UserID in to_remove
    for form, user in to_remove:
        dataset = dataset[~((dataset.index.get_level_values('Form') == form) & (dataset.index.get_level_values('UserID') == user))]
    dataset_average = dataset.groupby(['Form', 'QID']).mean()
    dataset_std = dataset.groupby(['Form', 'QID']).std()
    
    dataset_average = dataset_average[(dataset_std[expl_label_column] < 1)]
        
    X = dataset_average[example_column]
    X = X.apply(pd.to_numeric, errors='coerce').fillna(-1)
    Z = dataset_average[expl_column]
    Z = Z.apply(pd.to_numeric, errors='coerce').fillna(-1)
    y = dataset_average[expl_label_column]
    # apply to y this y = np.where(y >= threshold, 1, 0) but keep y a dataframe
    y = y.apply(pd.to_numeric, errors='coerce').fillna(-1)
    y = y.apply(lambda x: np.where(x < 3, 1, 0))
    ypred = dataset_average[pred_label_column]
    
    contamination = y.mean()
    batch_size = 128
    epochs = 100
    
    idx_train, idx_test = train_test_split(Z.index, test_size=0.2, random_state=iteration, stratify=y) 
    idx_train, idx_val = train_test_split(idx_train, test_size=1/8, random_state=iteration, stratify=y.loc[idx_train])
    
    
    Xtr, Ztr, ytr, ypred_tr = X.loc[idx_train], Z.loc[idx_train], y[idx_train], ypred[idx_train]
    Xval, Zval, yval, ypred_val = X.loc[idx_val], Z.loc[idx_val], y[idx_val], ypred[idx_val]
    Xte, Zte, yte, ypred_te = X.loc[idx_test], Z.loc[idx_test], y[idx_test], ypred[idx_test]
    
    
    matrix = pd.read_csv(folder+'features.csv')
    matrix.set_index(['Form', 'QID'], inplace=True)
    features = matrix.loc[Xtr.index]
    features = features.apply(pd.to_numeric, errors='coerce').fillna(0)
    features = features.apply(lambda x: np.where(x >= 0.5, 1, 0))
    features = features.astype(int)
    
    metr_rejector = FaithRej(dataset='xG', func=function, mode='', seed=iteration, contamination=contamination)
    metr_rejector.fit(pd.DataFrame(Xtr), Ztr, [])
    faithfulness_test = metr_rejector.score(Xte, Zte)

    metr_rejector = StabRej(dataset='xG', seed=iteration, mode ='xG', contamination=contamination)
    metr_rejector.fit(Ztr)
    stability_test = metr_rejector.score(Zte)

    compl_rejector = ComplRej(seed=iteration, contamination=contamination)
    compl_rejector.fit(Ztr)
    complexity_test = compl_rejector.score(Zte)
    

    columns = ["dataset", "contamination", "iteration", "method", "kernel","C", "k","eps","val_AUROC",
            "corr_faith", "spear_faith",
            "corr_stab", "spear_stab",
            "corr_comp", "spear_comp"]
    for kernel in ['linear', 'poly', 'rbf']:
        for C in [0.01,1,100]:
            for k in [3,10,20]:
                for eps in [0.01, 0.1, 1]:
                    
                    filter = ULER(C=C, kernel=kernel, eps=eps, k=k, seed=iteration, contamination=contamination)
                    filter.fit(Ztr, ytr, features)
                    filter_val_scores = filter.score(Zval)
                    filter_test_scores = filter.score(Zte)
                    filter_val_auroc = roc_auc_score(yval, filter_val_scores)
                    correlation_ULER_faithfulness = np.corrcoef(filter_test_scores, faithfulness_test)[0,1]
                    spearman_ULER_faithfulness = spearmanr(filter_test_scores, faithfulness_test)[0]
                    correlation_ULER_stability = np.corrcoef(filter_test_scores, stability_test)[0,1]
                    spearman_ULER_stability = spearmanr(filter_test_scores, stability_test)[0]
                    correlation_ULER_complexity = np.corrcoef(filter_test_scores, complexity_test)[0,1]
                    spearman_ULER_complexity = spearmanr(filter_test_scores, complexity_test)[0]
                    values = ['xG', contamination, iteration, kernel, C,k,eps, filter_val_auroc,
                            correlation_ULER_faithfulness, spearman_ULER_faithfulness,
                            correlation_ULER_stability, spearman_ULER_stability,
                            correlation_ULER_complexity, spearman_ULER_complexity,]
                    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                    if os.path.exists(f'./results/Appendix/correlation/user.csv'):
                        df.to_csv(f'./results/Appendix/correlation/user.csv', mode='a', index=False, header=False)
                    else:
                        df.to_csv(f'./results/Appendix/correlation/user.csv', mode='w', index=False, header=True)

    print(f"Finished APP correlation user study, iteration {iteration}")
    print()

@click.command()
@click.option('--iteration', default=0, help='Number of iterations')
@click.option('--seed', default=9, help='Seed for reproducibility')
def run_experiment(iteration, seed):
    compute_results(seed, iteration)
    
    
if __name__ == '__main__':
    run_experiment()
    