

import torch
import numpy as np

# calculate the worst-group accuracy for each model
def wg_acc(y, y_pred, g, return_all=False):

    # get the combinations in g, sort from smallest to largest
    groups = list(torch.unique(g))
    groups = sorted(groups)

    # correct
    correct = (y == y_pred).float()

    # get the accuracy per group
    accuracy = torch.zeros(len(groups))
    i = 0
    for group in groups:

        # get the index of the group, and get the accuracy
        idx =  (g == group)
        correct_group = correct[idx]
        accuracy_combination = correct_group.mean().item()
        accuracy[i] = accuracy_combination
        i+=1
    
    # calculate the equal-weighted accuracy
    weighted_acc = accuracy.mean().item()

    if return_all:
        return min(accuracy), accuracy, weighted_acc
    else:
        return min(accuracy)


def return_results(results, seeds, type_result='val'):

    # go over all the seeds - record average and standard error.
    n_seeds = len(seeds)
    loss_avg, loss_se = np.mean(results['loss_{}'.format(type_result)]), np.std(results['loss_{}'.format(type_result)])/np.sqrt(n_seeds)
    wg_avg, wg_se = np.mean(results['wg_{}'.format(type_result)]), np.std(results['wg_{}'.format(type_result)])/np.sqrt(n_seeds)
    weighted_acc_avg, weighted_acc_se = np.mean(results['weighted_acc_{}'.format(type_result)]), np.std(results['weighted_acc_{}'.format(type_result)])/np.sqrt(n_seeds)

    return loss_avg, loss_se, wg_avg, wg_se, weighted_acc_avg, weighted_acc_se

def get_key(method, dataset, penalty_type, penalty_strength, solver, tol, batch_size_model, augmentation_data, early_stopping_model, lambda_JTT, DFR_data,  penalty_type_predictor, penalty_strength_predictor, create_identifier, p_y_JTT, T, C_GDRO, eta_param_GDRO, eta_q_GDRO, fraction_original_data, seed):
    key = 'dataset_{}_method_{}_penalty_{}_strength_{}_solver_{}_tol_{}_bs_{}_ad_{}_es_{}_DFR_data_{}_T_{}_fraction_{}_seed_{}'.format(dataset, method, penalty_type, penalty_strength, solver, tol, batch_size_model, augmentation_data, early_stopping_model, DFR_data, T, fraction_original_data, seed)
    
    # add the JTT specific parameters
    if method == 'JTT':
     
        key += '_lambda_JTT_{}_create_identifier_{}_p_y_JTT_{}_penalty_predictor_{}_penalty_strength_predictor_{}'.format(lambda_JTT, create_identifier, p_y_JTT, penalty_type_predictor, penalty_strength_predictor)
    
    # add the GDRO specific parameters
    elif method == 'GDRO':
        key += '_C_GDRO_{}_eta_param_GDRO_{}_eta_q_GDRO_{}'.format(C_GDRO, eta_param_GDRO, eta_q_GDRO)

    return key
    

def create_result_entry(result_dict, method, dataset, penalty_type, penalty_strength, solver, tol, batch_size_model, augmentation_data, early_stopping_model, lambda_JTT, DFR_data,  penalty_type_predictor, penalty_strength_predictor, create_identifier, p_y_JTT, T, C_GDRO, eta_param_GDRO, eta_q_GDRO, fraction_original_data):

    # create for each method a dictionary
    result_dict['dataset'] = dataset
    result_dict['method'] = method
    result_dict['solver'] = solver
    result_dict['tol'] = tol
    result_dict['batch_size_model'] = batch_size_model
    result_dict['augmentation_data'] = augmentation_data
    result_dict['early_stopping_model'] = early_stopping_model
    result_dict['DFR_data'] = DFR_data
    result_dict['T'] = T
    result_dict['fraction_original_data'] = fraction_original_data
    result_dict['penalty_type'] = penalty_type
    result_dict['penalty_strength'] = penalty_strength

    
    # add the JTT specific parameters
    if method == 'JTT':
        result_dict['lambda_JTT'] = lambda_JTT
        result_dict['create_identifier'] = create_identifier
        result_dict['p_y_JTT'] = p_y_JTT
        result_dict['T'] = T
        result_dict['penalty_type_predictor'] = penalty_type_predictor
        result_dict['penalty_strength_predictor'] = penalty_strength_predictor
    
    # add the GDRO specific parameters
    elif method == 'GDRO':
        result_dict['C_GDRO'] = C_GDRO
        result_dict['eta_param_GDRO'] = eta_param_GDRO
        result_dict['eta_q_GDRO'] = eta_q_GDRO

    return  result_dict
    