defaults:
  seed: 0
  method: name
  task: dummy_disc
  logdir: /dev/null
  replay: uniform
  replay_size: 1e6
  replay_online: False
  eval_dir: ""
  filter: ".*"

  jax:
    platform: gpu
    jit: True
    precision: float16
    prealloc: False
    debug_nans: False
    logical_cpus: 0
    debug: False
    policy_devices: [0]
    train_devices: [0]
    metrics_every: 10

  run:
    script: train
    steps: 1e8
    expl_until: 0
    log_every: 1000
    save_every: 2000
    eval_every: 1e4
    eval_initial: True
    eval_eps: 1
    eval_samples: 1
    train_ratio: 512
    train_fill: 0
    eval_fill: 0
    log_zeros: False
    log_keys_video: [none]
    log_keys_sum: "^$"
    log_keys_mean: "(log_entropy)"
    log_keys_max: "^$"
    from_checkpoint: ""
    sync_every: 10

  envs:
    {
      amount: 1,
      parallel: process,
      length: 0,
      reset: True,
    }
  wrapper: { length: 0, reset: True, discretize: 0, checks: False }

  # Agent
  task_behavior: Hansome
  expl_behavior: None
  batch_size: 16
  batch_length: 64
  data_loaders: 8

  # World Model
  grad_heads: [decoder, reward, cont, worker_reward]
  rssm:
    {
      deter: 4096,
      units: 1024,
      stoch: 32,
      classes: 32,
      act: silu,
      norm: layer,
      initial: learned,
      unimix: 0.01,
      unroll: False,
      action_clip: 1.0,
      winit: normal,
      fan: avg,
    }
  encoder:
    {
      cnn_keys: "none",
      mlp_keys: "none",
      act: silu,
      norm: layer,
      mlp_layers: 2,
      mlp_units: 16,
      cnn: resnet,
      cnn_depth: 96,
      cnn_blocks: 0,
      resize: stride,
      winit: normal,
      fan: avg,
      symlog_inputs: True,
      minres: 4,
    }
  decoder:
    {
      cnn_keys: "none",
      mlp_keys: "none",
      act: silu,
      norm: layer,
      mlp_layers: 5,
      mlp_units: 1024,
      cnn: resnet,
      cnn_depth: 96,
      cnn_blocks: 0,
      image_dist: mse,
      vector_dist: symlog_mse,
      inputs: [deter, stoch],
      resize: stride,
      winit: normal,
      fan: avg,
      outscale: 1.0,
      minres: 4,
      cnn_sigmoid: False,
    }
  reward_head:
    {
      layers: 5,
      units: 1024,
      act: silu,
      norm: layer,
      dist: symlog_disc,
      outscale: 0.0,
      outnorm: False,
      inputs: [deter, stoch],
      winit: normal,
      fan: avg,
      bins: 255,
    }
  cont_head:
    {
      layers: 5,
      units: 1024,
      act: silu,
      norm: layer,
      dist: binary,
      outscale: 1.0,
      outnorm: False,
      inputs: [deter, stoch],
      winit: normal,
      fan: avg,
    }
  loss_scales:
    {
      image: 1.0,
      vector: 1.0,
      reward: 1.0,
      worker_reward: 1.0,
      cont: 1.0,
      dyn: 0.5,
      rep: 0.1,
      actor: 1.0,
      critic: 1.0,
      slowreg: 1.0,
    }
  dyn_loss: { impl: kl, free: 1.0 }
  rep_loss: { impl: kl, free: 1.0 }
  model_opt:
    {
      opt: adam,
      lr: 1e-4,
      eps: 1e-8,
      clip: 1000.0,
      wd: 0.0,
      warmup: 0,
      lateclip: 0.0,
    }

  # Actor Critic
  actor:
    {
      layers: 5,
      units: 1024,
      act: silu,
      norm: layer,
      minstd: 0.1,
      maxstd: 1.0,
      outscale: 1.0,
      outnorm: False,
      unimix: 0.01,
      inputs: [deter, stoch],
      winit: normal,
      fan: avg,
      symlog_inputs: False,
    }
  critic:
    {
      layers: 5,
      units: 1024,
      act: silu,
      norm: layer,
      dist: symlog_disc,
      outscale: 0.0,
      outnorm: False,
      inputs: [deter, stoch],
      winit: normal,
      fan: avg,
      bins: 255,
      symlog_inputs: False,
    }
  actor_opt:
    {
      opt: adam,
      lr: 3e-5,
      eps: 1e-5,
      clip: 100.0,
      wd: 0.0,
      warmup: 0,
      lateclip: 0.0,
    }
  critic_opt:
    {
      opt: adam,
      lr: 3e-5,
      eps: 1e-5,
      clip: 100.0,
      wd: 0.0,
      warmup: 0,
      lateclip: 0.0,
    }
  actor_dist_disc: onehot
  actor_dist_cont: normal
  actor_grad_disc: reinforce
  actor_grad_cont: backprop
  critic_type: vfunction
  imag_horizon: 15
  imag_unroll: False
  horizon: 333
  return_lambda: 0.95
  critic_slowreg: logprob
  slow_critic_update: 1
  slow_critic_fraction: 0.02
  retnorm: { impl: perc_ema, decay: 0.99, max: 1.0, perclo: 5.0, perchi: 95.0 }
  actent: 3e-4

  # Exploration
  expl_rewards: { extr: 1.0, disag: 0.1 }
  expl_opt: { opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0 }
  disag_head:
    {
      layers: 5,
      units: 1024,
      act: silu,
      norm: layer,
      dist: mse,
      outscale: 1.0,
      inputs: [deter, stoch, action],
      winit: normal,
      fan: avg,
    }
  disag_target: [stoch]
  disag_models: 8

  # Replay Configuration
  replay_hyper:
    {
      initial_priority: 1e5,
      c: 1e4,
      beta: 0.7,
      epsilon: 0.01,
      alpha: 0.7,
      key_find_priority: 1e7,
    }

xsmall:
  rssm.deter: 64
  .*\.cnn_depth: 4
  .*\.units: 64
  .*\.layers: 1

small:
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.units: 512
  .*\.layers: 2

medium:
  rssm.deter: 1024
  .*\.cnn_depth: 48
  .*\.units: 640
  .*\.layers: 3

large:
  rssm.deter: 2048
  .*\.cnn_depth: 64
  .*\.units: 768
  .*\.layers: 4

xlarge:
  rssm.deter: 4096
  .*\.cnn_depth: 96
  .*\.units: 1024
  .*\.layers: 5
