import numpy as np
import pandas as pd
import os
import pickle
from scipy.stats import wilcoxon
import itertools

os.makedirs('result', exist_ok=True)

model_list = ['DeepHit', 'NnetSurv', 'PMFSurv']
fairness_list = ['None']
dataset_list = ['mimiccxr', 'adni', 'areds']
sensitive_attr_list = ['sex', 'age', 'race']
metric_list = ['ctd', 'brier', 'auc']
pretrain_list = ['True']
hparam_seed_list = [str(i) for i in range(10)]
seed_list = [str(i) for i in range(1)]
shift_list = ['None']
group_shift_list = ['None']



stat_test_results = []
for fairness, dataset, sensitive_attr, metric, pretrain in \
    itertools.product(fairness_list, dataset_list, 
                      sensitive_attr_list, metric_list, pretrain_list):
    
    # skip invalid combinations
    if (dataset == 'adni' and sensitive_attr == 'race'):
        continue
    values = []
    for model, hparam_seed, seed, shift, group_shift in itertools.product(model_list, hparam_seed_list, seed_list, shift_list, group_shift_list):
        filename = f'output/{fairness}/score_{model}_{fairness}_{dataset}_{sensitive_attr}_{metric}_{pretrain}_{shift}_{group_shift}_{hparam_seed}_{seed}.pkl'
        if not os.path.exists(filename):
            print(f'{filename} not found')
            continue
            # raise FileNotFoundError(f'{filename} not found')
        with open(filename, 'rb') as f:
            result = pickle.load(f)
        values.append(result['test']['per_group'])
    values = np.array(values)

    # perform wilcoxon signed-rank test
    stat, p = wilcoxon(values[:, 0], values[:, 1])
    stat_test_results.append({
        'dataset': dataset,
        'sensitive_attr': sensitive_attr,
        'metric': metric,
        'p-value': p,
        'avg_g0': np.mean(values[:, 0]),
        'avg_g1': np.mean(values[:, 1]),
        'avg_gap': np.mean(values[:, 0]) - np.mean(values[:, 1]),
    })

stat_test_results = pd.DataFrame(stat_test_results)
stat_test_results.sort_values(['dataset', 'sensitive_attr', 'metric'], inplace=True)
stat_test_results.to_csv('result/stat_test_group_disparity.csv', index=False)
