## imports
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from preprocessing import *
from sklearn.metrics import balanced_accuracy_score, f1_score, roc_auc_score, average_precision_score, confusion_matrix
import random
import os
from RandomRejection import RandRej
from NoveltyRejection import NovRej
from PredictionAmbiguity import AmbRejClas, AmbRejReg
from FaithfulnessRejection import FaithRej
from StabilityRejection import StabRej
from ComplexityRejection import ComplRej
import click
import warnings
warnings.filterwarnings("ignore")


def compute_results(seed,iteration):
    np.random.seed(seed)
    random.seed(seed)
    
    # rm dir /csvFiles/Q1/PASTA
    folder = '../data/'
    explanations = pd.read_csv(folder+'form_explanations.csv')
    explanations.set_index(['Form', 'QID'], inplace=True)
    explanations.drop(columns=['game_id', 'action_id', 'event_id'], inplace=True)
    examples = pd.read_csv(folder+'form_instances.csv')
    examples.set_index(['Form', 'QID'], inplace=True)
    examples.drop(columns=['game_id', 'action_id', 'event_id'], inplace=True)
    num_forms = 35
    explanations = explanations.loc[:num_forms, :]
    examples = examples.loc[:num_forms, :]
    all_labels = pd.read_csv(folder+'merged_form.csv')
    all_labels.set_index(['Form', 'UserID'], inplace=True)
    all_labels = all_labels[all_labels['AttentionCheck'] > 1]

    explanation_score_columns = [col for col in all_labels.columns if 'Explanation' in col]
    prediction_score_columns = [col for col in all_labels.columns if 'Prediction' in col]
    explanation_labels = all_labels[explanation_score_columns]
    prediction_labels = all_labels[prediction_score_columns]

    explanation_labels.columns = [col.replace('ExplanationAgreement', '') for col in explanation_labels.columns]
    prediction_labels.columns = [col.replace('PredictionAgreement', '') for col in prediction_labels.columns]
    explanation_labels = pd.melt(explanation_labels.reset_index(), id_vars=['Form', 'UserID'], var_name='QID', value_name='expl_score').set_index(['Form', 'QID'])
    prediction_labels = pd.melt(prediction_labels.reset_index(), id_vars=['Form', 'UserID'], var_name='QID', value_name='pred_score').set_index(['Form', 'QID','UserID'])
    explanation_labels.index = pd.MultiIndex.from_tuples([(int(form), int(qid)) for form, qid in explanation_labels.index], names=['Form', 'QID'])
    explanation_labels['UserID'] = explanation_labels['UserID'].astype(int)
    prediction_labels.index = pd.MultiIndex.from_tuples([(int(form), int(qid), int(user)) for form, qid,user  in prediction_labels.index], names=['Form', 'QID', 'UserID'])
    explanation_labels.sort_index(inplace=True)
    prediction_labels.sort_index(inplace=True)
    example_explanations = examples.join(explanations, on=['Form', 'QID'], rsuffix='_expl', lsuffix='_pred')
    # for each (Form,QID) in example_explanations, create many (Form,QID,UserID) rows joining explanation labels
    example_explanation_explscores = example_explanations.join(explanation_labels, on=['Form', 'QID'])
    example_explanation_explscores = example_explanation_explscores.reset_index().set_index(['Form', 'QID', 'UserID'])
    example_explanation_explscores_predscores = example_explanation_explscores.join(prediction_labels, on=['Form', 'QID','UserID'])
    
    dataset = example_explanation_explscores_predscores
    expl_label_column = 'expl_score'
    pred_label_column = 'pred_score'
    expl_column = [col for col in dataset.columns if '_expl' in col]
    example_column = [col for col in dataset.columns if '_pred' in col]
    pred_column = 'prediction'
    to_remove = [(1,0), (8,4), (7,2), (5,3), (6,0), (8,3), (4,2), (9,3)]

    ## remove from dataset rows where Form,UserID in to_remove
    for form, user in to_remove:
        dataset = dataset[~((dataset.index.get_level_values('Form') == form) & (dataset.index.get_level_values('UserID') == user))]
    
    dataset_average = dataset.groupby(['Form', 'QID']).mean()
    dataset_std = dataset.groupby(['Form', 'QID']).std()
    
    dataset_average = dataset_average[(dataset_std[expl_label_column] < 1)]
        
    X = dataset_average[example_column]
    X = X.apply(pd.to_numeric, errors='coerce').fillna(-1)
    Z = dataset_average[expl_column]
    Z = Z.apply(pd.to_numeric, errors='coerce').fillna(-1)
    y = dataset_average[expl_label_column]
    # apply to y this y = np.where(y >= threshold, 1, 0) but keep y a dataframe
    y = y.apply(pd.to_numeric, errors='coerce').fillna(-1)
    y = y.apply(lambda x: np.where(x < 3, 1, 0))
    ypred = dataset_average[pred_label_column]

    explanations_map = pd.read_csv(folder + 'form_explanations.csv').set_index(['Form', 'QID'])
    mapping_df = explanations_map.loc[y.index]
    y_with_ids = y.copy()
    y_with_ids = y_with_ids.to_frame() if not isinstance(y_with_ids, pd.DataFrame) else y_with_ids
    y_with_ids[['game_id', 'action_id', 'event_id']] = mapping_df[['game_id', 'action_id', 'event_id']].values
    y = y_with_ids.set_index(['game_id', 'action_id', 'event_id'], inplace=False) 
    predictions = pd.read_csv(folder + 'form_instances.csv').set_index(['game_id', 'action_id', 'event_id'])['prediction']
    explanations = pd.read_csv(folder+'form_explanations.csv').set_index(['Form', 'QID'])
    examples = pd.read_csv(folder+'form_instances.csv').set_index(['Form', 'QID'])
    
    Z = Z.join(explanations, on=['Form', 'QID'],)
    Z.set_index(['game_id', 'action_id', 'event_id'], inplace=True)

    
    X = X.join(examples, on=['Form', 'QID'])
    X.set_index(['game_id', 'action_id', 'event_id'], inplace=True)
    # drop columns with pred in columns
    Z.drop(columns=[col for col in Z.columns if 'expl' in col], inplace=True)
    X.drop(columns=[col for col in X.columns if 'pred' in col], inplace=True)
    contamination = y.values.mean()
    
    idx_train, idx_test = train_test_split(Z.index, test_size=0.2, random_state=iteration, stratify=y) 
    idx_train, idx_val = train_test_split(idx_train, test_size=1/8, random_state=iteration, stratify=y.loc[idx_train])
    
    
    Xtr, Ztr, ytr, ypred_tr = X.loc[idx_train], Z.loc[idx_train], y.loc[idx_train], predictions.loc[idx_train]
    Xval, Zval, yval, ypred_val = X.loc[idx_val], Z.loc[idx_val], y.loc[idx_val], predictions.loc[idx_val]
    Xte, Zte, yte, ypred_te = X.loc[idx_test], Z.loc[idx_test], y.loc[idx_test], predictions.loc[idx_test]
    
    
    matrix = pd.read_csv(folder+'features.csv')
    matrix.set_index(['Form', 'QID'], inplace=True)
    
    explanations_map = pd.read_csv(folder + 'form_explanations.csv').set_index(['Form', 'QID'])
    # Now, explanations_map has columns ['game_id', 'action_id', 'event_id', ...]
    # For a given (Form, QID) in y.index, you can do:
    mapping_df = explanations_map.loc[matrix.index]
    matrix_with_ids = matrix.copy()
    matrix_with_ids = matrix_with_ids.to_frame() if not isinstance(matrix_with_ids, pd.DataFrame) else matrix_with_ids
    matrix_with_ids[['game_id', 'action_id', 'event_id']] = mapping_df[['game_id', 'action_id', 'event_id']].values
    matrix = matrix_with_ids.set_index(['game_id', 'action_id', 'event_id'], inplace=False)  
    features = matrix.loc[Xtr.index]
    features = features.apply(pd.to_numeric, errors='coerce').fillna(0)
    features = features.apply(lambda x: np.where(x >= 0.5, 1, 0))
    features = features.astype(int)

    columns = ["contamination", "iteration", "method", "rr", "k", "val_AUROC","tn","fp","fn","tp","f1", "balanced_accuracy"]

    ### Standard rejection setting
    rand_rejector = RandRej(seed=iteration, contamination=contamination)
    rand_rejector.fit(Ztr)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        rand_rejector.set_rejection_rate(rr)
        rejector_labels = rand_rejector.reject(Zte)
        tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
        f1 = f1_score(yte, rejector_labels)
        bacc = balanced_accuracy_score(yte, rejector_labels)
        values = [contamination, iteration, "RandRej", rr, 0, 0, tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        if os.path.exists(f'./results/Appendix/UserStudy/baselines_rr.csv'):
            df.to_csv(f'./results/Appendix/UserStudy/baselines_rr.csv', mode='a', index=False, header=False)
        else:
            df.to_csv(f'./results/Appendix/UserStudy/baselines_rr.csv', mode='w', index=False, header=True)
    
    amb_rejector = AmbRejClas(seed=iteration, contamination=contamination)
    amb_rejector.fit(Xtr, ypred_tr)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        amb_rejector.set_rejection_rate(rr)
        rejector_labels = amb_rejector.reject(Xte, ypred_te) 
        tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
        f1 = f1_score(yte, rejector_labels)
        bacc = balanced_accuracy_score(yte, rejector_labels)
        values = [contamination, iteration, "PredAmb", rr, 0, 0, tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Appendix/UserStudy/baselines_rr.csv', mode='a', index=False, header=False)
            
    
    for k in [1, 10, 100]:
        nov_rejector = NovRej(k=k, seed=iteration, contamination=contamination)
        nov_rejector.fit(Xtr, ytr)
        scores = nov_rejector.score(Xval)
        val_auroc = roc_auc_score(yval, scores)
        for rr in np.linspace(0.01, 0.25, 25):
            rr = np.round(rr, 2)
            nov_rejector.set_rejection_rate(rr)
            rejector_labels = nov_rejector.reject(Xte)
            tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
            f1 = f1_score(yte, rejector_labels)
            bacc = balanced_accuracy_score(yte, rejector_labels)
            values = [contamination, iteration, "NovRej_{X}", rr, k, val_auroc, tn,fp,fn,tp, f1, bacc]
            df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
            df.to_csv(f'./results/Appendix/UserStudy/baselines_rr.csv', mode='a', index=False, header=False)

    
    ## Explanation-based rejection baselines
    for k in [1, 10, 100]:
        nov_rejector = NovRej(k=k, seed=iteration, contamination=contamination)
        nov_rejector.fit(Ztr, ytr)
        scores = nov_rejector.score(Zval)
        val_auroc = roc_auc_score(yval, scores)
        for rr in np.linspace(0.01, 0.25, 25):
            rr = np.round(rr, 2)
            nov_rejector.set_rejection_rate(rr)
            rejector_labels = nov_rejector.reject(Zte)
            tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
            f1 = f1_score(yte, rejector_labels)
            bacc = balanced_accuracy_score(yte, rejector_labels)
            values = [contamination, iteration, "NovRej_{Z}", rr, k, val_auroc, tn,fp,fn,tp, f1, bacc]
            df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
            df.to_csv(f'./results/Appendix/UserStudy/baselines_rr.csv', mode='a', index=False, header=False)

    metr_rejector = FaithRej(dataset='xG', seed=iteration, contamination=contamination)
    metr_rejector.fit(Xtr, Ztr, [])
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        metr_rejector.set_rejection_rate(rr)
        rejector_labels = metr_rejector.reject(pd.DataFrame(Xte), Zte)
        tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
        f1 = f1_score(yte, rejector_labels)
        bacc = balanced_accuracy_score(yte, rejector_labels)
        values = [contamination, iteration, "FaithRej", rr, 0, 0, tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Appendix/UserStudy/baselines_rr.csv', mode='a', index=False, header=False)

    metr_rejector = StabRej(dataset='xG', seed=iteration, contamination=contamination)
    metr_rejector.fit(Ztr)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        metr_rejector.set_rejection_rate(rr)
        rejector_labels = metr_rejector.reject(Zte)
        tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
        f1 = f1_score(yte, rejector_labels)
        bacc = balanced_accuracy_score(yte, rejector_labels)
        values = [contamination, iteration, "StabRej", rr, 0, 0,tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Appendix/UserStudy/baselines_rr.csv', mode='a', index=False, header=False)
        
    complexity_rejector = ComplRej(seed=iteration, contamination=contamination)
    complexity_rejector.fit(Ztr)
    for rr in np.linspace(0.01, 0.25, 25):
        rr = np.round(rr, 2)
        complexity_rejector.set_rejection_rate(rr)
        rejector_labels = complexity_rejector.reject(Zte)
        tn, fp, fn, tp = confusion_matrix(yte, rejector_labels).ravel()
        f1 = f1_score(yte, rejector_labels)
        bacc = balanced_accuracy_score(yte, rejector_labels)
        values = [contamination, iteration, "ComplRej", rr, 0, 0, tn,fp,fn,tp, f1, bacc]
        df = pd.DataFrame(data = np.array(values).reshape(1,-1), columns = columns)
        df.to_csv(f'./results/Appendix/UserStudy/baselines_rr.csv', mode='a', index=False, header=False)
        
    print(f"Finished APP competitors user study rr iteration {iteration}")
    print()

@click.command()
@click.option('--threshold_features', default=0.5, help='Threshold for features')
@click.option('--threshold', default=2.6, help='Threshold for scores')
@click.option('--iteration', default=0, help='Number of iterations')
@click.option('--seed', default=9, help='Seed for reproducibility')
@click.option('--filter_std', type=float, default=5.1, help='Standard deviation for filter')
def run_experiment(iteration, seed):
    compute_results(seed, iteration)
    
    
if __name__ == '__main__':
    run_experiment()
    