"""Default Hyperparameter configuration."""

import ml_collections


def d(**kwargs):
  """Helper of creating a config dict."""
  return ml_collections.ConfigDict(initial_dictionary=kwargs)


seed = 0


def get_config():
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.train_data_config = d(
    seed=seed,
    data_spec="data/dig=all-speaker=theo-split=train-stft.npy",
    data_dim=33,  # Used to construct mlp.
    shuffle=True,
    batchsize=100,
    mode='swrep',
    fixed_batch=False  # Use one batch from the empirical distribution.
  )
  # config.eval_data_config = dict(reuse_train=True)
  config.eval_data_config = d(
    shuffle=False,
    data_spec="data/dig=all-speaker=theo-split=test-stft.npy",
    batchsize=100,  # Using 'all' would cause OOM.
    mode='once',
    fixed_batch=False
  )
  config.train_eval_config = d(
    num_train_steps=30000,
    # num_eval_steps=1,
    num_eval_steps=None,  # Run eval by looping through the (finite) eval_ds.
    log_metrics_every_steps=100,
    log_imgs_every_steps=100,
    checkpoint_every_steps=int(1e9),
    eval_every_steps=1000,
    seed=seed,
    # warm_start="",
  )

  config.model_config = d(
    rd_lambda=10.0,
    latent_dim=33,
    num_samples=40000,  # NERDlagr.py
    data_dim=config.train_data_config.data_dim,
    distort_type='mse',
    scheduled_num_steps=config.train_eval_config.num_train_steps,
    transform_config=d(
      # decoder_units=[config.train_data_config.data_dim] * 2,  # NERD uses 2-layer MLP
      decoder_units=[config.train_data_config.data_dim] * 2,  # NERD uses 2-layer MLP
      decoder_activation='softplus',
      # decoder_activation='relu',
      # decoder_activation='leaky_relu',
      # decoder_activation=None,
      prior_type='std_gaussian',
      # ar_hidden_units=[10, 10],
      # ar_activation='softplus',
      # maf_stacks=2,
      # iaf_stacks=0
    ),
    optimizer_config=d(
      name='adam',
      # name='sgd',
      # lr_decay=True,
      # lr_decay_rate=1e-3,
      learning_rate=5e-4,
      # clip_norm=1.0,
    )
  )
  config.ckpt_restore_dir = 'None'
  return config


def get_cfg_str(config):
  from collections import OrderedDict
  runname_dict = OrderedDict()
  # runname_dict['ldim'] = config.model_config.latent_dim
  runname_dict['n'] = config.model_config.num_samples
  runname_dict['rd_lambda'] = config.model_config.rd_lambda
  runname_dict['units'] = config.model_config.transform_config.decoder_units
  # runname_dict['tseed'] = config.train_eval_config.seed

  from common import utils
  return utils.config_dict_to_str(runname_dict)


def get_hyper():
  """
  Produce a list of flattened dicts, each containing a hparam configuration overriding the one in
  get_config(), corresponding to one hparam trial/experiment/work unit.
  :return:
  """
  from common import hyper
  gaussian_noise_vars = [0.1, 1]
  rd_lambdas = list(reversed([0.3, 1, 3, 10, 30, 100, 300]))
  rd_lambdas = hyper.sweep('model_config.rd_lambda', rd_lambdas)
  # tseeds = hyper.sweep('train_eval_config.seed', list(range(5)))
  decoder_units = [[u, u] for u in (33, 50)]
  decoder_units = hyper.sweep('model_config.transform_config.decoder_units',
          decoder_units)

  hparam_cfgs = hyper.product(rd_lambdas, decoder_units)
  return hparam_cfgs
