defaults:

  seed: 0
  method: name
  task: carla_0
  logdir: log/dmc_back_walker_walk
  replay: uniform
  replay_size: 1e6
  replay_online: False
  eval_dir: ''
  filter: '.*'

  jax:
    platform: gpu
    jit: True
    precision: float16
    prealloc: True
    debug_nans: False
    logical_cpus: 0
    debug: False
    policy_devices: [0]
    train_devices: [0]
    metrics_every: 10

  run:
    script: train_eval
    steps: 300000
    expl_until: 0
    log_every: 300
    save_every: 900
    eval_every: 6000
    eval_initial: True
    eval_eps: 1
    eval_samples: 1
    train_ratio: 32.0
    train_fill: 0
    eval_fill: 0
    log_zeros: False
    log_keys_video: [image]
    log_keys_sum: '^$'
    log_keys_mean: '(log_entropy)'
    log_keys_max: '^$'
    from_checkpoint: ''
    sync_every: 10
    # actor_addr: 'tcp://127.0.0.1:5551'
    actor_addr: 'ipc:///tmp/5551'
    actor_batch: 32

  envs: {amount: 1, parallel: process, length: 0, reset: True, restart: True, discretize: 0, checks: False}
  wrapper: {length: 0, reset: True, discretize: 0, checks: False}
  env:
    atari: {size: [64, 64], repeat: 4, sticky: True, gray: False, actions: all, lives: unused, noops: 0, resize: opencv}
    dmlab: {size: [64, 64], repeat: 4, episodic: True}
    minecraft: {size: [64, 64], break_speed: 100.0}
    dmc: {size: [64, 64], repeat: 2, camera: -1}
    loconav: {size: [64, 64], repeat: 2, camera: -1}
    dmcback: {size: [64, 64], repeat: 2, camera: -1, eval_mode: 'video_easy'}
    carla: {size: [64, 64], repeat: 4}
    metaworld: {size: [64, 64], repeat: 1}
    myosuite: {size: [64, 64], repeat: 1}
    rmbench: {size: [64, 64], repeat: 2, camera: -1}
    humanoid: {obs_key: vector, policy_path: "", mean_path: "", var_path: "", policy_type: "", small_obs: "", is_eval: False, actuation: position, reward_dict: {hand_dist: 0.1, target_dist: 0.1, success: 10, terminate: False}}


  # Agent
  task_behavior: Greedy
  expl_behavior: None
  batch_size: 16
  batch_length: 32
  data_loaders: 8
  imag_inst_resample: True
  inst_prior: uniform
  multibit_inst: False
  first_inst: bernoulli
  load_replay: ''

  # World Model
  grad_heads: [decoder, reward, cont]
  rssm: {deter: 4096, units: 1024, stoch: 32, classes: 32, act: silu, norm: layer, initial: learned, unimix: 0.01, unroll: False, action_clip: 1.0, winit: normal, fan: avg}
  encoder: {mlp_keys: '.*', cnn_keys: '.*', act: silu, norm: layer, mlp_layers: 5, mlp_units: 1024, cnn: resnet, cnn_depth: 96, cnn_blocks: 0, resize: stride, winit: normal, fan: avg, symlog_inputs: True, minres: 4}
  decoder: {mlp_keys: '.*', cnn_keys: '.*', act: silu, norm: layer, mlp_layers: 5, mlp_units: 1024, cnn: resnet, cnn_depth: 96, cnn_blocks: 0, image_dist: mse, vector_dist: symlog_mse, inputs: [deter, stoch], resize: stride, winit: normal, fan: avg, outscale: 1.0, minres: 4, cnn_sigmoid: False}
  reward_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: symlog_disc, outscale: 0.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255}
  cont_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: binary, outscale: 1.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg}

  loss_scales: {image: 1.0, state: 1.0, vector: 1.0, reward: 1.0, cont: 1.0, dyn: 0.5, rep: 0.1, actor: 1.0, critic: 1.0, slowreg: 1.0, contrastive: {'min_kl': 0.1, 'max_kl': 0.1}}
  dyn_loss: {impl: kl, free: 1.0}
  rep_loss: {impl: kl, free: 1.0}
  model_opt: {opt: adam, lr: 1e-4, eps: 1e-8, clip: 1000.0, wd: 0.0, warmup: 0, lateclip: 0.0}

  # Actor Critic
  actor: {layers: 5, units: 1024, act: silu, norm: layer, minstd: 0.1, maxstd: 1.0, outscale: 1.0, outnorm: False, unimix: 0.01, inputs: [deter, stoch, inst], winit: normal, fan: avg, symlog_inputs: False}
  inst_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: onehot, outscale: 1.0, outnorm: False, unimix: 0.01, inputs: [deter, stoch], winit: normal, fan: avg, symlog_inputs: False}
  multibit_inst_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: binary, outscale: 1.0, outnorm: False, unimix: 0.01, inputs: [deter, stoch], winit: normal, fan: avg, symlog_inputs: False}
  critic: {layers: 5, units: 1024, act: silu, norm: layer, dist: symlog_disc, outscale: 0.0, outnorm: False, inputs: [deter, stoch], winit: normal, fan: avg, bins: 255, symlog_inputs: False}
  actor_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, lateclip: 0.0}
  critic_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0, lateclip: 0.0}
  actor_dist_disc: onehot
  actor_dist_cont: normal
  actor_grad_disc: reinforce
  actor_grad_cont: backprop
  critic_type: vfunction
  imag_horizon: 15
  imag_unroll: False
  horizon: 333
  return_lambda: 0.95
  critic_slowreg: logprob
  slow_critic_update: 1
  slow_critic_fraction: 0.02
  retnorm: {impl: perc_ema, decay: 0.99, max: 1.0, perclo: 5.0, perchi: 95.0}
  actent: 3e-4
  task_planning: instnet
  expl_planning: instnet
  sg_kl: True
  task_plannum: 16
  expl_plannum: 64
  imag_per_start: 32
  instnet_sparsity: False

  # Action Prior
  # action_prior: False
  # prio_act: {leg: [3,4,5,8,9,16,17,27,28,29,32,33,40,41]}
  # prio_key: leg
  # prior_until: 10000000

  # SMCP
  SMCP_inst_resample: False
  SMCP_weight_resample: True
  SMCP_horizon: 3
  calc_logpi: True
  calc_vs: True

  # Exploration
  expl_rewards: {extr: 1.0, disag: 0.1}
  expl_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 0.0, warmup: 0}
  disag_head: {layers: 5, units: 1024, act: silu, norm: layer, dist: mse, outscale: 1.0, inputs: [deter, stoch, action], winit: normal, fan: avg}
  disag_target: [stoch]
  disag_models: 8

minecraft:

  task: minecraft_diamond
  envs.amount: 16
  run:
    script: train_save
    eval_fill: 1e5
    train_ratio: 16
    log_keys_max: '^log_inventory.*'
  encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|reward', cnn_keys: 'image'}
  decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image'}

dmlab:

  task: dmlab_explore_goal_locations_small
  envs.amount: 8
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}
  run.train_ratio: 64


carla:
  task: carla_0
  run:
    steps: 5e5
    eval_every: 5e3 # 1e4
    log_every: 5e3 # 1e4
    save_every: 5e4
  grad_heads: [decoder, reward]
  wrapper: {length: 1000}
  batch_size: 16
  batch_length: 64
  encoder.cnn_depth: 32
  decoder.cnn_depth: 32
  # rssm: {deter: 32, units: 16, stoch: 4, classes: 4}
  # prefill: 2500
  # dataset_size: 0
  # pretrain: 100
  # action_step: 25
  # free_step: 50
  # use_free: True
  # rollout_policy: True
  # autoencoder: False

gym:
  task: gym_humanoid-v4
  run:
    steps: 5e6
    eval_every: 2e4 # 1e4
    log_every: 2e4 # 1e4
    save_every: 2e5
    train_ratio: 256
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}
  wrapper.length: 1000

atari:

  task: atari_pong
  envs.amount: 8
  run:
    steps: 5.5e7
    eval_eps: 10
    train_ratio: 64
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

atari100k:

  task: atari_pong
  envs: {amount: 1}
  env.atari: {gray: False, repeat: 4, sticky: False, noops: 30, actions: needed}
  run:
    script: train_eval
    steps: 1.5e5
    eval_every: 1e5
    eval_initial: False
    eval_eps: 100
    train_ratio: 1024
  jax.precision: float32
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units$: 512
  actor_eval_sample: True
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

crafter:

  task: crafter_reward
  envs.amount: 1
  run:
    log_keys_max: '^log_achievement_.*'
    log_keys_sum: '^log_reward$'
  run.train_ratio: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

dmc_vision:
  run:
    steps: 1e7
    eval_every: 5e4 # 1e4
    log_every: 5e4 # 1e4
    save_every: 5e5
  task: dmc_walker_walk
  run.train_ratio: 256
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

dmc_vision_SMCP:
  run:
    steps: 1e7
    eval_every: 5e4 # 1e4
    log_every: 5e4 # 1e4
    save_every: 5e5
    expl_until: 1e7
  task: dmc_walker_walk
  run.train_ratio: 256
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}
  expl_planning: instnet
  task_planning: SMCP_instnet
  inst_prior: instnet 
  expl_plannum: 16
  task_plannum: 64
  multibit_inst: True 
  first_inst: instnet 
  imag_per_start: 21 
  SMCP_inst_resample: False 
  SMCP_weight_resample: False 
  SMCP_horizon: 3 
  calc_logpi: False 
  calc_vs: True

metaworld:

  task: metaworld_button_press
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}
  run:
    eval_eps: 10

myosuite:
  run:
    steps: 1e6
    eval_every: 5e3 # 1e4
    log_every: 5e3 # 1e4
    eval_eps: 10
    save_every: 5e4
  # envs: {amount: 1, parallel: none}
  task: myo_key_turn
  wrapper.length: 1000
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: 'state', cnn_keys: 'image'}
  decoder: {mlp_keys: 'state', cnn_keys: 'image'}

myosuite_SMCP:
  run:
    steps: 1e6
    eval_every: 1e4 # 1e4
    log_every: 1e4 # 1e4
    eval_eps: 10
    save_every: 5e4
    expl_until: 1e7
  task: myo_key_turn
  wrapper.length: 1000
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: 'state', cnn_keys: 'image'}
  decoder: {mlp_keys: 'state', cnn_keys: 'image'}
  expl_planning: instnet
  task_planning: SMCP_instnet
  inst_prior: instnet 
  expl_plannum: 16
  task_plannum: 64
  multibit_inst: True 
  first_inst: instnet 
  imag_per_start: 21 
  SMCP_inst_resample: False 
  SMCP_weight_resample: False 
  SMCP_horizon: 3 
  calc_logpi: False 
  calc_vs: True

dmc_back:
  task: dmcback_walker_walk
  run.script: train_eval
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: 'state', cnn_keys: 'image'}

dmc_proprio:

  task: dmc_walker_walk
  run.train_ratio: 512
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '.*', cnn_keys: '$^'}
  decoder: {mlp_keys: '.*', cnn_keys: '$^'}

bsuite:

  task: bsuite_mnist/0
  envs: {amount: 1, parallel: none}
  run:
    script: train
    train_ratio: 1024  # 128 for cartpole
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512

loconav:

  task: loconav_ant_maze_m
  env.loconav.repeat: 2
  run:
    train_ratio: 512
    log_keys_max: '^log_.*'
  encoder: {mlp_keys: '.*', cnn_keys: 'image'}
  decoder: {mlp_keys: '.*', cnn_keys: 'image'}

humanoid_vision:
  task: humanoid_Pushing-v0
  run:
    steps: 1e7
    log_every: 2e4  # Seconds
    eval_every: 2e4  # Steps
    train_ratio: 512
    save_every: 5e5
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}

humanoid_vision_SMCP:
  task: humanoid_Pushing-v0
  run:
    steps: 1e7
    log_every: 5e4  # Seconds
    eval_every: 5e4  # Steps
    train_ratio: 512
    save_every: 5e5
    expl_until: 1e7
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512
  encoder: {mlp_keys: '$^', cnn_keys: 'image'}
  decoder: {mlp_keys: '$^', cnn_keys: 'image'}
  expl_planning: instnet
  task_planning: SMCP_instnet
  inst_prior: instnet 
  expl_plannum: 16
  task_plannum: 64
  multibit_inst: True 
  first_inst: instnet 
  imag_per_start: 21 
  SMCP_inst_resample: False 
  SMCP_weight_resample: False 
  SMCP_horizon: 3 
  calc_logpi: False 
  calc_vs: True


humanoid_benchmark:
  steps: 1e7
  task: humanoid_Pushing-v0
  rssm.deter: 1024
  .*\.cnn_depth: 48
  .*\.units: 640
  .*\.layers: 3
  run:
    script: train_eval
    train_ratio: 64
    log_every: 2e4  # Seconds
    eval_every: 2e4  # Steps
    save_every: 5e5
  encoder: {mlp_keys: '.*', cnn_keys: '$^'}
  decoder: {mlp_keys: '.*', cnn_keys: '$^'}

humanoid_SMCP:
  task: humanoid_Pushing-v0
  rssm.deter: 1024
  .*\.cnn_depth: 48
  .*\.units: 640
  .*\.layers: 3
  run:
    steps: 1e7
    script: train_eval
    train_ratio: 64
    log_every: 2e4  # Seconds
    eval_every: 2e4  # Steps
    save_every: 5e5
    expl_until: 1e7
  encoder: {mlp_keys: '.*', cnn_keys: '$^'}
  decoder: {mlp_keys: '.*', cnn_keys: '$^'}
  expl_planning: instnet
  task_planning: SMCP_instnet
  inst_prior: instnet 
  expl_plannum: 16
  task_plannum: 64
  multibit_inst: True 
  first_inst: instnet 
  imag_per_start: 21 
  SMCP_inst_resample: False 
  SMCP_weight_resample: False 
  SMCP_horizon: 3 
  calc_logpi: False 
  calc_vs: True

small:
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.units: 512
  .*\.layers: 2

medium:
  rssm.deter: 1024
  .*\.cnn_depth: 48
  .*\.units: 640
  .*\.layers: 3

large:
  rssm.deter: 2048
  .*\.cnn_depth: 64
  .*\.units: 768
  .*\.layers: 4

xlarge:
  rssm.deter: 4096
  .*\.cnn_depth: 96
  .*\.units: 1024
  .*\.layers: 5

multicpu:

  jax:
    logical_cpus: 8
    policy_devices: [0, 1]
    train_devices: [2, 3, 4, 5, 6, 7]
  run:
    actor_batch: 4
  envs:
    amount: 8
  batch_size: 12
  batch_length: 10

debug:

  jax: {jit: True, prealloc: False, debug: True, platform: cpu}
  envs: {restart: False, amount: 1}
  wrapper: {length: 100, checks: True}
  run:
    eval_every: 1000
    log_every: 5
    save_every: 10
    train_ratio: 32
    actor_batch: 2
  batch_size: 2
  batch_length: 2
  replay_size: 1e5
  encoder.cnn_depth: 8
  decoder.cnn_depth: 8
  rssm: {deter: 32, units: 16, stoch: 4, classes: 4}
  .*unroll: False
  .*\.layers: 2
  .*\.units: 16
  .*\.wd$: 0.0
