## imports
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from preprocessing import *
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, f1_score, roc_auc_score
from ULER import ULER
import random
import os
from PASTA import PASTA
import click
import warnings
warnings.filterwarnings("ignore")


def compute_results(seed,iteration):
    np.random.seed(seed)
    random.seed(seed)
    
    # rm dir /csvFiles/Q3/PASTA
    folder = '../data/'
    explanations = pd.read_csv(folder+'form_explanations.csv')
    explanations.set_index(['Form', 'QID'], inplace=True)
    explanations.drop(columns=['game_id', 'action_id', 'event_id'], inplace=True)
    examples = pd.read_csv(folder+'form_instances.csv')
    examples.set_index(['Form', 'QID'], inplace=True)
    examples.drop(columns=['game_id', 'action_id', 'event_id'], inplace=True)
    num_forms = 35
    explanations = explanations.loc[:num_forms, :]
    examples = examples.loc[:num_forms, :]
    all_labels = pd.read_csv(folder+'merged_form.csv')
    all_labels.set_index(['Form', 'UserID'], inplace=True)
    all_labels = all_labels[all_labels['AttentionCheck'] > 1]

    explanation_score_columns = [col for col in all_labels.columns if 'Explanation' in col]
    prediction_score_columns = [col for col in all_labels.columns if 'Prediction' in col]
    explanation_labels = all_labels[explanation_score_columns]
    prediction_labels = all_labels[prediction_score_columns]

    explanation_labels.columns = [col.replace('ExplanationAgreement', '') for col in explanation_labels.columns]
    prediction_labels.columns = [col.replace('PredictionAgreement', '') for col in prediction_labels.columns]
    explanation_labels = pd.melt(explanation_labels.reset_index(), id_vars=['Form', 'UserID'], var_name='QID', value_name='expl_score').set_index(['Form', 'QID'])
    prediction_labels = pd.melt(prediction_labels.reset_index(), id_vars=['Form', 'UserID'], var_name='QID', value_name='pred_score').set_index(['Form', 'QID','UserID'])
    explanation_labels.index = pd.MultiIndex.from_tuples([(int(form), int(qid)) for form, qid in explanation_labels.index], names=['Form', 'QID'])
    explanation_labels['UserID'] = explanation_labels['UserID'].astype(int)
    prediction_labels.index = pd.MultiIndex.from_tuples([(int(form), int(qid), int(user)) for form, qid,user  in prediction_labels.index], names=['Form', 'QID', 'UserID'])
    explanation_labels.sort_index(inplace=True)
    prediction_labels.sort_index(inplace=True)
    example_explanations = examples.join(explanations, on=['Form', 'QID'], rsuffix='_expl', lsuffix='_pred')
    # for each (Form,QID) in example_explanations, create many (Form,QID,UserID) rows joining explanation labels
    example_explanation_explscores = example_explanations.join(explanation_labels, on=['Form', 'QID'])
    example_explanation_explscores = example_explanation_explscores.reset_index().set_index(['Form', 'QID', 'UserID'])
    example_explanation_explscores_predscores = example_explanation_explscores.join(prediction_labels, on=['Form', 'QID','UserID'])
    
    dataset = example_explanation_explscores_predscores
    expl_label_column = 'expl_score'
    pred_label_column = 'pred_score'
    expl_column = [col for col in dataset.columns if '_expl' in col]
    example_column = [col for col in dataset.columns if '_pred' in col]
    pred_column = 'prediction'
    to_remove = [(1,0), (8,4), (7,2), (5,3), (6,0), (8,3), (4,2), (9,3)]
    
    ## remove from dataset rows where Form,UserID in to_remove
    for form, user in to_remove:
        dataset = dataset[~((dataset.index.get_level_values('Form') == form) & (dataset.index.get_level_values('UserID') == user))]
    dataset_average = dataset.groupby(['Form', 'QID']).mean()
    dataset_std = dataset.groupby(['Form', 'QID']).std()
    
    dataset_average = dataset_average[(dataset_std[expl_label_column] < 1)]
        
    X = dataset_average[example_column]
    X = X.apply(pd.to_numeric, errors='coerce').fillna(-1)
    Z = dataset_average[expl_column]
    Z = Z.apply(pd.to_numeric, errors='coerce').fillna(-1)
    y = dataset_average[expl_label_column]
    # apply to y this y = np.where(y >= threshold, 1, 0) but keep y a dataframe
    y = y.apply(pd.to_numeric, errors='coerce').fillna(-1)
    y = y.apply(lambda x: np.where(x < 3, 1, 0))
    ypred = dataset_average[pred_label_column]
    
    contamination = y.mean()
    batch_size = 128
    epochs = 100
    
    idx_train, idx_test = train_test_split(Z.index, test_size=0.2, random_state=iteration, stratify=y) 
    idx_train, idx_val = train_test_split(idx_train, test_size=1/8, random_state=iteration, stratify=y.loc[idx_train])
    
    
    Xtr, Ztr, ytr, ypred_tr = X.loc[idx_train], Z.loc[idx_train], y[idx_train], ypred[idx_train]
    Xval, Zval, yval, ypred_val = X.loc[idx_val], Z.loc[idx_val], y[idx_val], ypred[idx_val]
    Xte, Zte, yte, ypred_te = X.loc[idx_test], Z.loc[idx_test], y[idx_test], ypred[idx_test]
    
    
    matrix = pd.read_csv(folder+'features.csv')
    matrix.set_index(['Form', 'QID'], inplace=True)
    features = matrix.loc[Xtr.index]
    features = features.apply(pd.to_numeric, errors='coerce').fillna(0)
    features = features.apply(lambda x: np.where(x >= 0.5, 1, 0))
    features = features.astype(int)

    columns = ["threshold", "std", "aggregation","contamination", "iteration", "method", "alpha", "beta", "gamma", "val_AUROC",
            "precision", "recall", "f1", "balanced_accuracy", "AUROC", "AUPR"]   
    for alpha in [0.1, 1, 10]:
        for beta in [0.001, 0.01, 0.1]:
            for gamma in [0.01, 0.1, 1]:
                PASTA_rejector = PASTA(dataset="xG", seed = iteration, batch_size=batch_size, epochs=epochs, contamination=contamination, alpha=alpha, beta=beta, gamma=gamma)
                PASTA_rejector.fit(Ztr, ytr.values, Zval, yval.values)
                scores = PASTA_rejector.score(Zval)
                val_auroc = roc_auc_score(yval, scores)
                for rr in np.linspace(0.01, 0.5, 50):
                    rr = np.round(rr, 2)
                    PASTA_rejector.set_rejection_rate(rr)
                    rejector_labels = PASTA_rejector.reject(Zte)
                    tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
                    f1 = f1_score(yte, rejector_labels)
                    bacc = balanced_accuracy_score(yte, rejector_labels)
                    values = [ contamination, iteration, "PASTARej", rr,  alpha, beta, gamma, val_auroc, tn,fp,fn,tp, f1, bacc]    
                    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                    if os.path.exists(f'./results/Q3/PASTARej_rr.csv'):
                        df.to_csv(f'./results/Q3/PASTARej_rr.csv', mode='a', index=False, header=False)
                    else:
                        df.to_csv(f'./results/Q3/PASTARej_rr.csv', mode='w', index=False, header=True)
            
            
    columns = ["threshold", "std", "aggregation","contamination", "iteration", "method", "kernel","C", "k","eps","val_AUROC",
            "precision", "recall", "f1", "balanced_accuracy", "AUROC", "AUPR"]
    for kernel in ['linear', 'poly', 'rbf']:
        for C in [0.01,1,100]:
            for k in [3,10,20]:
                for eps in [0.01, 0.1, 1]:
                    filter = ULER(C=C, kernel=kernel, eps=eps, k=k, seed=iteration, contamination=contamination)
                    filter.fit(Ztr, ytr, features.values)
                    train_scores = filter.score(Ztr)
                    scores = filter.score(Zval)
                    val_auroc = roc_auc_score(yval, scores)
                    for rr in np.linspace(0.01, 0.25, 25):
                        rr = np.round(rr, 2)
                        threshold = np.quantile(train_scores, 1-rr)
                        rejector_labels = np.where(filter.score(Zte) >= threshold, 1, 0)
                        tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
                        f1 = f1_score(yte, rejector_labels)
                        bacc = balanced_accuracy_score(yte, rejector_labels)
                        values = [dataset, contamination, iteration, "ULER", rr, kernel, C, k, eps, val_auroc, tn,fp,fn,tp, f1, bacc]
                        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
                        if os.path.exists(f'./results/Q3/LtX_rr.csv'):
                            df.to_csv(f'./results/Q3/LtX_rr.csv', mode='a', index=False, header=False)
                        else:
                            df.to_csv(f'./results/Q3/LtX_rr.csv', mode='w', index=False, header=True)
                    
    print(f"Finished Q3 iteration {iteration}")
    print()

@click.command()
@click.option('--iteration', default=0, help='Number of iterations')
@click.option('--seed', default=9, help='Seed for reproducibility')
def run_experiment(iteration, seed):
    compute_results(seed, iteration)
    
    
if __name__ == '__main__':
    run_experiment()