"""Default NAMM config."""

import ml_collections

def get_config():
  config = ml_collections.ConfigDict()

  data = config.data = ml_collections.ConfigDict()
  # data.image_size = 32
  data.height = 32
  data.width = 32
  data.num_channels = 1
  data.dataset = ''
  data.random_flip = False
  data.random_rotation = False
  data.random_zoom = False
  data.uniform_dequantization = False
  data.centered = False
  data.dataset = ''
  data.tfds_dir = '/tmp/bfeng/tensorflow_datasets'
  data.antialias = True
  data.taper = False
  data.taper_frac_radius_min = 0.1
  data.taper_frac_radius_max = 0.65
  data.taper_gaussian_blur_sigma = 2.
  data.constant_flux = False
  data.total_flux = 120.  # (for 64x64)
  data.depth = 0  # depth > 0 means 3D data
  data.num_kolmogorov_states = 8
  data.num_kolmogorov_states_per_row = 4
  data.kolmogorov_representation = 'image'  # image | volume

  constraint = config.constraint = ml_collections.ConfigDict()
  constraint.type = 'flux'
  constraint.total_flux = 120.
  constraint.reynolds = 1000.
  constraint.inner_steps = 20
  constraint.kolmogorov_dt = 0.01
  constraint.kolmogorov_forcing = False
  constraint.kolmogorov_t0 = 3
  constraint.burgers_t0 = 0
  constraint.burgers_dt = 0.025
  constraint.burgers_inner_steps = 5

  model = config.model = ml_collections.ConfigDict()
  model.fwd_n_filters = 64
  model.bwd_n_filters = 64
  model.n_res_blocks = 6
  model.dropout_rate = 0.5
  model.n_downsample_layers = 2
  model.upsample_mode = 'deconv'
  # model.residual = False  # if `True`, ResNet estimates residual
  model.fwd_residual = True
  model.bwd_residual = False
  model.fwd_network = 'icnn'
  model.bwd_network = 'resnet'
  model.fwd_activation = 'none'
  model.bwd_activation = 'softplus'
  model.fwd_strong_convexity = 0.9
  model.bwd_strong_convexity = 0.1
  model.fwd_icnn_n_filters = 32
  model.bwd_icnn_n_filters = 64
  model.fwd_icnn_n_layers = 3
  model.bwd_icnn_n_layers = 5
  model.fwd_icnn_kernel_size = 3
  model.bwd_icnn_kernel_size = 3
  model.ema_rate = 0.999

  optim = config.optim = ml_collections.ConfigDict()
  optim.mdm_finetune = False  # whether to fine-tune with dual samples from MDM instead of perturbed G(x)
  optim.grad_clip = -1.  # negative value means no clipping
  optim.learning_rate = 2e-4
  optim.zero_nans = False
  optim.adam_beta1 = 0.5
  optim.cycle_weight = 1.
  optim.regularization_weight = 1e-3
  optim.dsm_weight = 1e-3
  optim.constraint_weight = 0.001  # max weight if annealing constraint
  # optim.anneal_constraint = False
  # optim.constraint_init_weight = 1e-5
  # optim.constraint_annealing_pivot = 50000
  optim.regularization = 'sparse_icnn'
  optim.max_sigma = 0.1  # maximum perturbation level (i.e., noise std. dev. or magnitude of diffusion gamma)
  optim.perturb_type = 'noise'  # noise | diffusion
  optim.fixed_sigma = False
  optim.perturb_length_scale = 0.5
  optim.divergence_weight = 0.1

  training = config.training = ml_collections.ConfigDict()
  training.batch_size = 16
  training.n_epochs = 50
  training.log_freq = 100
  training.snapshot_epoch_freq = 5
  training.ckpt_epoch_freq = 50
  training.early_stop = False

  config.eval = evaluation = ml_collections.ConfigDict()
  evaluation.batch_size = 16

  config.seed = 42
  return config