_target_: src.models.cfm_module.CNFLitModule
_partial_: true

optimizer:
  _target_: torch.optim.AdamW
  _partial_: true
  lr: 0.001
  weight_decay: 0.01

net:
  _target_: src.models.components.simple_mlp.VelocityNet
  _partial_: true
  hidden_dims: [64, 64, 64]
  batch_norm: False
  activation: "tanh"

augmentations:
  _target_: src.models.components.augmentation.AugmentationModule
  cnf_estimator: "exact"
  l1_reg: 0.
  l2_reg: 0.
  squared_l2_reg: 1e-4
  jacobian_frobenius_reg: 1e-4
  jacobian_diag_frobenius_reg: 0.
  jacobian_off_diag_frobenius_reg: 0.

# Set to integer if want to train with left out timepoint
leaveout_timepoint: -1
