
import numpy as np
from lib import misc

def _define_hparam(hparams, hparam_name, default_val, random_val_fn):
    hparams[hparam_name] = (hparams, hparam_name, default_val, random_val_fn)


def _hparams(algorithm, dataset, random_seed, args):
    """
    Global registry of hyperparams. Each entry is a (default, random) tuple.
    New algorithms / networks / etc. should add entries here.
    """
    SMALL_IMAGES = ['Debug28', 'RotatedMNIST', 'ColoredMNIST', 'MNISTUSPS', 'USPSMNIST']
    MEDIUM_IMAGES = ['SVHNMNIST']
    RESNET18 = False if dataset == 'VisDA17' else True

    hparams = {}
    def _hparam(name, default_val, random_val_fn):
        """Define a hyperparameter. random_val_fn takes a RandomState and
        returns a random hyperparameter value."""
        #assert(name not in hparams)
        random_state = np.random.RandomState(
            misc.seed_hash(random_seed, name)
        )
        hparams[name] = (default_val, random_val_fn(random_state))

    # Unconditional hparam definitions.

    _hparam('data_augmentation', True, lambda r: True)
    _hparam('resnet18', RESNET18, lambda r: RESNET18)
    _hparam('resnet_dropout', 0., lambda r: r.choice([0., 0.1, 0.5]))
    _hparam('class_balanced', False, lambda r: False)
    _hparam('nonlinear_classifier', False, lambda r: bool(r.choice([False, True])))

    _hparam('specify_zdim', True, lambda r: bool(r.choice([False, True])))
    # _hparam('grad_decay', args.grad, lambda r: 10 ** r.uniform(-2, -1))
    # _hparam('grad_decay', args.grad, lambda r: r.uniform(0.1, 0.9))


    # Network-specific defifitions:
    _hparam('z_dim', 256, lambda r: int(r.choice([16, 128, 256, 512])))



    # Algorithm-specific hparam definitions. Each block of code below
    # corresponds to exactly one algorithm.

    if algorithm in ['KLGP', 'KLGPv2', 'ERM_GP', 'KLCLGP', 'KLCL']:
        #_hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2))
        _hparam('grad_decay', args.grad, lambda r: r.uniform(0, 0.9))

    if algorithm in ['KLCLGP', 'KLCL']:
        #_hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2))
        _hparam('grad_decay', args.grad, lambda r: r.uniform(0, 0.9))
        _hparam('dis_decay', args.dis, lambda r: r.uniform(0.1, 0.5))

    if algorithm in ['ERM', 'ERM_LS']:
        #_hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2))
        _hparam('epsilon_LS', args.ls, lambda r: r.uniform(0.1, 0.5))

    if algorithm in ['CLIGAv1', 'CLIGAv2', 'CLIGAv3', 'CLIGAv4', 'CLIGAv5', 'CLIGAPv1', 'CLIGAPv2']:
        #_hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2))
        _hparam('dis_decay', args.dis, lambda r: r.uniform(0.1, 0.5))
        _hparam('warm_up', args.warm, lambda r: r.uniform(10, 300))
        _hparam('grad_decay', args.grad, lambda r: r.uniform(0.1, 0.9))
        _hparam('sigma_decay', args.sigma, lambda r: r.uniform(0.1, 0.5))

    if algorithm in ['LIDAv1', 'LIDAv2']:
        #_hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2))
        _hparam('dis_decay', args.dis, lambda r: r.uniform(0.1, 0.5))
        _hparam('warm_up', args.warm, lambda r: r.uniform(10, 300))
        _hparam('grad_decay', args.grad, lambda r: r.uniform(0.1, 0.9))

    if algorithm in ['DANN', 'CDANN']:
        #_hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2))
        _hparam('lambda', 10.0, lambda r: 10**r.uniform(-2, 2))
        _hparam('weight_decay_d', 0., lambda r: 10**r.uniform(-6, -2))
        _hparam('d_steps_per_g_step', 1, lambda r: int(2**r.uniform(0, 3)))
        _hparam('grad_penalty', 0., lambda r: 10**r.uniform(-2, 1))
        _hparam('beta1', 0.5, lambda r: r.choice([0., 0.5]))
        _hparam('mlp_width', 256, lambda r: int(2 ** r.uniform(6, 10)))
        _hparam('mlp_depth', 3, lambda r: int(r.choice([3, 4, 5])))
        _hparam('mlp_dropout', 0., lambda r: r.choice([0., 0.1, 0.5]))

    if algorithm == 'WD':
        _hparam('weight_decay_wd', 0., lambda r: 10**r.uniform(-6, -2))
        _hparam('grad_penalty', 10., lambda r: 10**r.uniform(-2, 1))
        _hparam('lambda_wd', 1.0, lambda r: 10**r.uniform(-2, 2))
        _hparam('wd_steps_per_step', 5, lambda r: int(2**r.uniform(1, 3)))
        _hparam('mlp_width', 256, lambda r: int(2 ** r.uniform(6, 10)))
        _hparam('mlp_depth', 3, lambda r: int(r.choice([3, 4, 5])))
        _hparam('mlp_dropout', 0., lambda r: r.choice([0., 0.1, 0.5]))


    if algorithm == "RSC":
        _hparam('rsc_f_drop_factor', 1/3, lambda r: r.uniform(0, 0.5))
        _hparam('rsc_b_drop_factor', 1/3, lambda r: r.uniform(0, 0.5))

    if algorithm == "SagNet":
        _hparam('sag_w_adv', 0.1, lambda r: 10**r.uniform(-2, 1))

    if algorithm == "IRM":
        _hparam('irm_lambda', 1e2, lambda r: 10**r.uniform(-1, 5))
        _hparam('irm_penalty_anneal_iters', 500, lambda r: int(10**r.uniform(0, 4)))

    if algorithm == "Mixup":
        _hparam('mixup_alpha', 0.2, lambda r: 10**r.uniform(-1, -1))

    if algorithm == "GroupDRO":
        _hparam('groupdro_eta', 1e-2, lambda r: 10**r.uniform(-3, -1))

    if algorithm == "MMD" or algorithm == "CORAL" or algorithm == "PMMD":
        _hparam('mmd_gamma', 1., lambda r: 10**r.uniform(-1, 1))

    if algorithm == "MLDG":
        _hparam('mldg_beta', 1., lambda r: 10**r.uniform(-1, 1))

    if algorithm == "MTL":
        _hparam('mtl_ema', .99, lambda r: r.choice([0.5, 0.9, 0.99, 1.]))

    if algorithm == "VREx":
        _hparam('vrex_lambda', 1e1, lambda r: 10**r.uniform(-1, 5))
        _hparam('vrex_penalty_anneal_iters', 500, lambda r: int(10**r.uniform(0, 4)))

    if algorithm == "SD":
        _hparam('sd_reg', args.sd, lambda r: 10**r.uniform(-5, -1))

    if algorithm in ['KL', 'PERM', 'PMMD', 'KLUP', 'KLGP', 'KLGPv2', 'CLIGAPv1', 'CLIGAPv2', 'KLCL', 'KLCLGP']:
        _hparam('num_samples', 20, lambda r: 20)



    # Dataset-and-algorithm-specific hparam definitions. Each block of code
    # below corresponds to exactly one hparam. Avoid nested conditionals.


    if dataset in SMALL_IMAGES or dataset in MEDIUM_IMAGES:
        _hparam('lr', args.lr, lambda r: 10**r.uniform(-4.5, -2.5))
    else:
        _hparam('lr', 5e-5, lambda r: 10**r.uniform(-5, -3.5))


    if dataset in SMALL_IMAGES or dataset in MEDIUM_IMAGES:
        _hparam('weight_decay', 0., lambda r: 0.)
    else:
        _hparam('weight_decay', 0., lambda r: 10**r.uniform(-6, -2))


    if algorithm != 'KL':
        if dataset in SMALL_IMAGES or dataset in MEDIUM_IMAGES:
            _hparam('batch_size', 64, lambda r: int(2**r.uniform(3, 9)) )
        elif algorithm == 'ARM':
            _hparam('batch_size', 8, lambda r: 8)
        elif dataset == 'DomainNet':
            _hparam('batch_size', 32, lambda r: int(2**r.uniform(3, 5)) )
        else:
            _hparam('batch_size', 64, lambda r: int(2**r.uniform(3, 5.5)) )


    if algorithm in ['DANN', 'CDANN'] and (dataset in SMALL_IMAGES or dataset in MEDIUM_IMAGES):
        _hparam('lr_g', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5) )
    elif algorithm in ['DANN', 'CDANN']:
        _hparam('lr_g', 5e-5, lambda r: 10**r.uniform(-5, -3.5) )


    if algorithm in ['DANN', 'CDANN'] and (dataset in SMALL_IMAGES or dataset in MEDIUM_IMAGES):
        _hparam('lr_d', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5) )
    elif algorithm in ['DANN', 'CDANN']:
        _hparam('lr_d', 5e-5, lambda r: 10**r.uniform(-5, -3.5) )


    if algorithm == 'WD' and (dataset in SMALL_IMAGES or dataset in MEDIUM_IMAGES):
        _hparam('lr_wd', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5) )
    elif algorithm == 'WD':
        _hparam('lr_wd', 5e-5, lambda r: 10**r.uniform(-5, -3.5) )


    if algorithm in ['DANN', 'CDANN'] and (dataset in SMALL_IMAGES or dataset in MEDIUM_IMAGES):
        _hparam('weight_decay_g', 0., lambda r: 0.)
    elif algorithm in ['DANN', 'CDANN']:
        _hparam('weight_decay_g', 0., lambda r: 10**r.uniform(-6, -2) )

    

    if algorithm in ['KL', 'KLUP', 'KLGP', 'KLGPv2', 'KLCL', 'KLCLGP']:
        _hparam('augment_softmax', 0.0, lambda r: r.choice([0.0,0.01,0.05]))
        if dataset in SMALL_IMAGES:
            # _hparam('kl_reg', 0.3, lambda r: 0.3)
            # _hparam('kl_reg_aux', 0.1, lambda r: r.choice([0.0, 0.1]))
            # _hparam('batch_size', 32, lambda r: 32)

            _hparam('kl_reg', args.kl, lambda r: 0.3)
            _hparam('kl_reg_aux', args.klaux, lambda r: r.choice([0.0, 0.1]))

            _hparam('batch_size', 256, lambda r: r.choice([32, 256]))
            # _hparam('resnet_dropout', 0., lambda r: r.choice([0.]))
            _hparam('nonlinear_classifier', False, lambda r: bool(r.choice([False])))
            # _hparam('z_dim', 128, lambda r: int(r.choice([16, 128])))
        elif dataset in MEDIUM_IMAGES:
            # _hparam('kl_reg', 0.055, lambda r: 0.3)
            # _hparam('kl_reg_aux', 0.055, lambda r: r.choice([0.0]))
            _hparam('batch_size', 256, lambda r: 256)
            _hparam('z_dim', 16, lambda r: int(r.choice([16])))
            _hparam('augment_softmax', 0.01, lambda r: r.choice([0.0,0.01,0.05]))

            _hparam('kl_reg', args.kl, lambda r: 0.3)
            _hparam('kl_reg_aux', args.klaux, lambda r: r.choice([0.0, 0.1]))


            # _hparam('kl_reg', 0.3, lambda r: r.choice([0.3, 0.055]))
            # _hparam('kl_reg_aux', 0.1, lambda r: r.choice([0.1, 0.055, 0.0]))
            # _hparam('nonlinear_classifier', True, lambda r: bool(r.choice([False])))
            # _hparam('z_dim', 256, lambda r: int(r.choice([16, 128])))
        elif dataset == 'VisDA17':
            _hparam('kl_reg', 0.002, lambda r: r.choice([0.05, 0.1, 0.2]))
            _hparam('kl_reg_aux', 0.001, lambda r: r.choice([0.0]))
            _hparam('batch_size', 128, lambda r: int(r.choice([64, 128, 256])))
            _hparam('z_dim', 16, lambda r: int(r.choice([16])))
            _hparam('lr', 1e-4, lambda r: 1e-4)
            _hparam('weight_decay', 0.0, lambda r: 0.)

            _hparam('resnet_dropout', 0., lambda r: r.choice([0.]))
            _hparam('nonlinear_classifier', False, lambda r: bool(r.choice([False])))

            _hparam('augment_softmax', 0.05, lambda r: r.choice([0.0,0.01,0.05]))
        else:
            _hparam('kl_reg', 0.1, lambda r: r.choice([0.05, 0.1, 0.2]))
            _hparam('kl_reg_aux', 0.0, lambda r: r.choice([0.0]))
            _hparam('batch_size', 256, lambda r: int(r.choice([64, 128, 256])))
            _hparam('z_dim', 16, lambda r: int(r.choice([16])))
            _hparam('lr', 1e-4, lambda r: 1e-4)
            _hparam('weight_decay', 0.0, lambda r: 0.)

            _hparam('resnet_dropout', 0., lambda r: r.choice([0.]))
            _hparam('nonlinear_classifier', False, lambda r: bool(r.choice([False])))

            _hparam('augment_softmax', 0.05, lambda r: r.choice([0.0,0.01,0.05]))

    return hparams

def default_hparams(algorithm, dataset, args):
    return {a: b for a,(b,c) in
        _hparams(algorithm, dataset, 0, args).items()}

def random_hparams(algorithm, dataset, seed, args):
    return {a: c for a,(b,c) in _hparams(algorithm, dataset, seed, args).items()}
