

import torch
import numpy as np
import random
from torch.utils.data import DataLoader, TensorDataset
import sys

def set_seed(seed):
    """
    Sets the seed for the random number generators
    """
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

def Convert(string, type_=float):
        
    """
    Turn string to list of certain objects
    """
    li = list(string.split("-"))

    li_float = [type_(x) for x in li]

    return li_float

def str_to_bool(text):
    if text.lower() == 'true':
        return True
    elif text.lower() == 'false':
        return False
    

def fast_xtdx( X, diag):
        """
        Compute X.T * D * X where D is diagonal
        X: ndarray of shape (n, d)
        d: ndarray of shape (n,) representing the diagonal of D
        """
        # Element-wise multiply X by d
        DX = X * diag[:, np.newaxis]
        
        # Compute the final result
        result = X.T @ DX
        
        return result


def subsample_clip_p_min(  n_dict):

        # max of 1.0, min of 1/n_g per group
        p_min = {g: torch.as_tensor(1/n_dict[g]) for g in n_dict.keys()}

        return p_min


# calculate standard weights 
def calc_subsample_ood_weights(p_train, n_train):

    # get the n_g for each group
    n_g = {}
    for key in p_train.keys():
        n_g[key] = torch.ceil(p_train[key]*n_train).int()

    # get the n_g for the smallest group
    n_s = min(n_g.values())

    # for each group, calculate the weights
    p_ood = {g: torch.as_tensor(n_s/n_g[g]).to(torch.float32) for g in n_g.keys()}

    return p_ood

def key_of_min(d):
  return min(d, key = d.get)


     

def get_fraction_original_data(fraction, X, y, g):
    """
     Sample a fraction of the original data
    """
     
    # get the number of samples for the data
    n = X.shape[0]

    # get the number of samples to keep
    n_tilde = int(np.ceil(n*fraction))
    print('n_tilde: {}'.format(n_tilde))

    # shuffle the indices
    indices = np.arange(n)
    shuffled_indices = np.random.permutation(indices)

    # get the subsample
    indices = shuffled_indices[:n_tilde]
    X_tilde = X[indices, :]
    y_tilde = y[indices]
    g_tilde = g[indices]

    return X_tilde, y_tilde, g_tilde, indices


def get_best_hyper_param_method(method, param_search_df, selection_criteria='wg_val'):
    """
    param_search_df: dataframe with the hyper-parameter search results for the respective dataset
    """

    # We assume we have already done the hyperparameter search, and simply need to select the best model
    # based on the method, we will loop over different hyper-parameters, and get all results
    if method == 'ERM-GW' or method == 'ERM-SUBG' or method == 'DFR':
        param_to_search = ['penalty_strength']
    elif method == 'GDRO':
        param_to_search = ['penalty_strength', 'C_GDRO', 'eta_param_GDRO']
    elif method == 'JTT':
        param_to_search = ['penalty_strength', 'lambda_JTT']

    # based on the param to search, select the best model based on the selection criteria
    cols = param_to_search + [selection_criteria] + ['key'] # columns for the selection
    param_search_df_param_to_search = param_search_df[cols]
    

    # get the best hyper-parameters
    # if worst group accuracy is the selection criteria, we select the model with the lowest value
    if selection_criteria == 'wg_val':
        # multiply the column by -1, so we can select the minimum value
        param_search_df_param_to_search = param_search_df_param_to_search.copy()
        param_search_df_param_to_search[selection_criteria] = param_search_df_param_to_search[selection_criteria]*-1

    # select the minimum value
    min_idx =param_search_df_param_to_search[selection_criteria]==param_search_df_param_to_search[selection_criteria].min()
    best_param = param_search_df_param_to_search.loc[min_idx, :]

    # reshape the best_param
    if len(best_param.shape) ==1:
        best_param = best_param.unsqueeze(-1).T


    # check if more than one row in best_param
    if best_param.shape[0] > 1:
        if param_to_search[0] == 'penalty_strength':
            # sort on penalty strength
            best_param = best_param.sort_values(by='penalty_strength', ascending=True)
            
            # select the last row
            best_param = best_param.iloc[0]
            
            print('More than one best hyper-parameter found: {}'.format(best_param))
    else:
        best_param = best_param.iloc[0]
            

    return best_param, param_to_search




     




       
        
        
            