# Sequential Controlled Langevin Diffusions (SCLD)
name: scld
num_steps: 128  # Total number of steps / bridges. Defines the number of steps per sub-trajectory as n_steps / n_steps_per_traj
n_sub_traj: 1 # 1 means no subtrajectories
batch_size: 2000
init_std: ${target.cmcd.initial_scale}  # Standard deviation of the prior distribution
max_diffusion: ${target.cmcd.max_diffusion} # the max_diffusion supplied to noise schedule

grad_clip: 1.  # Value for L2 Gradient clipping. If negative, no gradient clipping is applied
target_clip: -1  # Clips the value of the gradient of the target used in Langevin dynamics
langevin_norm_clip: 1000000 # Something like 100 * sqrt(num_dims) probably reasonable 
# Note: if using KL should set subtraj_buffer to False and use buffer size 1
loss: "rev_lv"  # Choose between [rev_kl, fwd_kl, rev_lv, fwd_lv, rev_tb, fwd_tb]

prior:
  learn_variance: True
  learn_mean: True
  lr: ${target.cmcd.step_size}

subtraj_loss_weighting:
  pinns_eps: 0 # = 0 means no pinns training scheme. 1 is good default if using training

step_size: ${target.cmcd.step_size}   # Learning rate or peak_lr if scheduling

# https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.warmup_exponential_decay_schedule
use_warmup: False
num_warmup_steps: 1000
initial_lr: 1e-4

use_decay: False
num_steps_before_start_decay: 1000
decay_factor_per_thousand: 0.5
final_lr: 1e-5


# parameters specific to TB loss
use_pseudohuber: False
use_jensen_trick: False # if True, use improved estimator for lnZ
pseudo_huber_delta: 10 # if > 0, replace squared loss in tb with pseudohuber loss w/ given delta
logZ_step_size: 0.05 # Learning rate for updating log Z for second-moment loss
init_logZ: 0. # Initial value for learnable log Z for second-moment loss
leak_true_lnZ: False # we cheat: only really works if not using subtraj as you dont know lnZ of intermediate dists
true_lnZ: 0 #${target.fn.log_Z} if available



n_sim: 8000  # Outer loop iterations: Simulates the SDE + MCMC and puts samples in the buffer
n_updates_per_sim: 1  # Inner loop iterations: Uses samples from the buffer to optimze the model on sub-trajectories

use_resampling: False  # Flag whether to use resampling or not at train
use_resampling_inference: True # Flag whether to use resampling at Inference time
resample_threshold: 0.3  # Threshold for resampling
resampler: # Type of resampling scheme. Choose between [multinomial, systematic]
  _target_: algorithms.scld.resampling.get_resampler
  identifier: multinomial

use_markov: False # Flag whether to use MCMC or not at train
use_markov_inference: True # Flag whether to use MCMC at Inference time

buffer:
  use_subtraj_buffer: True
  max_length_in_batches: 1. # Maximum length of buffer in batches. Setting this to 1 corresponds to not using a buffer
  min_length_in_batches: 1. # Can be ignored
  sample_with_replacement: False
  prioritized: True # toggling this doesn't do anything atm
  sampling_scheme: "vanilla" # New or vanilla
  update_weights: False # The FAB paper performs weight updates so we also implement this as an option
  temperature: 1
defaults:
  - model: pisgrad_net  # Parameterized model
  - noise_schedule: cosine # Scheduler for the diffusion coefficient. Choose between [const, linear, cosine]
  - mcmc: hmc # MCMC transition kernel. Choose between [hmc, mh]

model_detach_langevin: True
model:
  bias_init: 0.  # Initialization of the last layers' bias of the time-dependent network
  weight_init: 1e-8  # Initialization of the last layers' weights of the time-dependent network

learn_max_diffusion: False
noise_schedule:
  reverse: False  # Ensures correct that noise scheduler goes in the right direction (time-wise)

annealing_schedule:
  schedule_type: "learnt" # uniform or cosine OR learnt
  schedule_lr: 0.01 # I am Speed

plot_subtrajs: False # its pretty expensive for WandB to do this
plot_ode: False
sweep_mode: True