defaults:

  logdir: ~/logdir/{timestamp}
  replica: 0
  replicas: 1
  method: name
  task: dummy_disc
  seed: 0
  script: train
  batch_size: 16
  batch_length: 64
  report_length: 32
  consec_train: 1
  consec_report: 1
  replay_context: 1
  random_agent: False
  clock_addr: ''
  clock_port: ''
  ipv6: False
  errfile: False

  delay:
    delayed: False
    obs_delay_step: True
    method: 'latent' # 'latent', 'extended_obs', 'latest_obs', 'wait'
    privilaged_decoder: False
    include_actions: True
    include_masks: False
    delayed_training: False
    delay_distribution: Fixed
    maximum_delay: 5
    obs_window_size: 5
    delay_TruncatedGaussian_dist_kwargs: {mu: 5, scale: 1, minimum: 0, maximum: 5}
    delay_Uniform_dist_kwargs: {minimum: 0, maximum: 5}
    delay_UniformlyIncreasing_dist_kwargs: {minimum: 0, maximum: 5}
    delay_UniformlyDecreasing_dist_kwargs: {minimum: 0, maximum: 5}
    delay_Fixed_dist_kwargs: {length: 5}
    visualize_reconstruction: False

  logger:
    outputs: [jsonl, scope]
    filter: 'score|length|fps|ratio|train/loss/|train/rand/'
    timer: True
    fps: 15
    user: ''

  env:
    atari: {size: [96, 96], repeat: 4, sticky: True, gray: True, actions: all, lives: unused, noops: 30, autostart: False, pooling: 2, aggregate: max, resize: pillow, clip_reward: False}
    procgen: {size: [96, 96], resize: pillow}
    crafter: {size: [64, 64], logs: False}
    atari100k: {size: [64, 64], repeat: 4, sticky: False, gray: False, actions: needed, lives: unused, noops: 30, autostart: False, resize: pillow, clip_reward: False}
    dmlab: {size: [64, 64], repeat: 4, episodic: True, use_seed: True}
    minecraft: {size: [64, 64], break_speed: 100.0, logs: False, length: 36000}
    dmc: {size: [64, 64], repeat: 1, proprio: True, image: True, camera: -1}
    metaworld: {size: [64, 64], repeat: 2, proprio: True, image: True, camera: -1, seed: 0, length: 250}
    loconav: {size: [64, 64], repeat: 1, camera: -1}

    action_noise: {noise_fraction: 0.0}

  replay:
    size: 5e6
    online: True
    fracs: {uniform: 1.0, priority: 0.0, recency: 0.0}
    prio: {exponent: 0.8, maxfrac: 0.5, initial: inf, zero_on_sample: True}
    priosignal: model
    recexp: 1.0
    chunksize: 1024
    save_replay: True

  run:
    steps: 1e10
    duration: 0
    train_ratio: 32.0
    log_every: 120
    report_every: 300
    save_every: 900
    envs: 16
    eval_envs: 4
    eval_eps: 1
    report_batches: 1
    from_checkpoint: ''
    episode_timeout: 180
    actor_addr: 'localhost:{auto}'
    replay_addr: 'localhost:{auto}'
    logger_addr: 'localhost:{auto}'
    actor_batch: -1
    actor_threads: 1
    agent_process: False
    remote_replay: False
    remote_envs: False
    usage: {psutil: True, nvsmi: True, gputil: False, malloc: False, gc: False}
    debug: True

  jax:
    platform: cuda
    compute_dtype: bfloat16
    policy_devices: [0]
    train_devices: [0]
    mock_devices: 0
    prealloc: True
    jit: True
    debug: False
    expect_devices: 0
    enable_policy: True
    coordinator_address: ''

  agent:
    # loss_scales: {rec: 1.0, rew: 1.0, con: 1.0, dyn: 1.0, rep: 0.1, policy: 1.0, value: 1.0, repval: 0.3}
    loss_scales: {rec: 1.0, rew: 1.0, con: 1.0, dyn: 1.0, rep: 0.1, policy: 1.0, value: 1.0}
    opt: {lr: 4e-5, agc: 0.3, eps: 1e-20, beta1: 0.9, beta2: 0.999, momentum: True, wd: 0.0, schedule: const, warmup: 1000, anneal: 0}
    ac_grads: False
    dyn:
      typ: rssm
      rssm: {deter: 8192, hidden: 1024, stoch: 32, classes: 64, act: silu, norm: rms, unimix: 0.01, outscale: 1.0, winit: trunc_normal_in, imglayers: 2, obslayers: 1, dynlayers: 1, absolute: False, blocks: 8, free_nats: 1.0, particles: 1}
    enc:
      typ: simple
      simple: {depth: 64, mults: [2, 3, 4, 4], layers: 3, units: 1024, act: silu, norm: rms, winit: trunc_normal_in, symlog: True, outer: False, kernel: 5, strided: False, enc_key_prefix: ''}
    dec:
      typ: simple
      simple: {depth: 64, mults: [2, 3, 4, 4], layers: 3, units: 1024, act: silu, norm: rms, outscale: 1.0, winit: trunc_normal_in, outer: False, kernel: 5, bspace: 8, strided: False, dec_key_prefix: ''}
    rewhead: {layers: 1, units: 1024, act: silu, norm: rms, output: symexp_twohot, outscale: 0.0, winit: trunc_normal_in, bins: 255}
    conhead: {layers: 1, units: 1024, act: silu, norm: rms, output: binary, outscale: 1.0, winit: trunc_normal_in}
    policy: {layers: 3, units: 1024, act: silu, norm: rms, minstd: 0.1, maxstd: 1.0, outscale: 0.01, unimix: 0.01, winit: trunc_normal_in}
    value: {layers: 3, units: 1024, act: silu, norm: rms, output: symexp_twohot, outscale: 0.0, winit: trunc_normal_in, bins: 255}
    policy_dist_disc: categorical
    policy_dist_cont: bounded_normal
    imag_last: 0
    imag_length: 15
    horizon: 333
    contdisc: True
    imag_loss: {slowtar: False, lam: 0.95, actent: 3e-4, slowreg: 1.0}
    repl_loss: {slowtar: False, lam: 0.95, slowreg: 1.0}
    slowvalue: {rate: 0.02, every: 1}
    retnorm: {impl: perc, rate: 0.01, limit: 1.0, perclo: 5.0, perchi: 95.0, debias: False}
    valnorm: {impl: none, rate: 0.01, limit: 1e-8}
    advnorm: {impl: none, rate: 0.01, limit: 1e-8}
    reward_grad: True
    repval_loss: True
    repval_grad: True
    report: True
    report_gradnorms: False

latent:
  delay.delayed: True
  delay.method: 'latent'
  delay.delayed_training: True
  run.envs: 1
  run.eval_envs: 1
  script: delayedtrain_delayedeval

extended_obs:
  delay.delayed: True
  delay.method: "extended_obs"
  delay.include_actions: True
  delay.include_masks: False
  # delay.privilaged_decoder: True
  # agent.enc.simple.enc_key_prefix: 'encoder'
  # agent.dec.simple.dec_key_prefix: 'decoder'
  delay.privilaged_decoder: False
  agent.enc.simple.enc_key_prefix: ''
  agent.dec.simple.dec_key_prefix: ''
  script: train_eval
  
latest_obs:
  delay.delayed: True
  delay.method: "latest_obs"
  script: train_eval

wait:
  delay.delayed: True
  delay.method: "wait"
  script: train_eval

size1m: &size1m
  .*\.rssm: {deter: 512, hidden: 64, classes: 4}
  .*\.depth: 4
  .*\.units: 64

size7m: &size7m
  .*\.rssm: {deter: 512, hidden: 64, classes: 32}
  .*\.depth: 4
  .*\.units: 512

size12m: &size12m
  .*\.rssm: {deter: 2048, hidden: 256, classes: 16}
  .*\.depth: 16
  .*\.units: 256

size25m: &size25m
  .*\.rssm: {deter: 3072, hidden: 384, classes: 24}
  .*\.depth: 24
  .*\.units: 384

size50m: &size50m
  .*\.rssm: {deter: 4096, hidden: 512, classes: 32}
  .*\.depth: 32
  .*\.units: 512

size100m: &size100m
  .*\.rssm: {deter: 6144, hidden: 768, classes: 48}
  .*\.depth: 48
  .*\.units: 768

size200m: &size200m
  .*\.rssm: {deter: 8192, hidden: 1024, classes: 64}
  .*\.depth: 64
  .*\.units: 1024

size400m: &size400m
  .*\.rssm: {deter: 12288, hidden: 1536, classes: 96}
  .*\.depth: 96
  .*\.units: 1536

minecraft:
  task: minecraft_diamond

dmlab:
  task: dmlab_explore_goal_locations_small
  run: {steps: 2.6e7, train_ratio: 32}

atari:
  task: atari_pong
  run: {steps: 5.1e7, train_ratio: 32}

procgen:
  task: procgen_coinrun
  run: {steps: 1.1e8, train_ratio: 64}

atari100k:
  task: atari100k_pong
  run: {steps: 1.1e5, envs: 1, train_ratio: 256}

crafter:
  task: crafter_reward
  run: {steps: 1.1e6, envs: 1, train_ratio: 512}

dmc_proprio:
  task: dmc_walker_walk
  env.dmc.image: False
  run: {steps: 1.1e6, train_ratio: 1024}

gym_proprio:
  task: gym_Pendulum
  # env.dmc.image: False
  run: {steps: 1.1e6, train_ratio: 1024}

dmc_vision:
  task: dmc_walker_walk
  env.dmc.proprio: False
  run: {steps: 1.1e6, train_ratio: 256}
  
metaworld:
  task: metaworld_door-close-v2  # metaworld_box-close-v2
  run:
    train_ratio: 512  # TODO
    steps: 25000  # 1e5  # 75000 (rebuttal)  # total: multiplied by action_repeat  # 5e5
    report_every: 1000  # 5000  # total: multiplied by action_repeat  # eval_every: 1000
    eval_eps: 10  # 8
    # eval_initial: True  # does not matter much.
    envs: 1  # envs.amount: 1
    eval_envs: 1

metaworld_vision:
  env.metaworld.proprio: False

metaworld_proprio:
  env.metaworld.image: False

log_tb:
  logger.outputs: [jsonl, scope, tensorboard]

bsuite:
  task: bsuite_mnist/0
  run: {envs: 1, save_every: -1, train_ratio: 1024}

loconav:
  task: loconav_ant_maze_m
  env.loconav.repeat: 1
  run.train_ratio: 256

multicpu:
  batch_size: 12
  jax.mock_devices: 8
  jax.policy_devices: [0, 1]
  jax.train_devices: [2, 3, 4, 5, 6, 7]

debug:
  batch_size: 8
  batch_length: 10
  report_length: 5
  jax: {platform: cpu, debug: True, prealloc: False}
  run: {envs: 4, report_every: 10, log_every: 5, save_every: 15, train_ratio: 8, debug: True}
  replay.size: 1e4
  agent:
    .*\.bins: 5
    .*\.layers: 1
    .*\.units: 8
    .*\.stoch: 2
    .*\.classes: 4
    .*\.deter: 8
    .*\.hidden: 3
    .*\.blocks: 4
    .*\.depth: 2
