defaults:

  seed: 0
  method: name
  task: dummy_disc
  logdir: /dev/null
  replay_size: 1e6
  replay_online: False
  eval_dir: ''
  filter: '.*'
  tensorboard_videos: True

  jax:
    platform: gpu
    jit: True
    compute_dtype: float16
    param_dtype: float32
    prealloc: True
    checks: False
    logical_cpus: 0
    debug: False
    policy_devices: [0]
    train_devices: [0]
    sync_every: 10
    profiler: False  # True
    transfer_guard: True
    assert_num_devices: -1

  run:
    script: train
    steps: 1e10
    duration: 0
    num_envs: 4
    expl_until: 0
    log_every: 120
    save_every: 900
    eval_every: 1e6
    eval_initial: True
    eval_eps: 1
    eval_samples: 1
    train_ratio: 32.0
    train_fill: 0
    eval_fill: 0
    log_zeros: True
    log_keys_video: [image]
    log_keys_sum: '^$'
    log_keys_avg: '^$'
    log_keys_max: '^$'
    log_video_fps: 20
    log_video_streams: 4
    log_episode_timeout: 60
    from_checkpoint: ''
    actor_addr: 'tcp://localhost:{auto}'
    replay_addr: 'tcp://localhost:{auto}'
    logger_addr: 'tcp://localhost:{auto}'
    actor_batch: 32
    actor_threads: 4
    env_replica: -1
    ipv6: False
    usage: {psutil: True, nvsmi: True, gputil: False, malloc: False, gc: False}
    timer: True  # A bit slow but useful.
    driver_parallel: True

  wrapper: {length: 0, reset: True, discretize: 0, checks: False}
  env:
    atari: {size: [64, 64], repeat: 4, sticky: True, gray: True, actions: all, lives: unused, noops: 0, pooling: 2, aggregate: max, resize: pillow}
    crafter: {size: [64, 64], logs: False}
    atari100k: {size: [64, 64], repeat: 4, sticky: False, gray: False, actions: needed, lives: unused, noops: 30, resize: pillow}
    dmlab: {size: [64, 64], repeat: 4, episodic: True, use_seed: True}
    minecraft: {size: [64, 64], break_speed: 100.0, logs: False}
    dmc: {size: [64, 64], repeat: 2, camera: -1}
    loconav: {size: [64, 64], repeat: 2, camera: -1}

  # Agent
  task_behavior: Director
  expl_behavior: None
  batch_size: 16
  batch_length: 64
  random_agent: False

  # Director
  director_jointly: True
  train_skill_duration: 8
  env_skill_duration: 8
  goal_enc: {layers: 5, units: 1024, act: silu, norm: layer, dist: onehot, outscale: 1.0, inputs: [goal]}
  goal_dec: {layers: 5, units: 1024, act: silu, norm: layer, dist: mse, outscale: 0.1, inputs: [skill]}
  goal_opt: {opt: adam, lr: 1e-4, eps: 1e-6, clip: 100.0, wd: 1e-2, wd_pattern: 'kernel'}
  goal_kl_scale: 0.1  # 1.0
  goal_kl_free: 1.0
  skill_shape: [8, 8]
  manager_rews: {extr: 1.0, expl: 0.1, goal: 0.0}
  worker_rews: {extr: 0.0, expl: 0.0, goal: 1.0}
  worker_inputs: [deter, stoch, goal]
  worker_goals: [manager]
  worker_report_horizon: 64
  manager_actent: 3e-4
  worker_actent: 3e-4

  # World Model
  grad_heads: [decoder, reward, cont]
  rssm_type: rssm
  rssm: {deter: 6144, units: 1024, stoch: 32, classes: 32, act: silu, norm: layer, unimix: 0.01, unroll: False, bottleneck: 2048, winit: normal, fan: avg}
  encoder: {mlp_keys: '.*', cnn_keys: '.*', act: silu, norm: layer, mlp_layers: 5, mlp_units: 1024, cnn: resnet, cnn_depth: 96, cnn_blocks: 0, resize: stride, winit: normal, fan: avg, minres: 4, symlog: True}
  decoder: {mlp_keys: '.*', cnn_keys: '.*', act: silu, norm: layer, mlp_layers: 5, mlp_units: 1024, cnn: resnet, cnn_depth: 96, cnn_blocks: 0, cnn_dist: mse, mlp_dist: symlog_mse, inputs: [deter, stoch], resize: stride, winit: normal, fan: avg, outscale: 1.0, minres: 4, cnn_sigmoid: False}
  reward_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: symexp_twohot, outscale: 0.0, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255}
  cont_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: binary, outscale: 1.0, inputs: [deter, stoch], winit: normal, fan: avg}
  loss_scales: {dec_cnn: 1.0, dec_mlp: 1.0, reward: 1.0, cont: 1.0, dyn: 0.5, rep: 0.1, actor: 1.0, critic: 1.0, slowreg: 1.0}
  rssm_loss: {free: 1.0}
  model_opt: {opt: adam, lr: 1e-4, eps: 1e-8, clip: 1000.0, wd: 0.0, warmup: 0, adaclip: 0.0}

  # Actor Critic
  actor: {layers: 5, units: 1024, act: silu, norm: layer, minstd: 0.1, maxstd: 1.0, outscale: 1.0, unimix: 0.01, inputs: [deter, stoch], winit: normal, fan: avg}
  critic: {layers: 5, units: 1024, act: silu, norm: layer, dist: symexp_twohot, outscale: 0.0, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255}
  actor_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, adaclip: 0.0}
  critic_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, adaclip: 0.0}
  reward_scales: {extr: 1.0, disag: 100.0, expl: 100.0, goal: 1.0}
  actor_dist_disc: onehot
  actor_dist_cont: normal
  actor_grad_disc: reinforce
  actor_grad_cont: backprop
  critic_type: vfunction
  imag_horizon: 15
  imag_unroll: False
  imag_cont: mean  # mode
  horizon: 333
  return_lambda: 0.95
  critic_slowreg: logprob
  slow_critic_update: 1
  slow_critic_fraction: 0.02
  slow_critic_target: False
  retnorm: {impl: perc_ema, decay: 0.99, max: 1.0, perclo: 5.0, perchi: 95.0}
  actent: 3e-4

  # Exploration
  expl_rewards: {extr: 1.0, disag: 0.1}
  expl_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0}
  disag_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: mse, outscale: 1.0, inputs: [deter, stoch, action], winit: normal, fan: avg}
  disag_target: [stoch]
  disag_models: 8

minecraft:
  task: minecraft_diamond
  run:
    script: train_save
    num_envs: 16
    eval_fill: 1e5
    train_ratio: 16
    log_keys_max: '^log_inventory.*'
  encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|reward', cnn_keys: 'image'}
  decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image'}

dmlab:
  task: dmlab_explore_goal_locations_small
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}
  run:
    num_envs: 8
    train_ratio: 64

atari:
  task: atari_pong
  run:
    steps: 5.5e7
    eval_eps: 10
    num_envs: 8
    train_ratio: 64
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

atari100k:
  task: atari_pong
  run:
    script: train_eval
    steps: 1.5e5
    num_envs: 1
    eval_every: 1e5
    eval_initial: False
    eval_eps: 100
    train_ratio: 1024
  jax.precision: float32
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units$: 512
  actor_eval_sample: True
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

crafter:
  task: crafter_reward
  run:
    num_envs: 1
    log_keys_max: '^log_achievement_.*'
    log_keys_sum: '^log_reward$'
    log_video_fps: 10
  run.train_ratio: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

dmc_vision:
  task: dmc_walker_walk
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

dmc_proprio:
  task: dmc_walker_walk
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '.*', cnn_keys: '$^'}
  decoder: {mlp_keys: '.*', cnn_keys: '$^'}

bsuite:
  task: bsuite_mnist/0
  run:
    num_envs: 1
    script: train
    train_ratio: 1024  # 128 for cartpole
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512

loconav:
  task: loconav_ant_maze_m
    # env.loconav.repeat: 2
  env.loconav.repeat: 1
  run:
    train_ratio: 512
    log_keys_max: '^log_.*'
  encoder: {mlp_keys: '.*', cnn_keys: 'image'}
  decoder: {mlp_keys: '.*', cnn_keys: 'image'}

small:
  rssm.deter: 512
  rssm.bottleneck: -1
  .*\.cnn_depth: 32
  .*\.units: 512
  .*\.layers: 2

medium:
  rssm.deter: 1024
  rssm.bottleneck: -1
  .*\.cnn_depth: 48
  .*\.units: 640
  .*\.layers: 3

large:
  rssm.deter: 2048
  rssm.bottleneck: -1
  .*\.cnn_depth: 64
  .*\.units: 768
  .*\.layers: 4

xlarge:
  rssm.deter: 4096
  rssm.bottleneck: -1
  .*\.cnn_depth: 96
  .*\.units: 1024
  .*\.layers: 5

multicpu:
  jax:
    logical_cpus: 8
    policy_devices: [0, 1]
    train_devices: [2, 3, 4, 5, 6, 7]
  run:
    num_envs: 8
    actor_batch: 4
  batch_size: 12
  batch_length: 10

debug:

  jax: {debug: True, jit: True, prealloc: False, platform: cpu, profiler: False, checks: False}
  wrapper: {length: 100, checks: True}
  run: {num_envs: 4, eval_every: 1000, log_every: 5, save_every: 15, train_ratio: 8, actor_batch: 2, driver_parallel: False}
  batch_size: 8
  batch_length: 12
  replay_size: 1e5
  encoder.cnn_depth: 8
  decoder.cnn_depth: 8
  rssm: {deter: 32, bottleneck: 8, units: 16, stoch: 4, classes: 4}
  .*unroll: False
  .*\.layers: 2
  .*\.units: 16
  .*\.wd$: 0.0
