_target_: dem.models.pis_module.PISLitModule

nll_integration_method: dopri5
use_richardsons: false
logz_with_cfm: false

optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: 1e-3
  weight_decay: 1e-7

scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _partial_: true
  mode: min
  factor: 0.1
  patience: 10

# net:
#   _target_: dem.models.components.mlp.FourierMLP
#   _partial_: true
#   num_layers: 2
#   channels: 64
#   in_shape: ${energy.dimensionality}
#   out_shape: ${energy.dimensionality}

tnet:
  _target_: dem.models.components.mlp.TimeConder
  _partial_: true
  channel: 64
  out_dim: 1
  num_layers: 3

defaults:
  - net:
      - pis_mlp
  - noise_schedule:
      - geometric

buffer:
  _target_: dem.models.components.prioritised_replay_buffer.SimpleBuffer
  dim: ${energy.dimensionality}
  max_length: 10000
  min_sample_length: 1000
  initial_sampler: null
  device: ${trainer.accelerator}
  sample_with_replacement: True
  fill_buffer_during_init: False
  prioritize: False

score_scaler: null

num_init_samples: 1024
num_estimator_mc_samples: 100
num_samples_to_generate_per_epoch: 1024
num_samples_to_sample_from_buffer: 512
eval_batch_size: 1024

num_integration_steps: 100

nll_with_cfm: false
cfm_sigma: 0.0
cfm_prior_std: ${energy.data_normalization_factor}
prioritize_cfm_training_samples: false

lr_scheduler_update_frequency: ${trainer.check_val_every_n_epoch}

input_scaling_factor: null
output_scaling_factor: null

# compile model for faster training with pytorch 2.0
compile: false

cfm_loss_weight: 1.0
use_ema: false
debug_use_train_data: false

pis_scale: 1.
time_range: 1.
num_samples_to_save: 100000
