import ml_collections

def get_config():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  training.batch_size = 256
  training.sde = 'rfsde'
  training.n_iters = 500001
  training.snapshot_freq = 50000
  training.log_freq = 50
  training.eval_freq = 100
  ## store additional checkpoints for preemption in cloud computing environments
  # Snapshot arguments
  training.snapshot_sampling = True
  training.snapshot_freq_for_preemption = 10000
  training.zero_snapshot = True
  training.snapshot_fid_sample = 5000 # TODO: Not FID, but just for more evaluation.
  training.snapshot_save_freq = 50000
  ## produce samples at each snapshot.
  training.likelihood_weighting = False
  training.continuous = True
  training.n_jitted_steps = 5
  training.reduce_mean = True
  training.n_epochs = 20

  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.noise_removal = True
  sampling.probability_flow = False
  sampling.timestep_style = 'uniform'
  sampling.method = 'pc'
  sampling.predictor = 'rf_solver'
  sampling.corrector = 'none'
  sampling.tol = 1e-5

  # evaluation
  config.eval = evaluate = ml_collections.ConfigDict()
  evaluate.begin_step = 20000
  evaluate.end_step = 100000
  evaluate.enable_sampling = True
  evaluate.num_samples = 50000
  evaluate.enable_loss = False
  evaluate.enable_bpd = False
  evaluate.bpd_dataset = 'test'
  evaluate.save_trajectory = False
  evaluate.batch_size = 2048
  evaluate.num_scales = 250

  # data
  config.data = data = ml_collections.ConfigDict()
  data.include_charges = True
  data.augment_noise = 0
  data.conditioning = tuple()
  data.dataset = 'qm9'
  data.remove_h = False
  data.norm_values = (1.0, 1.0, 1.0)
  data.norm_biases = (None, 0.0, 0.0)
  data.cat_loss_step = -1
  data.on_hold_batch = -1
  data.condition_time = True

  # model
  config.model = model = ml_collections.ConfigDict()
  model = config.model
  model.name = 'egnn'
  model.initial_count = 1
  # architectural parameters
  model.scale_by_sigma = False
  model.ema_rate = 0.9999 # Default 0.9999
  # model.normalization = 'GroupNorm'
  # model.nonlinearity = 'swish'
  # model.nf = 128
  # model.ch_mult = (1, 2, 2, 2)
  # model.num_res_blocks = 4
  # model.attn_resolutions = (16,)
  # model.resamp_with_conv = True
  # model.conditional = True
  # model.fir = False
  # model.fir_kernel = [1, 3, 3, 1]
  # model.skip_rescale = True
  # model.resblock_type = 'biggan'
  # model.progressive = 'none'
  # model.progressive_input = 'none'
  # model.progressive_combine = 'sum'
  # model.attention_type = 'ddpm'
  # model.embedding_type = 'positional'
  # model.init_scale = 0.
  # model.fourier_scale = 16
  # model.conv_size = 3
  # model.dropout = 0.15
  model.variable_ema_rate = False

  # EGNN parameters
  model.n_layers = 6
  model.inv_sublayers = 1
  model.nf = 128
  model.tanh = True
  model.attention = True
  model.norm_constant = 1
  model.sin_embedding = False
  model.norm_diff = True
  model.condition_time = True
  model.include_charges = True

  # anonymous-repo parameters
  model.lambda1 = 1.0
  model.lambda2 = 1.0
  model.aug_dim = 0

  # optimization
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.warmup = 0
  optim.grad_clip = 1.

  config.seed = 42

  return config