## imports
import numpy as np
from RandomRejection import RandRej
from NoveltyRejection import NovRej
from PredictionAmbiguity import AmbRejClas, AmbRejReg
from FaithfulnessRejection import FaithRej
from StabilityRejection import StabRej
from ComplexityRejection import ComplRej
import pandas as pd
from sklearn.model_selection import train_test_split
from preprocessing import *
from sklearn.metrics import recall_score, precision_score, balanced_accuracy_score, f1_score, roc_auc_score, average_precision_score
import random
import os
import click
import warnings
from pathlib import Path
import sys

warnings.filterwarnings("ignore")

def compute_results(seed,iteration):
    np.random.seed(seed)
    random.seed(seed)
    
    folder = '../data/'
    explanations = pd.read_csv(folder+'form_explanations.csv')
    explanations.set_index(['Form', 'QID'], inplace=True)
    explanations.drop(columns=['game_id', 'action_id', 'event_id'], inplace=True)
    examples = pd.read_csv(folder+'form_instances.csv')
    examples.set_index(['Form', 'QID'], inplace=True)
    examples.drop(columns=['game_id', 'action_id', 'event_id'], inplace=True)
    num_forms = 35
    explanations = explanations.loc[:num_forms, :]
    examples = examples.loc[:num_forms, :]
    all_labels = pd.read_csv(folder+'merged_form.csv')
    all_labels.set_index(['Form', 'UserID'], inplace=True)
    all_labels = all_labels[all_labels['AttentionCheck'] > 1]

    explanation_score_columns = [col for col in all_labels.columns if 'Explanation' in col]
    prediction_score_columns = [col for col in all_labels.columns if 'Prediction' in col]
    explanation_labels = all_labels[explanation_score_columns]
    prediction_labels = all_labels[prediction_score_columns]

    explanation_labels.columns = [col.replace('ExplanationAgreement', '') for col in explanation_labels.columns]
    prediction_labels.columns = [col.replace('PredictionAgreement', '') for col in prediction_labels.columns]
    explanation_labels = pd.melt(explanation_labels.reset_index(), id_vars=['Form', 'UserID'], var_name='QID', value_name='expl_score').set_index(['Form', 'QID'])
    prediction_labels = pd.melt(prediction_labels.reset_index(), id_vars=['Form', 'UserID'], var_name='QID', value_name='pred_score').set_index(['Form', 'QID','UserID'])
    explanation_labels.index = pd.MultiIndex.from_tuples([(int(form), int(qid)) for form, qid in explanation_labels.index], names=['Form', 'QID'])
    explanation_labels['UserID'] = explanation_labels['UserID'].astype(int)
    prediction_labels.index = pd.MultiIndex.from_tuples([(int(form), int(qid), int(user)) for form, qid,user  in prediction_labels.index], names=['Form', 'QID', 'UserID'])
    explanation_labels.sort_index(inplace=True)
    prediction_labels.sort_index(inplace=True)
    example_explanations = examples.join(explanations, on=['Form', 'QID'], rsuffix='_expl', lsuffix='_pred')
    # for each (Form,QID) in example_explanations, create many (Form,QID,UserID) rows joining explanation labels
    example_explanation_explscores = example_explanations.join(explanation_labels, on=['Form', 'QID'])
    example_explanation_explscores = example_explanation_explscores.reset_index().set_index(['Form', 'QID', 'UserID'])
    example_explanation_explscores_predscores = example_explanation_explscores.join(prediction_labels, on=['Form', 'QID','UserID'])
    
    dataset = example_explanation_explscores_predscores
    expl_label_column = 'expl_score'
    pred_label_column = 'pred_score'
    expl_column = [col for col in dataset.columns if '_expl' in col]
    example_column = [col for col in dataset.columns if '_pred' in col]
    pred_column = 'prediction'
    to_remove = [(1,0), (8,4), (7,2), (5,3), (6,0),  (8,3), (4,2), (9,3)] 
    ## remove from dataset rows where Form,UserID in to_remove
    for form, user in to_remove:
        dataset = dataset[~((dataset.index.get_level_values('Form') == form) & (dataset.index.get_level_values('UserID') == user))]
    
    dataset_average = dataset.groupby(['Form', 'QID']).mean()
    dataset_std = dataset.groupby(['Form', 'QID']).std()
    
    dataset_average = dataset_average[(dataset_std[expl_label_column] < 1)]
        
    X = dataset_average[example_column]
    X = X.apply(pd.to_numeric, errors='coerce').fillna(-1)
    Z = dataset_average[expl_column]
    Z = Z.apply(pd.to_numeric, errors='coerce').fillna(-1)
    y = dataset_average[expl_label_column]
    # apply to y this y = np.where(y >= threshold, 1, 0) but keep y a dataframe
    y = y.apply(pd.to_numeric, errors='coerce').fillna(-1)
    y = y.apply(lambda x: np.where(x < 3, 1, 0))
    ypred = dataset_average[pred_label_column]
    

    # Example: get mapping DataFrame
    explanations_map = pd.read_csv(folder + 'form_explanations.csv').set_index(['Form', 'QID'])
    mapping_df = explanations_map.loc[y.index]
    y_with_ids = y.copy()
    y_with_ids = y_with_ids.to_frame() if not isinstance(y_with_ids, pd.DataFrame) else y_with_ids
    y_with_ids[['game_id', 'action_id', 'event_id']] = mapping_df[['game_id', 'action_id', 'event_id']].values
    y = y_with_ids.set_index(['game_id', 'action_id', 'event_id'], inplace=False)
     
    
    
    predictions = pd.read_csv(folder + 'form_instances.csv').set_index(['game_id', 'action_id', 'event_id'])['prediction']
    explanations = pd.read_csv(folder+'form_explanations.csv').set_index(['Form', 'QID'])
    examples = pd.read_csv(folder+'form_instances.csv').set_index(['Form', 'QID'])
    
    Z = Z.join(explanations, on=['Form', 'QID'],)
    Z.set_index(['game_id', 'action_id', 'event_id'], inplace=True)

    
    X = X.join(examples, on=['Form', 'QID'])
    X.set_index(['game_id', 'action_id', 'event_id'], inplace=True)
    # drop columns with pred in columns
    Z.drop(columns=[col for col in Z.columns if 'expl' in col], inplace=True)
    X.drop(columns=[col for col in X.columns if 'pred' in col], inplace=True)
    contamination = y.values.mean()
    
    idx_train, idx_test = train_test_split(Z.index, test_size=0.2, random_state=iteration, stratify=y) 
    idx_train, idx_val = train_test_split(idx_train, test_size=1/8, random_state=iteration, stratify=y.loc[idx_train])
    
    
    Xtr, Ztr, ytr, ypred_tr = X.loc[idx_train], Z.loc[idx_train], y.loc[idx_train], predictions.loc[idx_train]
    Xval, Zval, yval, ypred_val = X.loc[idx_val], Z.loc[idx_val], y.loc[idx_val], predictions.loc[idx_val]
    Xte, Zte, yte, ypred_te = X.loc[idx_test], Z.loc[idx_test], y.loc[idx_test], predictions.loc[idx_test]
    
    
    matrix = pd.read_csv(folder+'features.csv')
    matrix.set_index(['Form', 'QID'], inplace=True)
    
    explanations_map = pd.read_csv(folder + 'form_explanations.csv').set_index(['Form', 'QID'])
    # Now, explanations_map has columns ['game_id', 'action_id', 'event_id', ...]
    # For a given (Form, QID) in y.index, you can do:
    mapping_df = explanations_map.loc[matrix.index]
    matrix_with_ids = matrix.copy()
    matrix_with_ids = matrix_with_ids.to_frame() if not isinstance(matrix_with_ids, pd.DataFrame) else matrix_with_ids
    matrix_with_ids[['game_id', 'action_id', 'event_id']] = mapping_df[['game_id', 'action_id', 'event_id']].values
    matrix = matrix_with_ids.set_index(['game_id', 'action_id', 'event_id'], inplace=False)  
    features = matrix.loc[Xtr.index]
    features = features.apply(pd.to_numeric, errors='coerce').fillna(0)
    features = features.apply(lambda x: np.where(x >= 0.5, 1, 0))
    features = features.astype(int)

    ## Q1
    columns = ["contamination", "iteration", "method", "k", "val_AUROC", "precision", "recall", "f1", "balanced_accuracy", "AUROC", "AUPR"]

    ## Standard rejection setting
    rand_rejector = RandRej(seed=iteration, contamination=contamination)
    rand_rejector.fit(Ztr)
    rand_rejector_scores = rand_rejector.score(Zte)
    rand_rejector_labels = rand_rejector.reject(Zte)
    rand_precision = precision_score(yte, rand_rejector_labels, zero_division=0)
    rand_recall = recall_score(yte, rand_rejector_labels)
    rand_f1 = f1_score(yte, rand_rejector_labels)
    rand_bacc = balanced_accuracy_score(yte, rand_rejector_labels)
    rand_auroc = roc_auc_score(yte, rand_rejector_scores)
    rand_aupr = average_precision_score(yte, rand_rejector_scores)
    values = [contamination, iteration, "RandRej", 0, 0, rand_precision, rand_recall, rand_f1, rand_bacc, rand_auroc, rand_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    if os.path.exists(f'./results/Appendix/UserStudy/baselines.csv'):
        df.to_csv(f'./results/Appendix/UserStudy/baselines.csv', mode='a', index=False, header=False)
    else:
        df.to_csv(f'./results/Appendix/UserStudy/baselines.csv', mode='w', index=False, header=True)
    
    amb_rejector = AmbRejClas(seed=iteration, contamination=contamination)
    amb_rejector.fit(Xtr, ypred_tr)
    amb_rejector_scores = amb_rejector.score(Xte, ypred_te)
    amb_rejector_labels = amb_rejector.reject(Xte, ypred_te)
    amb_precision = precision_score(yte, amb_rejector_labels, zero_division=0)
    amb_recall = recall_score(yte, amb_rejector_labels)
    amb_f1 = f1_score(yte, amb_rejector_labels)
    amb_bacc = balanced_accuracy_score(yte, amb_rejector_labels)
    amb_auroc = roc_auc_score(yte, amb_rejector_scores)
    amb_aupr = average_precision_score(yte, amb_rejector_scores)
    values = [contamination, iteration, "PredAmb", 0, 0, amb_precision, amb_recall, amb_f1, amb_bacc, amb_auroc, amb_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    df.to_csv(f'./results/Appendix/UserStudy/baselines.csv', mode='a', index=False, header=False)
            
    
    for k in [1, 10, 100]:
        nov_rejector = NovRej(k=k, seed=iteration, contamination=contamination)
        nov_rejector.fit(Xtr, ytr)
        nov_rejector_val_scores = nov_rejector.score(Xval)
        nov_val_auroc = roc_auc_score(yval, nov_rejector_val_scores)
        nov_rejector_scores = nov_rejector.score(Xte)
        nov_rejector_labels = nov_rejector.reject(Xte)
        nov_precision = precision_score(yte, nov_rejector_labels, zero_division=0)
        nov_recall = recall_score(yte, nov_rejector_labels)
        nov_f1 = f1_score(yte, nov_rejector_labels)
        nov_bacc = balanced_accuracy_score(yte, nov_rejector_labels)
        nov_auroc = roc_auc_score(yte, nov_rejector_scores)
        nov_aupr = average_precision_score(yte, nov_rejector_scores)
        values = [contamination, iteration, "NovRej_{X}", k ,nov_val_auroc, nov_precision, nov_recall, nov_f1, nov_bacc, nov_auroc, nov_aupr]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Appendix/UserStudy/baselines.csv', mode='a', index=False, header=False)
    
    ## Explanation-based rejection baselines
    for k in [1, 10, 100]:
        nov_rejector = NovRej(k=k, seed=iteration, contamination=contamination)
        nov_rejector.fit(Ztr, ytr)
        nov_rejector_val_scores = nov_rejector.score(Zval)
        nov_val_auroc = roc_auc_score(yval, nov_rejector_val_scores)
        nov_rejector_scores = nov_rejector.score(Zte)
        nov_rejector_labels = nov_rejector.reject(Zte)
        nov_precision = precision_score(yte, nov_rejector_labels, zero_division=0)
        nov_recall = recall_score(yte, nov_rejector_labels)
        nov_f1 = f1_score(yte, nov_rejector_labels)
        nov_bacc = balanced_accuracy_score(yte, nov_rejector_labels)
        nov_auroc = roc_auc_score(yte, nov_rejector_scores)
        nov_aupr = average_precision_score(yte, nov_rejector_scores)
        values = [contamination, iteration, "NovRej_{Z}", k, nov_val_auroc, nov_precision, nov_recall, nov_f1, nov_bacc, nov_auroc, nov_aupr]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Appendix/UserStudy/baselines.csv', mode='a', index=False, header=False)

    metr_rejector = FaithRej(dataset='xG', seed=iteration, contamination=contamination)
    metr_rejector.fit(Xtr, Ztr, [])
    metr_rejector_scores = metr_rejector.score(Xte, Zte)
    metr_rejector_labels = metr_rejector.reject(Xte, Zte)
    metr_precision = precision_score(yte, metr_rejector_labels, zero_division=0)
    metr_recall = recall_score(yte, metr_rejector_labels)
    metr_f1 = f1_score(yte, metr_rejector_labels)
    metr_bacc = balanced_accuracy_score(yte, metr_rejector_labels)
    metr_auroc = roc_auc_score(yte, metr_rejector_scores)
    metr_aupr = average_precision_score(yte, metr_rejector_scores)
    values = [contamination, iteration, "FaithRej", 0, 0, metr_precision, metr_recall, metr_f1, metr_bacc, metr_auroc, metr_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    df.to_csv(f'./results/Appendix/UserStudy/baselines.csv', mode='a', index=False, header=False)

    metr_rejector = StabRej(dataset='xG', seed=iteration, contamination=contamination)
    metr_rejector.fit(Ztr)
    metr_rejector_scores = metr_rejector.score(Zte)
    metr_rejector_labels = metr_rejector.reject(Zte)
    metr_precision = precision_score(yte, metr_rejector_labels, zero_division=0)
    metr_recall = recall_score(yte, metr_rejector_labels)
    metr_f1 = f1_score(yte, metr_rejector_labels)
    metr_bacc = balanced_accuracy_score(yte, metr_rejector_labels)
    metr_auroc = roc_auc_score(yte, metr_rejector_scores)
    metr_aupr = average_precision_score(yte, metr_rejector_scores)
    values = [contamination, iteration, "StabRej", 0, 0, metr_precision, metr_recall, metr_f1, metr_bacc, metr_auroc, metr_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    df.to_csv(f'./results/Appendix/UserStudy/baselines.csv', mode='a', index=False, header=False)
    
    compl_rejector = ComplRej(seed=iteration, contamination=contamination)
    compl_rejector.fit(Ztr)
    compl_rejector_scores = compl_rejector.score(Zte)
    compl_rejector_labels = compl_rejector.reject(Zte)
    compl_precision = precision_score(yte, compl_rejector_labels, zero_division=0)
    compl_recall = recall_score(yte, compl_rejector_labels)
    compl_f1 = f1_score(yte, compl_rejector_labels)
    compl_bacc = balanced_accuracy_score(yte, compl_rejector_labels)
    compl_auroc = roc_auc_score(yte, compl_rejector_scores)
    compl_aupr = average_precision_score(yte, compl_rejector_scores)
    values = [contamination, iteration, "ComplRej", 0, 0, compl_precision, compl_recall, compl_f1, compl_bacc, compl_auroc, compl_aupr]
    df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
    df.to_csv(f'./results/Appendix/UserStudy/baselines.csv', mode='a', index=False, header=False)
    
    print(f"Finished APP competitors user study iteration {iteration}")
    print()

@click.command()
@click.option('--iteration', default=0, help='Number of iterations')
@click.option('--seed', default=9, help='Seed for reproducibility')
def run_experiment(iteration, seed):
    compute_results(seed, iteration)
    
    
if __name__ == '__main__':
    run_experiment()
    