# Generative Flow Networks (GFN) with Trajectory Balance (TB) loss
name: gfn_subtb
step_size: ${target.gfn_subtb.step_size}
batch_size: ${target.all.batch_size}
iters: ${target.all.iters}
num_steps: ${target.all.num_steps}
grad_clip: 1.
reference_process: ${target.gfn_subtb.reference_process}  # ou or pinned_brownian
init_std: ${target.gfn_subtb.init_std}  # for ou
max_diffusion: ${target.gfn_subtb.max_diffusion}  # for pinned_brownian or ou
noise_scale: 6.  # for ou_dds
n_chunks: ${target.gfn_subtb.n_chunks}
logflow_step_size: ${target.gfn_subtb.logflow_step_size}
partial_energy: True
beta_schedule: learnt  # learnt, linear, cosine
beta_step_size: ${target.gfn_subtb.beta_step_size}
init_logZ: ${target.gfn_subtb.init_logZ}  # for ou or ou_dds
logZ_step_size: ${target.gfn_subtb.logZ_step_size}  # for ou or ou_dds
init_invtemp: 1.
loss_type: tb_subtb  # subtb, tb_subtb, lv_subtb
subtb_weight: 1.0  # used if loss_type is in [tb_subtb, lv_subtb]
logr_clip: -1e5

# Learning rate scheduler
lr_schedule:
  type: multistep  # multistep | cosine | constant
  milestones: ${target.all.milestones}  # iteration indices at which to decay
  gamma: 0.3  # multiply LR by gamma at each milestone

defaults:
  - model: pisgrad_net
  - noise_schedule: const  # const for pinned_brownian; this is not used for ou_dds 

model:
  learn_flow: True
  share_embeddings: False
  use_lp: ${target.all.use_lp}
  num_hid: ${target.all.num_hid}
  flow_num_hid: ${target.gfn_subtb.flow_num_hid}
  weight_init: 1e-8  # Initialization of the last layers' weights of the time-dependent network
  bias_init: 0.1  # Initialization of the last layers' bias of the time-dependent network

noise_schedule:
  reverse: False

buffer:
  use: True
  max_length_in_batches: 100
  prioritize_by: "piw"  # none, reward, loss, uiw, piw
  target_ess: 0.05
  sampling_method: "systematic"  # multinomial, stratified, systematic, rank
  rank_k: 0.01  # only used if sampling_method is rank
  sample_with_replacement: True
  prefill_steps: 100  # collect `prefill_steps` batches before starting training
  bwd_to_fwd_ratio: 2  # number of bwd steps per fwd step
  update_score: False