"""Defining hyper parameters and their distributions for HPO"""

import numpy as np

from woods.objectives.assign_objectives import OBJECTIVES

def get_training_hparams(dataset_name, seed, sample=False):
    """ Get training related hyper parameters (class_balance, weight_decay, lr, batch_size)

    Args:
        dataset_name (str): dataset that is gonna be trained on for the run
        seed (int): seed used if hyper parameter is sampled
        sample (bool, optional): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.

    Raises:
        NotImplementedError: Dataset name not found

    Returns:
        dict: Dictionnary with hyper parameters values
    """

    dataset_train = dataset_name + '_train'
    if dataset_train not in globals():
        raise NotImplementedError("dataset not found: {}".format(dataset_name))
    else:
        hyper_function = globals()[dataset_train]

    hparams = hyper_function(sample)
    
    for k in hparams.keys():
        hparams[k] = hparams[k](np.random.RandomState(seed))

    return hparams

def DSADS_CROSSPOSITION_train(sample):
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-3, -5),
            'batch_size': lambda r: int(2**r.uniform(3, 7))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 5 * 10**-4,
            'lr': lambda r: 1e-3,
            'batch_size': lambda r: 32
        }

def Basic_Fourier_train(sample):
    """ Basic Fourier model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-4.5, -2.5),
            'batch_size': lambda r: int(2**r.uniform(3, 9))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 1e-3,
            'batch_size': lambda r: 64
        }
    
def Spurious_Fourier_train(sample):
    """ Spurious Fourier model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-4.5, -2.5),
            'batch_size': lambda r: int(2**r.uniform(3, 9))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 1e-5,
            'lr': lambda r: 1e-2,
            'batch_size': lambda r: 64
        }

def TMNIST_train(sample):
    """ TMNIST model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-4.5, -2.5),
            'batch_size': lambda r: int(2**r.uniform(3, 9))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 1e-3,
            'batch_size': lambda r: 64
        }

def TCMNIST_Source_train(sample):
    """ TCMNIST_Source model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-4.5, -2.5),
            'batch_size': lambda r: int(2**r.uniform(3, 9))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 1e-3,
            'batch_size': lambda r: 64
        }

def TCMNIST_Time_train(sample):
    """ TCMNIST_Time model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-4.5, -2.5),
            'batch_size': lambda r: int(2**r.uniform(3, 9))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 1e-3,
            'batch_size': lambda r: 64
        }

def CAP_train(sample):
    """ CAP model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-5, -3),
            'batch_size': lambda r: int(2**r.uniform(3, 4))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 10**-4,
            'batch_size': lambda r: 8
        }

# def bendr_train(sample):
#     """ CAP model hparam definition 
    
#     Args:
#         sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
#     """
#     if sample:
#         return {
#             'class_balance': lambda r: True,
#             'weight_decay': lambda r: 0.,
#             'lr': lambda r: 10**r.uniform(-5, -3),
#             'batch_size': lambda r: int(2**r.uniform(3, 4))
#         }
#     else:
#         return {
#             'class_balance': lambda r: True,
#             'weight_decay': lambda r: 0,
#             'lr': lambda r: 5*10**-5,
#             'batch_size': lambda r: 8
#         }

def SEDFx_train(sample):
    """ SEDFx model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-5, -3),
            'batch_size': lambda r: int(2**r.uniform(3, 4))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 10**-4,
            'batch_size': lambda r: 8
        }

def PCL_train(sample):
    """ PCL model hparam definition """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-5, -3),
            'batch_size': lambda r: int(2**r.uniform(3, 5))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**-3,
            'batch_size': lambda r: 16
        }

def HHAR_train(sample):
    """ HHAR model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-4, -2),
            'batch_size': lambda r: int(2**r.uniform(3, 4))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 10**-3,
            'batch_size': lambda r: 16
        }

def LSA64_train(sample):
    """ LSA64 model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-5, -3),
            'batch_size': lambda r: int(2**r.uniform(3, 4))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 10**-4,
            'batch_size': lambda r: 2
        }

def AusElectricityUnbalanced_train(sample):
    """ AusElectricity model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-5, -3),
            'batch_size': lambda r: int(2**r.uniform(5, 7))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 10**-4,
            'batch_size': lambda r: 10
        }

def AusElectricity_train(sample):
    """ AusElectricity model hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0.,
            'lr': lambda r: 10**r.uniform(-5, -3),
            'batch_size': lambda r: int(2**r.uniform(3, 5))
        }
    else:
        return {
            'class_balance': lambda r: True,
            'weight_decay': lambda r: 0,
            'lr': lambda r: 10**-4,
            'batch_size': lambda r: 2
        }
    
def IEMOCAPOriginal_train(sample):
    if sample:
        return {
            'class_balance': lambda r: True,
            'lr': lambda r: 10**r.random(-3, -5),
            'weight_decay': lambda r: 0.,
            'batch_size': lambda r: int(2**r.uniform(4, 6)),
        }
    else:
        return {
            'class_balance': lambda r: True,
            'lr': lambda r: 1e-4,
            'weight_decay': lambda r: 1e-5,
            'batch_size': lambda r: 30,
        }

def IEMOCAPUnbalanced_train(sample):
    if sample:
        return {
            'class_balance': lambda r: True,
            'lr': lambda r: 10**r.uniform(-3, -5),
            'weight_decay': lambda r: 0.,
            'batch_size': lambda r: int(2**r.uniform(4, 6)),
        }
    else:
        return {
            'class_balance': lambda r: True,
            'lr': lambda r: 1e-4,
            'weight_decay': lambda r: 1e-5,
            'batch_size': lambda r: 30,
        }

def IEMOCAP_train(sample):
    if sample:
        return {
            'class_balance': lambda r: True,
            'lr': lambda r: 10**r.uniform(-3, -5),
            'weight_decay': lambda r: 0.,
            'batch_size': lambda r: int(2**r.uniform(1, 4)),
        }
    else:
        return {
            'class_balance': lambda r: True,
            'lr': lambda r: 1e-4,
            'weight_decay': lambda r: 1e-5,
            'batch_size': lambda r: 3,
        }


def get_objective_hparams(objective_name, seed, sample=False):
    """ Get the objective related hyper parameters

    Each objective has their own model hyper parameters definitions

    Args:
        objective_name (str): objective that is gonna be trained on for the run
        seed (int): seed used if hyper parameter is sampled
        sample (bool, optional): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.

    Raises:
        NotImplementedError: Objective name not found

    Returns:
        dict: Dictionnary with hyper parameters values
    """
    # Return the objective class with the given name
    objective_hyper = objective_name+'_hyper'
    if objective_hyper not in globals():
        raise NotImplementedError("objective not found: {}".format(objective_name))
    else:
        hyper_function = globals()[objective_hyper]

    hparams = hyper_function(sample)

    for k in hparams.keys():
        hparams[k] = hparams[k](np.random.RandomState(seed))
    
    return hparams

def IIB_hyper(sample):
    if sample:
        return {
            'lambda_beta': lambda r: 1e-3 ** r.uniform(-2, 2),
            'lambda_inv_risks': lambda r: int(10 ** r.uniform(0, 2)),
            'enable_bn': lambda r: bool(r.choice([True, True])),
            'nonlinear_classifier': lambda r: bool(r.choice([False, True]))
            
        }
    else:
        return {
            'lambda_beta': lambda r: 1e-3,
            'lambda_inv_risks': lambda r: 1,
            'enable_bn': lambda r: True,
            'nonlinear_classifier': lambda r: False,
        }

def GILE_hyper(sample):
    if sample:
        return {
            'aux_optim_learning_rate': lambda r: 10**r.uniform(-3,-5),
        }
    else:
        return {
            'aux_optim_learning_rate': lambda r: 1e-3
        }

def AdaRNN_hyper(sample):
    if sample:
        return {
            'dw': lambda r : [0.01, 0.05, 5.0], # 0.01, 0.05, 5.0
            'len_win': lambda r : 0,
        }
    else:
        return {
            'dw': lambda r : 0.5,
            'len_win': lambda r : 0,
        }

def FEDNet_hyper(sample):
    if sample:
        return {
            'aux_optim_learning_rate': lambda r: 10**r.uniform(-3,-5),
        }
    else:
        return {
            'aux_optim_learning_rate': lambda r: 1e-3
        }

def ERM_hyper(sample):
    """ ERM objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    return {}

def GroupDRO_hyper(sample):
    """ IRM objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'eta': lambda r: 10**r.uniform(-3,-1),
        }
    else:
        return {
            'eta': lambda r: 1e-2,
        }

def IRM_hyper(sample):
    """ IRM objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'penalty_weight': lambda r: 10**r.uniform(-1,5),
            'anneal_iters': lambda r: r.uniform(0,2000)
        }
    else:
        return {
            'penalty_weight': lambda r: 1e3,
            'anneal_iters': lambda r: 4000
        }

def VREx_hyper(sample):
    """ VREx objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'penalty_weight': lambda r: 10**r.uniform(-1,5),
            'anneal_iters': lambda r: r.uniform(0,2000)
        }
    else:
        return {
            'penalty_weight': lambda r: 1e4,
            'anneal_iters': lambda r: 2000
        }

def SD_hyper(sample):
    """ SD objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'penalty_weight': lambda r: 10**r.uniform(-5,-1)
        }
    else:
        return {
            'penalty_weight': lambda r: 1
        }
        
def IGA_hyper(sample):
    """ IGA objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'penalty_weight': lambda r: 10**r.uniform(1,5)
        }
    else:
        return {
            'penalty_weight': lambda r: 1e1
        }

def ANDMask_hyper(sample):
    """ ANDMask objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'tau': lambda r: r.uniform(0,1)
        }
    else:
        return {
            'tau': lambda r: 1
        }


def Fish_hyper(sample):
    """ Fish objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'meta_lr': lambda r: 0.5
        }
    else:
        return {
            'meta_lr': lambda r:r.choice([0.05, 0.1, 0.5])
        }
        
def SANDMask_hyper(sample):
    """ SANDMask objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'tau': lambda r: r.uniform(0.0,1.),
            'k': lambda r: 10**r.uniform(-3, 5),
            'betas': lambda r: r.uniform(0.9,0.999)
        }
    else:
        return {
            'tau': lambda r: 0.5,
            'k': lambda r: 1e+1,
            'betas': lambda r: 0.9
        }        
        
def IB_ERM_hyper(sample):
    """ IB_ERM objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'ib_weight': lambda r: 10**r.uniform(-3,0),
        }
    else:
        return {
            'ib_weight': lambda r: 0.1,
        }
        
def IB_IRM_hyper(sample):
    """ IB_ERM objective hparam definition 
    
    Args:
        sample (bool): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.
    """
    if sample:
        return {
            'penalty_weight': lambda r: 10**r.uniform(-1,5),
            'anneal_iters': lambda r: r.uniform(0,2000),
            'ib_lambda': lambda r: int(10 ** r.uniform(-1, 5)),
            'ib_penalty_anneal_iters': lambda r: r.uniform(0,2000)
            # 'irm_lambda': lambda r: 10**r.uniform(-1,5),  
            # 'irm_penalty_anneal_iters': lambda r: int(10 ** r.uniform(0, 4)),
        }
    else:
        return {
            'penalty_weight': lambda r: 0, #1e2,
            'anneal_iters': lambda r: 500, #1e2,
            'ib_lambda': lambda r: 10, #500,
            'ib_penalty_anneal_iters': lambda r: 500, #500
        }

def CAD_hyper(sample):
    """_summary_

    Args:
        sample (_type_): _description_
    """
    if sample:
        return {
            'lmbda': lambda r: r.choice([1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2]),
            'temperature': lambda r: r.choice([0.05, 0.1]),
            'is_project': lambda r: False,
            'is_normalized': lambda r: False,
            'is_flipped': lambda r: True,
        }
    else:
        return {
            'lmbda': lambda r: 0.01,
            'temperature': lambda r: 0.1,
            'is_project': lambda r: False,
            'is_normalized': lambda r: False,
            'is_flipped': lambda r: True,
        }

def CondCAD_hyper(sample):
    """_summary_

    Args:
        sample (_type_): _description_
    """
    if sample:
        return {
            'lmbda': lambda r: r.choice([1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2]),
            'temperature': lambda r: r.choice([0.05, 0.1]),
            'is_project': lambda r: False,
            'is_normalized': lambda r: False,
            'is_flipped': lambda r: True,
        }
    else:
        return {
            'lmbda': lambda r: 0.01,
            'temperature': lambda r: 0.1,
            'is_project': lambda r: False,
            'is_normalized': lambda r: False,
            'is_flipped': lambda r: True,
        }

def Transfer_hyper(sample):
    """_summary_

    Args:
        sample (_type_): _description_
    """
    if sample:
        return {
            't_lambda': lambda r: 10**r.uniform(-2, 1),
            'delta': lambda r: r.uniform(0.1, 3.0),
            'd_steps_per_g': lambda r: int(r.choice([1, 2, 5])),
            'weight_decay_d': lambda r: 10**r.uniform(-6, -2),
            'gda': lambda r: False,
            'beta1': lambda r: r.choice([0., 0.5]),
            'lr_d': lambda r: 10**r.uniform(-4.5, -2.5),
        }
    else:
        return {
            't_lambda': lambda r: 0.01,
            'delta': lambda r: 2.0,
            'd_steps_per_g': lambda r: 10,
            'weight_decay_d': lambda r: 0.,
            'gda': lambda r: False,
            'beta1': lambda r: 0.5,
            'lr_d': lambda r: 1e-3,
        }