defaults:

  seed: 0
  method: name
  task: dummy_disc
  logdir: ./logdir/dreamerv3-jax_pushing7x7/run_monolithic-pretrain_0
  replay: uniform
  replay_size: 1e6
  replay_online: False
  eval_dir: ''
  filter: '.*'

  wandb:
    project: dreamerv3-jax_pushing7x7
    name: run_monolithic-pretrain_0
    group: monolithic-pretrain
    run_id: null
    dir: ./wandb/dreamerv3-jax_pushing7x7

  jax:
    platform: gpu
    jit: True
    precision: float16
    prealloc: True
    debug_nans: False
    logical_cpus: 0
    debug: False
    policy_devices: [0]
    train_devices: [0]
    metrics_every: 10

  run:
    script: train
    steps: 1e10
    expl_until: 0
    log_every: 30
    log_video_every: 5000
    save_every: 3600
    eval_every: 10000
    eval_initial: True
    eval_eps: 30
    eval_samples: 1
    train_ratio: 512.0
    train_fill: 0
    eval_fill: 0
    prefill_steps: 100
    log_zeros: False
    log_keys_video: [image]
    log_keys_sum: '^$'
    log_keys_mean: '(log_entropy)'
    log_keys_max: '^$'
    from_checkpoint: ''
    sync_every: 10
    # actor_addr: 'tcp://127.0.0.1:5551'
    actor_addr: 'ipc:///tmp/5551'
    actor_batch: 32

  envs: {amount: 4, parallel: process, length: 0, reset: True, restart: True, discretize: 0, checks: False}
  wrapper: {length: 0, reset: True, discretize: 0, checks: False}
  env:
    atari: {size: [64, 64], repeat: 4, sticky: True, gray: False, actions: all, lives: unused, noops: 0, resize: opencv}
    dmlab: {size: [64, 64], repeat: 4, episodic: True}
    minecraft: {size: [64, 64], break_speed: 100.0}
    dmc: {size: [64, 64], repeat: 2, camera: -1}
    loconav: {size: [64, 64], repeat: 2, camera: -1}

  # Agent
  task_behavior: Greedy
  expl_behavior: None
  batch_size: 16
  batch_length: 64
  data_loaders: 8

  # World Model
  grad_heads: [decoder, reward, cont]
  rssm: {deter: 4096, units: 1024, stoch: 32, classes: 32, act: silu, norm: layer, initial: learned, unimix: 0.01, unroll: False, action_clip: 1.0, winit: normal, fan: avg}
  encoder: {mlp_keys: '$^', cnn_keys: 'image', act: silu, norm: layer, mlp_layers: 5, mlp_units: 1024, cnn: resnet, cnn_depth: 96, cnn_blocks: 0, resize: stride, winit: normal, fan: avg, symlog_inputs: True, minres: 4}
  decoder: {mlp_keys: '$^', cnn_keys: 'image', act: silu, norm: layer, mlp_layers: 5, mlp_units: 1024, cnn: resnet, cnn_depth: 96, cnn_blocks: 0, image_dist: mse, vector_dist: symlog_mse, inputs: [deter, stoch], resize: stride, winit: normal, fan: avg, outscale: 1.0, minres: 4, cnn_sigmoid: False}
  reward_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: symlog_disc, outscale: 0.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255}
  cont_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: binary, outscale: 1.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg}
  loss_scales: {image: 1.0, vector: 1.0, reward: 1.0, cont: 1.0, dyn: 0.5, rep: 0.1, actor: 1.0, critic: 1.0, slowreg: 1.0}
  dyn_loss: {impl: kl, free: 1.0}
  rep_loss: {impl: kl, free: 1.0}
  model_opt: {opt: adam, lr: 1e-4, eps: 1e-8, clip: 1000.0, wd: 0.0, warmup: 0, lateclip: 0.0, frozen_param_prefixes: ['agent/wm/enc']}

  # Actor Critic
  actor: {layers: 5, units: 1024, act: silu, norm: layer, minstd: 0.1, maxstd: 1.0, outscale: 1.0, outnorm: False, unimix: 0.01, inputs: [deter, stoch], winit: normal, fan: avg, symlog_inputs: False}
  critic: {layers: 5, units: 1024, act: silu, norm: layer, dist: symlog_disc, outscale: 0.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255, symlog_inputs: False}
  actor_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, lateclip: 0.0}
  critic_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, lateclip: 0.0}
  actor_dist_disc: onehot
  actor_dist_cont: normal
  actor_grad_disc: reinforce
  actor_grad_cont: backprop
  critic_type: vfunction
  imag_horizon: 15
  imag_unroll: False
  horizon: 333
  return_lambda: 0.95
  critic_slowreg: logprob
  slow_critic_update: 1
  slow_critic_fraction: 0.02
  retnorm: {impl: perc_ema, decay: 0.99, max: 1.0, perclo: 5.0, perchi: 95.0}
  actent: 3e-4

  # Exploration
  expl_rewards: {extr: 1.0, disag: 0.1}
  expl_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0}
  disag_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: mse, outscale: 1.0, inputs: [deter, stoch, action], winit: normal, fan: avg}
  disag_target: [stoch]
  disag_models: 8

minecraft:

  task: minecraft_diamond
  envs.amount: 16
  run:
    script: train_save
    eval_fill: 1e5
    train_ratio: 16
    log_keys_max: '^log_inventory.*'
  encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|reward', cnn_keys: 'image'}
  decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image'}

dmlab:

  task: dmlab_explore_goal_locations_small
  envs.amount: 8
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}
  run.train_ratio: 64

atari:

  task: atari_pong
  envs.amount: 8
  run:
    steps: 5.5e7
    eval_eps: 10
    train_ratio: 64
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

atari100k:

  task: atari_pong
  envs: {amount: 1}
  env.atari: {gray: False, repeat: 4, sticky: False, noops: 30, actions: needed}
  run:
    script: train_eval
    steps: 1.5e5
    eval_every: 1e5
    eval_initial: False
    eval_eps: 100
    train_ratio: 1024
  jax.precision: float32
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units$: 512
  actor_eval_sample: True
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

crafter:

  task: crafter_reward
  envs.amount: 1
  run:
    log_keys_max: '^log_achievement_.*'
    log_keys_sum: '^log_reward$'
  run.train_ratio: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

dmc_vision:

  task: dmc_walker_walk
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

dmc_proprio:

  task: dmc_walker_walk
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '.*', cnn_keys: '$^'}
  decoder: {mlp_keys: '.*', cnn_keys: '$^'}

bsuite:

  task: bsuite_mnist/0
  envs: {amount: 1, parallel: none}
  run:
    script: train
    train_ratio: 1024  # 128 for cartpole
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512

loconav:

  task: loconav_ant_maze_m
  env.loconav.repeat: 2
  run:
    train_ratio: 512
    log_keys_max: '^log_.*'
  encoder: {mlp_keys: '.*', cnn_keys: 'image'}
  decoder: {mlp_keys: '.*', cnn_keys: 'image'}

small:
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.units: 512
  .*\.layers: 2

medium:
  rssm.deter: 1024
  .*\.cnn_depth: 48
  .*\.units: 640
  .*\.layers: 3

large:
  rssm.deter: 2048
  .*\.cnn_depth: 64
  .*\.units: 768
  .*\.layers: 4

xlarge:
  rssm.deter: 4096
  .*\.cnn_depth: 96
  .*\.units: 1024
  .*\.layers: 5

multicpu:

  jax:
    logical_cpus: 8
    policy_devices: [0, 1]
    train_devices: [2, 3, 4, 5, 6, 7]
  run:
    actor_batch: 4
  envs:
    amount: 8
  batch_size: 12
  batch_length: 10

debug:

  jax: {jit: True, prealloc: False, debug: True, platform: cpu}
  envs: {restart: False, amount: 3}
  wrapper: {length: 100, checks: True}
  run:
    eval_every: 1000
    log_every: 5
    save_every: 10
    train_ratio: 32
    actor_batch: 2
  batch_size: 8
  batch_length: 12
  replay_size: 1e5
  encoder.cnn_depth: 8
  decoder.cnn_depth: 8
  rssm: {deter: 32, units: 16, stoch: 4, classes: 4}
  .*unroll: False
  .*\.layers: 2
  .*\.units: 16
  .*\.wd$: 0.0
