import numpy as np

from advbench.lib import misc
from advbench import datasets

def default_hparams(algorithm, dataset):
    return {a: b for a, (b, c) in _hparams(algorithm, dataset, 0).items()}

def random_hparams(algorithm, dataset, seed):
    return {a: c for a, (b, c) in _hparams(algorithm, dataset, seed).items()}

def _hparams(algorithm: str, dataset: str, random_seed: int):
    """Global registry of hyperparams. Each entry is a (default, random) tuple.
    New algorithms / networks / etc. should add entries here.
    """

    hparams = {}

    def _hparam(name, default_val, random_val_fn):
        """Define a hyperparameter. random_val_fn takes a RandomState and
        returns a random hyperparameter value."""

        assert(name not in hparams)
        random_state = np.random.RandomState(misc.seed_hash(random_seed, name))
        hparams[name] = (default_val, random_val_fn(random_state))

    # _hparam('batch_size', 64, lambda r: int(2 ** r.uniform(3, 8))) # default 64
    # _hparam('learning_rate', 0.01, lambda r: 10 ** r.uniform(-4.5, -2.5))
    # _hparam('sgd_momentum', 0.9, lambda r: r.uniform(0.8, 0.95))
    # _hparam('weight_decay', 3.5e-3, lambda r: 10 ** r.uniform(-6, -3))

    # Standard training
    if 'MNIST' in dataset:
        _hparam('batch_size', 256, lambda r: 256)
        _hparam('learning_rate', 0.075, lambda r: 0.075)
        _hparam('sgd_momentum', 0.9, lambda r: 0.9)
        _hparam('weight_decay', 3e-6, lambda r: 3e-6)
    elif 'CIFAR10' in dataset:
        _hparam('batch_size', 128, lambda r: 128)
        _hparam('learning_rate', 0.1, lambda r: 0.1)
        _hparam('sgd_momentum', 0.9, lambda r: 0.9)
        _hparam('weight_decay', 5e-4, lambda r: 5e-4)

    # Adversarial training
    if 'MNIST' in dataset:
        _hparam('epsilon', 0.3, lambda r: 0.3)
        _hparam('pgd_n_steps', 10, lambda r: 10)
        _hparam('pgd_step_size', 0.1, lambda r: 0.1)
    elif 'CIFAR10' in dataset or dataset == 'SVHN':
        _hparam('epsilon', 8/255., lambda r: 8/255.)
        _hparam('pgd_n_steps', 10, lambda r: 10)
        _hparam('pgd_step_size', 2/255., lambda r: 2/255.)

    # Algorithm specific
    if 'MMD' in algorithm or 'EntropyTraining' in algorithm or 'Mirror' in algorithm:
        if 'MNIST' in dataset:
            # _hparam('mmd_alpha', 1e-1, lambda r: 10 ** r.uniform(-3, 0))
            _hparam('mmd_alpha', 0.05243, lambda r: 10 ** r.uniform(-3, 0))
            _hparam('mmd_beta', 0.00119, lambda r: 10 ** r.uniform(-3, 0))
            _hparam('mmd_grad_steps', 11, lambda r: int(r.choice(range(5, 15))))
        elif 'CIFAR10' in dataset:
            _hparam('mmd_alpha', 0.1, lambda r: 10 ** r.uniform(-1, 0))
            _hparam('mmd_beta', 0.1, lambda r: 10 ** r.uniform(-1, 0))
            _hparam('mmd_grad_steps', 1, lambda r: 1) # int(r.choice(range(1, 5))))

    if algorithm == 'BETA' or algorithm == 'SBETA':
        if 'MNIST' in dataset:
            _hparam('beta_lr', 0.01, lambda r: 10 ** r.uniform(-3, -1))
            _hparam('beta_n_steps', 10, lambda r: int(r.choice(range(5, 50))))
        elif 'CIFAR10' in dataset:
            _hparam('beta_lr', 0.01, lambda r: 0.01)
            _hparam('beta_n_steps', 10, lambda r: 2) #int(r.choice(range(10, 30))))

    if algorithm == 'SBETA':
        _hparam('sbeta_temperature', 10.0, lambda r: int(r.choice(range(1, 100))))

    if algorithm == 'TRADES':
        if 'MNIST' in dataset:
            _hparam('trades_n_steps', 10, lambda r: 7)
            _hparam('trades_step_size', 0.1, lambda r: r.uniform(0.01, 0.1))
            _hparam('trades_beta', 6.0, lambda r: r.uniform(0.1, 10.0))
        elif 'CIFAR10' in dataset or dataset == 'SVHN':
            _hparam('trades_n_steps', 10, lambda r: 15)
            _hparam('trades_step_size', 2/255., lambda r: r.uniform(0.01, 0.1))
            _hparam('trades_beta', 6.0, lambda r: r.uniform(0.1, 10.0))

    if algorithm == 'MART':
        if 'MNIST' in dataset:
            _hparam('mart_beta', 1.0, lambda r: r.uniform(0.1, 10.0))
        elif dataset == 'CIFAR10' or dataset == 'SVHN':
            _hparam('mart_beta', 5.0, lambda r: r.uniform(0.1, 10.0))

    if algorithm == 'TwoStepAT':
        _hparam('two_step_eta', 0.075, lambda r: 10 ** r.uniform(-4, -2))

    if 'GaussianDALE' in algorithm:
        if 'MNIST' in dataset:
            _hparam('g_dale_n_steps', 7, lambda r: 7)
            _hparam('g_dale_step_size', 0.1, lambda r: 0.1)
            _hparam('g_dale_noise_coeff', 0.001, lambda r: 10 ** r.uniform(-6.0, -2.0))
        elif 'CIFAR10' in dataset or dataset == 'SVHN':
            _hparam('g_dale_n_steps', 10, lambda r: 10)
            _hparam('g_dale_step_size', 0.007, lambda r: 0.007)
            _hparam('g_dale_noise_coeff', 0, lambda r: 0)
        _hparam('g_dale_nu', 0.1, lambda r: 0.1)

    if 'LaplacianDALE' in algorithm:
        if 'MNIST' in dataset:
            _hparam('l_dale_n_steps', 7, lambda r: 7)
            _hparam('l_dale_step_size', 0.1, lambda r: 0.1)
            _hparam('l_dale_noise_coeff', 0.001, lambda r: 10 ** r.uniform(-6.0, -2.0))
        elif dataset == 'CIFAR10' or dataset == 'SVHN':
            _hparam('l_dale_n_steps', 10, lambda r: 10)
            _hparam('l_dale_step_size', 0.007, lambda r: 0.007)
            _hparam('l_dale_noise_coeff', 1e-2, lambda r: 1e-2)
        _hparam('l_dale_nu', 0.1, lambda r: 0.1)

    if 'DALE_PD' in algorithm:
        _hparam('g_dale_pd_step_size', 0.001, lambda r: 0.001)
        _hparam('g_dale_pd_margin', 0.1, lambda r: 0.1)

    if 'CVAR_SGD' in algorithm:
        _hparam('cvar_sgd_t_step_size', 1.0, lambda r: 0.001)
        _hparam('cvar_sgd_beta', 0.5, lambda r: 0.1)
        _hparam('cvar_sgd_M', 20, lambda r: 10)
        _hparam('cvar_sgd_n_steps', 5, lambda r: 10)

    if algorithm == 'TERM':
        _hparam('term_t', 2.0, lambda r: 1.0)

    if algorithm == 'RandSmoothing':
        if dataset == 'CIFAR10' or dataset == 'SVHN':
            _hparam('rand_smoothing_sigma', 0.12, lambda r: 0.12)
            _hparam('rand_smoothing_n_steps', 10, lambda r: 7)
            _hparam('rand_smoothing_step_size', 2/255., lambda r: r.uniform(0.01, 0.1))
            _hparam('rand_smoothing_n_samples', 10, lambda r: 1)
        elif 'MNIST' in dataset:
            _hparam('rand_smoothing_sigma', 0.5, lambda r: 0.12)
            _hparam('rand_smoothing_n_steps', 7, lambda r: 10)
            _hparam('rand_smoothing_step_size', 0.1, lambda r: r.uniform(0.01, 0.1))
            _hparam('rand_smoothing_n_samples', 10, lambda r: 1)

    # if 'MNIST' in dataset:
    #     _hparam('stochastic_pgd_beta', 0.01, lambda r: r.uniform(0.01, 0.99))

    return hparams

def test_hparams(algorithm: str, dataset: str):

    hparams = {}

    def _hparam(name, default_val):
        """Define a hyperparameter for test adversaries."""

        assert(name not in hparams)
        hparams[name] = default_val

    _hparam('batch_size', 100)

    # _hparam('test_betas', [0.1, 0.05, 0.01])
    # _hparam('aug_n_samples', 100)

    ##### PGD #####
    if 'MNIST' in dataset:
        _hparam('epsilon', 0.3)
        _hparam('pgd_n_steps', 40)
        _hparam('pgd_step_size', 0.01)
    elif 'CIFAR10' in dataset or dataset == 'SVHN':
        _hparam('epsilon', 8/255.)
        _hparam('pgd_n_steps', 10)
        _hparam('pgd_step_size', 2/255.)

    ##### TRADES #####
    # if 'MNIST' in dataset:
    #     _hparam('trades_n_steps', 10)
    #     _hparam('trades_step_size', 0.1)
    # elif dataset == 'CIFAR10' or dataset == 'SVHN':
    #     _hparam('trades_n_steps', 20)
    #     _hparam('trades_step_size', 2/255.)

    if 'MNIST' in dataset:
        _hparam('beta_lr', 0.01)
        _hparam('beta_n_steps', 40)
    else:
        _hparam('beta_lr', 0.01)
        _hparam('beta_n_steps', 10)

    ##### CVaR SGD #####
    # _hparam('cvar_sgd_t_step_size', 0.5)
    # _hparam('cvar_sgd_beta', 0.05)
    # _hparam('cvar_sgd_M', 10)
    # _hparam('cvar_sgd_n_steps', 10)

    return hparams