"""Default Hyperparameter configuration."""

import ml_collections


def d(**kwargs):
  """Helper of creating a config dict."""
  return ml_collections.ConfigDict(initial_dictionary=kwargs)


seed = 0


def get_config():
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.train_data_config = d(
    seed=seed,
    data_spec="basenji",
    data_dim=2,
    batchsize=8,
    fixed_batch=False
  )
  config.eval_data_config = dict(reuse_train=True)
  config.train_eval_config = d(
    num_train_steps=15000,
    substeps=1,
    num_eval_steps=100,
    log_metrics_every_steps=100,
    log_imgs_every_steps=100,
    checkpoint_every_steps=int(1e10),
    eval_every_steps=200,
    seed=seed,
    lr_decay_last_steps_ratio=0.1,
    # warm_start="",
  )

  config.model_config = d(
    # rd_lambda=2.0,
    rd_lambda=100.0,
    distort_type='mse',
    nu_support_size=1000,
  )
  config.optimizer_config = d(
    name='adam',
    args=dict(),  # b1, b2, etc.
    learning_rate=1e-3,
    lr_decay=True,
    decay_type='inv_sqrt',
    decay_factor=0.1,
    # gradient_clip_norm=1.0,
  )
  config.ckpt_restore_dir = 'None'
  return config


def get_cfg_str(config):
  from collections import OrderedDict
  runname_dict = OrderedDict()
  runname_dict['idim'] = config.train_data_config.data_dim
  runname_dict['n'] = config.model_config.nu_support_size
  runname_dict['rd_lambda'] = config.model_config.rd_lambda

  from common import utils
  return utils.config_dict_to_str(runname_dict, skip_falsy=False)


def get_hyper():
  """
  Produce a list of flattened dicts, each containing a hparam configuration overriding the one in
  get_config(), corresponding to one hparam trial/experiment/work unit.
  :return:
  """
  from common import hyper
  # idims = hyper.sweep("train_data_config.data_dim", [2])
  rd_lambdas = [16000, 8000, 4000, 1000, 500,
                300]  # From /home/yiboyang/projects/tfc-models/results/gan/rdub-d=2_4-datasets=basenji-method=all.pkl
  rd_lambdas = hyper.sweep('model_config.rd_lambda', rd_lambdas)
  ns = [5000, 1000]
  ns = hyper.sweep("model_config.nu_support_size", ns)

  hparam_cfgs = hyper.product(ns, rd_lambdas)
  return hparam_cfgs
