env_name: bigfish
exp_name: ~ # filled in by the launch.sh script (consistent across stages 1-3)
stage_exp_name: ~ # updated in stages 1-3, used for wandb logging as a suffix
seed: 11
test_every: 200
rollout_policy_every: 500

# configure model architecture
model:
  wm_scale: 24
  idm_impala_scale: 4
  policy_impala_scale: 4
  decoder_hidden_sizes: [192, 128, 64]
  vq:
    enabled: true
    num_codebooks: 2
    num_discrete_latents: 4
    emb_dim: 16
    num_embs: 64 # each of the vq_latent_dim latents is a categorical variable with vq_num_embs values
    commitment_cost: 0.05
    decay: 0.999

# stage-1 (latent idm & wm training) hyperparameters
lapo_stage1:
  lr: 3e-4
  bs: 128
  steps: 50_000

# stage-2 (latent behavior cloning) hyperparameters (Standard LAPO)
lapo_stage2:
  lr: 2e-4
  bs: 128
  steps: 60_000

# stage-3 (bc decoding) hyperparameters (Standard LAPO BC)
lapo_stage3:
  steps: 10_000
  lr: 2e-4
  bs: 128
  n_observed_samples: 4_000
  freeze_backbone: true
  num_envs: 64 # for eval only
  gamma: 0.999 # for eval only
  bc_only: false # no lapo pretraining, just bc from scratch

# stage-2 (idm decoding) hyperparameters. (LAPO+ IDM Decoding)
lapo_plus_stage2:
  lr: 2e-4
  bs: 128
  steps: 10_000
  n_observed_samples: 4_000
  freeze_backbone: true
  idm_only: false # no lapo pretraining, just idm from scratch

# stage-3 (policy from idm labels) hyperparameters (LAPO+ Policy from IDM labels)
lapo_plus_stage3:
  steps: 60_000
  lr: 2e-4
  bs: 128
  num_envs: 64 # for eval only
  gamma: 0.999 # for eval only
  idm_only: false # no lapo pretraining, just idm from scratch