# @package _global_

# to execute this experiment run:
# python train.py experiment=example

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["LJ55", "PIS"]

seed: 12345

logger:
  wandb:
    tags: ${tags}
    group: "lj55"

defaults:
  - override /energy: lj55
  - override /model/net: egnn

data:
  n_val_batches_per_epoch: 4

trainer:
  check_val_every_n_epoch: 5
  max_epochs: 2000

model:
  net:
    n_particles: 55
    n_layers: 5
    hidden_nf: 128

  noise_schedule:
    _target_: dem.models.components.noise_schedules.GeometricNoiseSchedule
    sigma_min: 0.5
    sigma_max: 4

  partial_prior:
    _target_: dem.energies.base_prior.MeanFreePrior
    _partial_: true
    n_particles: 55
    spatial_dim: 3

  lambda_weighter:
    _target_: dem.models.components.lambda_weighter.NoLambdaWeighter
    _partial_: true

  clipper:
    _target_: dem.models.components.clipper.Clipper
    should_clip_scores: True
    should_clip_log_rewards: False
    max_score_norm: 20
    min_log_reward: null

  diffusion_scale: 0.5

  num_init_samples: 1024
  num_samples_to_generate_per_epoch: 128
  num_samples_to_sample_from_buffer: 128
  eval_batch_size: 128

  init_from_prior: true
  num_samples_to_save: 10000

  nll_integration_method: dopri5

  negative_time: True
  num_negative_time_steps: 10

callbacks:
  model_checkpoint:
    monitor: "val/2-Wasserstein"
    save_top_k: -1
