defaults:

  seed: 0
  method: name
  task: dummy_disc
  logdir: ./logdir/
  replay: uniform
  replay_size: 1e6
  replay_online: False
  eval_dir: ''
  filter: '.*'

  jax:
    platform: gpu
    jit: True
    precision: float16
    prealloc: True
    debug_nans: False
    logical_cpus: 0
    debug: False
    policy_devices: [0]
    train_devices: [0]
    metrics_every: 10

  run:
    script: train
    steps: 1e10
    expl_until: 0
    log_every: 300
    save_every: 900
    eval_every: 1e6
    eval_initial: True
    eval_eps: 1
    eval_samples: 1
    train_ratio: 32.0
    train_fill: 0
    eval_fill: 0
    log_zeros: False
    log_keys_video: [image]
    log_keys_sum: '^$'
    log_keys_mean: '(log_entropy)'
    log_keys_max: '^$'
    from_checkpoint: ''
    sync_every: 10
    # actor_addr: 'tcp://127.0.0.1:5551'
    actor_addr: 'ipc:///tmp/5551'
    actor_batch: 32

  envs: {amount: 4, parallel: process, length: 0, reset: True, restart: True, discretize: 0, checks: False}
  wrapper: {length: 0, reset: True, discretize: 0, checks: False}
  env:
    seaquest: {size: [64, 64], repeat: 4, sticky: True, gray: False, actions: all, lives: unused, noops: 30, resize: opencv, property_file: "property_1", cost_val: 1.0, reset_on_death: False}

  # Agent
  task_behavior: Greedy
  expl_behavior: None
  batch_size: 16
  batch_length: 64
  data_loaders: 8

  # World Model
  grad_heads: [decoder, reward, cont, cost]
  rssm: {deter: 4096, units: 1024, stoch: 32, classes: 32, act: silu, norm: layer, initial: learned, unimix: 0.01, unroll: False, action_clip: 1.0, winit: normal, fan: avg}
  encoder: {mlp_keys: '.*', cnn_keys: '.*', act: silu, norm: layer, mlp_layers: 5, mlp_units: 1024, cnn: resnet, cnn_depth: 96, cnn_blocks: 0, resize: stride, winit: normal, fan: avg, symlog_inputs: True, minres: 4}
  decoder: {mlp_keys: '.*', cnn_keys: '.*', act: silu, norm: layer, mlp_layers: 5, mlp_units: 1024, cnn: resnet, cnn_depth: 96, cnn_blocks: 0, image_dist: mse, vector_dist: symlog_mse, inputs: [deter, stoch], resize: stride, winit: normal, fan: avg, outscale: 1.0, minres: 4, cnn_sigmoid: False}
  reward_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: symlog_disc, outscale: 0.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255}
  cont_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: binary, outscale: 1.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg}
  cost_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: binary, outscale: 1.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg}
  loss_scales: {image: 1.0, vector: 1.0, reward: 1.0, cont: 1.0, cost: 1.0, dyn: 0.5, rep: 0.1, actor: 1.0, critic: 1.0, slowreg: 1.0}
  dyn_loss: {impl: kl, free: 1.0}
  rep_loss: {impl: kl, free: 1.0}
  model_opt: {opt: adam, lr: 1e-4, eps: 1e-8, clip: 1000.0, wd: 0.0, warmup: 0, lateclip: 0.0}

  # Actor Critic
  actor: {layers: 5, units: 1024, act: silu, norm: layer, minstd: 0.1, maxstd: 1.0, outscale: 1.0, outnorm: False, unimix: 0.01, inputs: [deter, stoch], winit: normal, fan: avg, symlog_inputs: False}
  critic: {layers: 5, units: 1024, act: silu, norm: layer, dist: symlog_disc, outscale: 0.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255, symlog_inputs: False}
  actor_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, lateclip: 0.0}
  critic_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, lateclip: 0.0}
  actor_dist_disc: onehot
  actor_dist_cont: normal
  actor_grad_disc: reinforce
  actor_grad_cont: backprop
  critic_type: vfunction
  imag_horizon: 15
  imag_unroll: False
  horizon: 333
  return_lambda: 0.95
  critic_slowreg: logprob
  critic_cont_fn: cont
  slow_critic_update: 1
  slow_critic_fraction: 0.02
  retnorm: {impl: perc_ema, decay: 0.99, max: 1.0, perclo: 5.0, perchi: 95.0}
  actent: 3e-4

  # Exploration
  expl_rewards: {extr: 1.0, disag: 0.1}
  expl_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0}
  disag_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: mse, outscale: 1.0, inputs: [deter, stoch, action], winit: normal, fan: avg}
  disag_target: [stoch]
  disag_models: 8

  # Constrained Policy Optimisation
  penalty_mult: 5e-9
  lagrange_mult: 0.01
  penalty_pow: 1e-6
  cost_threshold: 1.0

  # Safety Critic
  bootstrap_cost: True
  safety_critic: {layers: 5, units: 1024, act: silu, norm: layer, dist: symlog_disc, outscale: 0.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255, symlog_inputs: False}
  safety_critic_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, lateclip: 0.0}
  safety_critic_slowreg: logprob
  safety_critic_type: 'vfunction'
  safety_critic_cont_fn: cont
  slow_safety_critic_update: 1
  slow_safety_critic_fraction: 0.02
  safety_horizon: 333

seaquest:

  task: seaquest
  envs.amount: 8
  wrapper: {length: 27000}
  run:
    steps: 5.5e7
    eval_eps: 10
    train_ratio: 64
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

colab:

  batch_size: 8
  batch_length: 12
  replay_size: 1e5
  encoder.cnn_depth: 8
  decoder.cnn_depth: 8
  rssm: {deter: 32, units: 16, stoch: 4, classes: 4}
  .*unroll: False
  .*\.layers: 2
  .*\.units: 16
  .*\.wd$: 0.0

lenovo:
  jax: {jit: True, prealloc: False, debug: True, platform: cpu}
  batch_size: 8
  batch_length: 12
  replay_size: 1e5
  encoder.cnn_depth: 8
  decoder.cnn_depth: 8
  rssm: {deter: 32, units: 16, stoch: 4, classes: 4}
  .*unroll: False
  .*\.layers: 2
  .*\.units: 16
  .*\.wd$: 0.0

small:
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.units: 512
  .*\.layers: 2

medium:
  rssm.deter: 1024
  .*\.cnn_depth: 48
  .*\.units: 640
  .*\.layers: 3

large:
  rssm.deter: 2048
  .*\.cnn_depth: 64
  .*\.units: 768
  .*\.layers: 4

xlarge:
  rssm.deter: 4096
  .*\.cnn_depth: 96
  .*\.units: 1024
  .*\.layers: 5

multicpu:

  jax:
    logical_cpus: 8
    policy_devices: [0, 1]
    train_devices: [2, 3, 4, 5, 6, 7]
  run:
    actor_batch: 4
  envs:
    amount: 8
  batch_size: 12
  batch_length: 10

debug:

  jax: {jit: True, prealloc: False, debug: True, platform: cpu}
  envs: {restart: False, amount: 3}
  wrapper: {length: 100, checks: True}
  run:
    eval_every: 1000
    log_every: 5
    save_every: 10
    train_ratio: 32
    actor_batch: 2
  batch_size: 8
  batch_length: 12
  replay_size: 1e5
  encoder.cnn_depth: 8
  decoder.cnn_depth: 8
  rssm: {deter: 32, units: 16, stoch: 4, classes: 4}
  .*unroll: False
  .*\.layers: 2
  .*\.units: 16
  .*\.wd$: 0.0