import numpy as np
import pandas as pd
import os
import pickle
from scipy.stats import wilcoxon
import itertools

os.makedirs('result', exist_ok=True)

model_list = ['DeepHit', 'NnetSurv', 'PMFSurv']
fairness_list = ['None']
dataset_list = ['mimiccxr', 'adni', 'areds']
sensitive_attr_list = ['sex', 'age', 'race']
metric_list = ['ctd', 'brier', 'auc']
pretrain_list = ['True', 'False']
hparam_seed_list = [str(i) for i in range(10)]
seed_list = [str(i) for i in range(1)]
shift_list = ['None']
group_shift_list = ['None']


stat_test_results = []
for fairness, dataset, sensitive_attr, metric in \
    itertools.product(fairness_list, dataset_list, 
                      sensitive_attr_list, metric_list):
    # skip invalid combinations
    if (dataset == 'adni' and sensitive_attr == 'race'):
        continue
    values = []
    for model, hparam_seed, seed, shift, group_shift in itertools.product(model_list, hparam_seed_list, seed_list, shift_list, group_shift_list):
        tmp = []
        for pretrain in pretrain_list:
            filename = f'output/{fairness}/score_{model}_{fairness}_{dataset}_{sensitive_attr}_{metric}_{pretrain}_{shift}_{group_shift}_{hparam_seed}_{seed}.pkl'
            if not os.path.exists(filename):
                print(f'{filename} not found')
                continue
                # raise FileNotFoundError(f'{filename} not found')
            with open(filename, 'rb') as f:
                result = pickle.load(f)
            tmp.append(result['test']['accuracy'])
            tmp.append(result['test']['fairness'])
        if len(tmp) == 4:
            values.append(tmp)
    values = np.array(values)

    # perform wilcoxon signed-rank tests
    if metric in ['ctd', 'auc']:
        stat, p_acc = wilcoxon(values[:, 0], values[:, 2], alternative='greater')
    else:
        stat, p_acc = wilcoxon(values[:, 0], values[:, 2], alternative='less')
    stat, p_fair = wilcoxon(values[:, 1], values[:, 3], alternative='less')
    stat_test_results.append({
        'dataset': dataset,
        'sensitive_attr': sensitive_attr,
        'metric': metric,
        'p_value_acc': p_acc,
        'p_value_fair': p_fair
    })

stat_test_results = pd.DataFrame(stat_test_results)
stat_test_results.sort_values(['dataset', 'sensitive_attr', 'metric'], inplace=True)
stat_test_results.to_csv('result/stat_test_pretrained.csv', index=False)