from utils import set_seed
import numpy as np
import hashlib

SMALL_IMAGES = ['ColoredMNIST']

def seed_hash(*args):
    """
    Derive an integer hash from all args, for use as a random seed.
    """
    args_str = str(args)
    return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31)

class hparams:
    def __init__(self, dataset='ColoredMNIST', algorithm="ERM", n_hparam_comb=20, trial_no=0, seed=42):
        set_seed(seed)
        self.n_hparam_comb = n_hparam_comb
        self.dataset = dataset
        self.trial_no = trial_no
        self.algorithm = algorithm
        self.seed = seed
        self.hparams = self.load_hparams()

    def load_hparams(self):
        hparams = []
            
        for n, state in enumerate(range(self.seed, self.seed + self.n_hparam_comb)):
            def _hparam(name, default_val, random_val_fn):
                """Define a hyperparameter. random_val_fn takes a RandomState and
                returns a random hyperparameter value."""
                if len(hparams) <= n:
                    hparams.append({})
    
                random_state = np.random.RandomState(
                    seed_hash(seed_hash(state, self.trial_no), name)
                )
    
                if n == 0:
                    hparams[n][name] = default_val
                else:
                    hparams[n][name] = random_val_fn(random_state)
            
            if self.algorithm == 'DANN' or self.algorithm == 'CDANN':        
                _hparam('lambda_', 1.0, lambda r: 10**r.uniform(-2, 2))
                _hparam('d_steps_per_g', 1, lambda r: int(2**r.uniform(0, 3)))
                _hparam('grad_penalty', 0., lambda r: 10**r.uniform(-2, 1))
                _hparam('weight_decay_d', 0., lambda r: 10**r.uniform(-6, -2))
    
                if self.dataset in SMALL_IMAGES:
                    _hparam('lr_g', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))
                    _hparam('lr_d', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))
                    _hparam('weight_decay_g', 0., lambda r: 0.)
                else:
                    _hparam('lr_g', 5e-5, lambda r: 10**r.uniform(-5, -3.5))
                    _hparam('lr_d', 5e-5, lambda r: 10**r.uniform(-5, -3.5))
                    _hparam('weight_decay_g', 0., lambda r: 10**r.uniform(-6, -2))
    
            elif self.algorithm == 'Fish':
                _hparam('meta_lr', 0.5, lambda r:r.choice([0.05, 0.1, 0.5]))
    
            elif self.algorithm == "IRM":
                _hparam('irm_lambda', 1e2, lambda r: 10**r.uniform(-1, 5))
                _hparam('irm_penalty_anneal_iters', 500,
                        lambda r: int(10**r.uniform(0, 4)))
    
            elif self.algorithm == "GroupDRO":
                _hparam('groupdro_eta', 1e-2, lambda r: 10**r.uniform(-3, -1))
    
            elif self.algorithm == "MMD" or self.algorithm == "CORAL":
                _hparam('mmd_gamma', 1., lambda r: 10**r.uniform(-1, 1))
    
            elif self.algorithm == 'TwinModel':
                _hparam('weight_decay_label', 0., lambda r: 10**r.uniform(-6, -2))
                _hparam('weight_decay_group', 0., lambda r: 10**r.uniform(-6, -2))
    
                _hparam('lambda_metric', 1., lambda r: 10**r.uniform(-1, 2))
                _hparam('lambda_group', 1., lambda r: 10**r.uniform(-1, 2))
    
                if self.dataset in SMALL_IMAGES:
                    _hparam('lr_label', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))
                    _hparam('lr_group', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))
                else:
                    _hparam('lr_label', 5e-5, lambda r: 10**r.uniform(-5, -3.5))
                    _hparam('lr_group', 5e-5, lambda r: 10**r.uniform(-5, -3.5))

            elif self.algorithm == 'ERM':
                pass
            
            else:
                raise NotImplementedError()
    
            if self.dataset in SMALL_IMAGES:
                _hparam('lr', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))
            else:
                _hparam('lr', 5e-5, lambda r: 10**r.uniform(-5, -3.5))
    
            if self.dataset in SMALL_IMAGES:
                _hparam('weight_decay', 0., lambda r: 0.)
            else:
                _hparam('weight_decay', 0., lambda r: 10**r.uniform(-6, -2))
    
        return hparams