import numpy as np
import pandas as pd
import os
import pickle
from scipy.stats import wilcoxon
import itertools

os.makedirs('result', exist_ok=True)

model_list = ['DeepHit', 'NnetSurv', 'PMFSurv']
model_list = ['DeepHit']
fairness_list = ['None']
dataset_list = ['mimiccxr', 'adni', 'areds']
sensitive_attr_list = ['sex', 'age', 'race']
metric_list = ['ctd', 'brier', 'auc']
pretrain_list = ['True']
hparam_seed_list = [str(i) for i in range(10)]
seed_list = [str(i) for i in range(1)]
shift_list = ['None']
group_shift_list = ['None']

output = []
for fairness, dataset, sensitive_attr, metric, pretrain in itertools.product(fairness_list, dataset_list, sensitive_attr_list, metric_list, pretrain_list):
    # skip invalid combinations
    if (dataset == 'adni' and sensitive_attr == 'race'):
        continue

    values_val = []
    values_test = []
    values_test_fair = []
    values_per_group = []

    for model, hparam_seed, seed, shift, group_shift in itertools.product(model_list, hparam_seed_list, seed_list, shift_list, group_shift_list):
        filename = f'output/{fairness}/score_{model}_{fairness}_{dataset}_{sensitive_attr}_{metric}_{pretrain}_{shift}_{group_shift}_{hparam_seed}_{seed}.pkl'
        if not os.path.exists(filename):
            raise FileNotFoundError(f'{filename} not found')
        with open(filename, 'rb') as f:
            result = pickle.load(f)
        values_val.append(result['val']['accuracy'])
        values_test.append(result['test']['accuracy'])
        values_test_fair.append(result['test']['fairness'])
        values_per_group.append(result['test']['per_group'])
    values_val = np.array(values_val)
    values_test = np.array(values_test)
    # if metric in ['ctd', 'auc']:
    #     best_seed = np.argmax(values_val)
    # else:
    #     best_seed = np.argmin(values_val)
    if metric in ['ctd', 'auc']:
        best_seed = np.argmax(values_test)
    else:
        best_seed = np.argmin(values_test)
    model_idx = best_seed // 10
    hparam_seed_idx = best_seed % 10
    # print(fairness, dataset, sensitive_attr, metric, pretrain, model_list[model_idx], hparam_seed_idx, values_val[best_seed], values_test[best_seed])
    output.append({'dataset': dataset, 'sensitive_attr': sensitive_attr, 'metric': metric, 'pretrain': pretrain, 'model': model_list[model_idx], 'hparam_seed': hparam_seed_idx, 
    'val_acc': values_val[best_seed], 'test_acc': values_test[best_seed], 'test_fair': values_test_fair[best_seed], 'group_0': values_per_group[best_seed][0], 'group_1': values_per_group[best_seed][1]})
output = pd.DataFrame(output)
output.sort_values(['dataset', 'sensitive_attr', 'metric', 'pretrain'], inplace=True)
# output.to_csv('result/tte_model_selection.csv', index=False)
output.to_csv('result/tte_model_selection_deephit.csv', index=False)
