# @package _global_

# to execute this experiment run:
# python train.py experiment=example

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

defaults:
  - override /energy: lj55
  - override /model/net: egnn

tags: ["LJ55", "CFM", "NLL_eval"]

seed: 12345

logger:
  wandb:
    tags: ${tags}
    group: "lj55"

# You need to fill energy.data_path_train at the command line with samples generated
# by running dem/eval.py on a checkpoint
energy:
  data_path_train: null

data:
  n_val_batches_per_epoch: 4

trainer:
  check_val_every_n_epoch: 1
  max_epochs: 2000

model:
  debug_use_train_data: true
  use_otcfm: false
  use_ema: false
  nll_with_cfm: true
  logz_with_cfm: true
  tol: 1e-3
  use_exact_likelihood: false
  net:
    n_particles: 55
    n_layers: 5
    hidden_nf: 128

  noise_schedule:
    _target_: dem.models.components.noise_schedules.GeometricNoiseSchedule
    sigma_min: 0.5
    sigma_max: 4

  partial_prior:
    _target_: dem.energies.base_prior.MeanFreePrior
    _partial_: true
    n_particles: 55
    spatial_dim: 3

  optimizer:
    lr: 1e-3

  lambda_weighter:
    _target_: dem.models.components.lambda_weighter.NoLambdaWeighter
    _partial_: true

  clipper:
    _target_: dem.models.components.clipper.Clipper
    should_clip_scores: True
    should_clip_log_rewards: False
    max_score_norm: 20
    min_log_reward: null

  diffusion_scale: 0.5

  num_init_samples: 1024
  num_samples_to_generate_per_epoch: 128
  num_samples_to_sample_from_buffer: 512
  num_samples_to_save: 1000
  eval_batch_size: 16

  init_from_prior: true

  nll_integration_method: dopri5

  negative_time: false
  num_negative_time_steps: 10

callbacks:
  model_checkpoint:
    monitor: "val/nll"
    save_top_k: -1
