import numpy as np
import hashlib
from utils.misc import seed_hash


def dataset_hparams(dataset, hparams):
    hparams['dataset'] = dataset
    if dataset == 'mimiccxr':
        hparams['data_dir'] = 'datasets/mimiccxr'
        hparams['batch_size'] = 128
        hparams['test_batch_size'] = 128
        hparams['feature_dim'] = 64
        hparams['hidden_dim'] = 16
        hparams['n_channel'] = 3
        hparams['num_step'] = 4000
        hparams['checkpoint_freq'] = 400
        hparams['sen_attr_n_class'] = 2
    elif dataset == 'adni':
        hparams['data_dir'] = 'datasets/ADNI'
        hparams['batch_size'] = 48
        hparams['test_batch_size'] = 48
        hparams['feature_dim'] = 64
        hparams['n_channel'] = 3
        hparams['hidden_dim'] = 16
        hparams['num_step'] = 1000
        hparams['checkpoint_freq'] = 100
        hparams['sen_attr_n_class'] = 2
    elif dataset == 'areds':
        hparams['data_dir'] = 'datasets/AREDS'
        hparams['batch_size'] = 96
        hparams['test_batch_size'] = 128
        hparams['feature_dim'] = 64
        hparams['hidden_dim'] = 16
        hparams['n_channel'] = 3
        hparams['num_step'] = 4000
        hparams['checkpoint_freq'] = 400
        hparams['sen_attr_n_class'] = 2
    # elif dataset == 'ukb-mi':
    #     hparams['data_dir'] = 'datasets/ukbiobank'
    #     hparams['batch_size'] = 256
    #     hparams['test_batch_size'] = 256
    #     hparams['feature_dim'] = 64
    #     hparams['hidden_dim'] = 16
    #     hparams['n_channel'] = 3
    #     hparams['num_step'] = 1000
    #     hparams['checkpoint_freq'] = 100
    #     hparams['sen_attr_n_class'] = 2
    else:
        raise ValueError('Unknown dataset: %s' % dataset)
    return hparams

def random_hparams(surv_model, fair_model, dataset, sensitive_attr, metric, hparams_seed, random_seed):
    
    hparams = {'surv_model': surv_model, 'fair_model': fair_model, 'dataset': dataset, 'sensitive_attribute': sensitive_attr, 'metric': metric}

    def _hparam(name, random_val_fn):
        """Define a hyperparameter. random_val_fn takes a RandomState and
        returns a random hyperparameter value."""
        assert(name not in hparams)
        random_state = np.random.RandomState(seed_hash(random_seed, name))
        hparams[name] = random_val_fn(random_state)

    # General hparam definitions.

    if hparams_seed == 0:
        hparams['lr'] = 1e-4
        hparams['decay'] = 1e-5
    else:
        _hparam('lr', lambda r: 10**r.uniform(-4, -3))
        _hparam('decay', lambda r: 10**r.uniform(-6, -4))

    # Dataset hparam definitions.

    dataset_hparams(dataset, hparams)

    # Model hparam definitions.

    if surv_model == 'DeepHit':
        hparams['alpha'] = 0.2
        hparams['sigma'] = 0.1
    if fair_model == 'GroupDRO':
        hparams['lr'] = 1e-4
        hparams['decay'] = 1e-5
        _hparam('eta', lambda r: 10**r.uniform(-3, -1))
    elif fair_model == 'Regularization':
        hparams['lr'] = 1e-4
        hparams['decay'] = 1e-5
        fair_hparams_list = np.concatenate([np.arange(1e-5, 1e-4, 1e-5), 
                                        np.arange(1e-4, 1e-3, 1e-4), 
                                        np.arange(1e-3, 1e-2, 1e-3), 
                                        np.arange(1e-2, 1e-1, 1e-2), 
                                        np.arange(1e-1, 1, 1e-1), 
                                        np.arange(1, 10, 1), 
                                        np.arange(10, 100, 10)], axis=0)
        _hparam('fair_weight', lambda r: r.choice(fair_hparams_list))
    elif fair_model in ['DomainInd', 'DomainIndAggregated', 'Reweighting']:
        hparams['lr'] = 1e-4
        hparams['decay'] = 1e-5
    return hparams


def default_hparams(surv_model, fair_model, dataset, sensitive_attr, metric):
    
    hparams = {'surv_model': surv_model, 'fair_model': fair_model, 'dataset': dataset, 'sensitive_attribute': sensitive_attr, 'metric': metric}

    hparams['lr'] = 1e-4
    hparams['decay'] = 1e-5

    # Dataset hparam definitions.

    dataset_hparams(dataset, hparams)

    # Model hparam definitions.

    if surv_model == 'DeepHit':
        hparams['alpha'] = 0.2
        hparams['sigma'] = 0.1

    return hparams


if __name__ == '__main__':
    a = seed_hash(5, 'lr')
    print(a)