_target_: dem.models.dem_module.DEMLitModule

optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: 0.001
  weight_decay: 0.0

scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _partial_: true
  mode: min
  factor: 0.1
  patience: 10

defaults:
  - net:
      - mlp
  - noise_schedule:
      - geometric

buffer:
  _target_: dem.models.components.prioritised_replay_buffer.SimpleBuffer
  dim: ${energy.dimensionality}
  max_length: 10000
  min_sample_length: 1000
  initial_sampler: null
  # device: auto
  sample_with_replacement: True
  fill_buffer_during_init: False
  prioritize: False

score_scaler: null

test_mode: 'test' 
num_init_samples: 1024
num_estimator_mc_samples: 100
num_samples_to_generate_per_epoch: 1024
num_samples_to_sample_from_buffer: 512
use_snis: true
num_samples_to_snis: 10
num_samples_to_snis_numerator: 10
num_samples_to_snis_denominator: 10
eval_batch_size: 1024

num_integration_steps: 1000
nll_integration_method: dopri5
tol: 1e-5

nll_with_cfm: false
nll_with_dem: false
nll_on_buffer: false
# compute the nll on the train data
# this is in addition to buffer and test
compute_nll_on_train_data: false
logz_with_cfm: false

cfm_sigma: 0.0
cfm_prior_std: ${energy.data_normalization_factor}
use_otcfm: false
prioritize_cfm_training_samples: false

lr_scheduler_update_frequency: ${trainer.check_val_every_n_epoch}

input_scaling_factor: null
output_scaling_factor: null

# compile model for faster training with pytorch 2.0
compile: false

use_richardsons: false

cfm_loss_weight: 1.0
use_ema: false
use_exact_likelihood: true

# train cfm only on train data and not dem
debug_use_train_data: false

# initialize the buffer with samples from the prior
init_from_prior: false

# set to true for iDEM and false for pDEM
use_buffer: true

# number of samples to save at the end of training
num_samples_to_save: 100000

negative_time: false
num_negative_time_steps: 100

nll_batch_size: 256

seed: ${seed}
