defaults:

  logdir: ~/logdir/{timestamp}
  replica: 0
  replicas: 1
  method: name
  task: dummy_disc
  seed: 0
  script: train
  batch_size: 16  
  batch_length: 64
  report_length: 32
  consec_train: 1
  consec_report: 1
  replay_context: 1
  random_agent: False
  clock_addr: ''
  clock_port: ''
  ipv6: False
  errfile: False

  logger:
    outputs: [jsonl, scope]
    filter: 'score|length|fps|ratio|train/loss/|train/rand/'
    timer: True
    fps: 15
    user: ''

  env:
    atari: {size: [96, 96], repeat: 4, sticky: True, gray: True, actions: all, lives: unused, noops: 30, autostart: False, pooling: 2, aggregate: max, resize: pillow, clip_reward: False}
    procgen: {size: [96, 96], resize: pillow}
    crafter: {size: [64, 64], logs: False}
    atari100k: {size: [64, 64], repeat: 4, sticky: False, gray: False, actions: needed, lives: unused, noops: 30, autostart: False, resize: pillow, clip_reward: False}
    dmlab: {size: [64, 64], repeat: 4, episodic: True, use_seed: True}
    minecraft: {size: [64, 64], break_speed: 100.0, logs: False, length: 36000}
    dmc: {size: [64, 64], repeat: 1, proprio: True, image: True, camera: -1}
    loconav: {size: [64, 64], repeat: 1, camera: -1}

  replay:
    size: 4.5e6 
    online: True
    fracs: {uniform: 1.0, priority: 0.0, recency: 0.0}
    prio: {exponent: 0.8, maxfrac: 0.5, initial: inf, zero_on_sample: True}
    priosignal: model
    recexp: 1.0
    chunksize: 1024

  run:
    steps: 1e10
    duration: 0
    train_ratio: 32.0
    log_every: 120 
    report_every: 300
    save_every: 1800      
    ckpt_keep: 10
    envs: 16  
    eval_envs: 4
    eval_eps: 1
    report_batches: 1
    from_checkpoint: ''
    episode_timeout: 180
    actor_addr: 'localhost:{auto}'
    replay_addr: 'localhost:{auto}'
    logger_addr: 'localhost:{auto}'
    actor_batch: -1
    actor_threads: 1
    agent_process: False
    remote_replay: False
    remote_envs: False
    usage: {psutil: True, nvsmi: True, gputil: False, malloc: False, gc: False}
    debug: True

  jax:
    profiler: False 
    platform: cuda
    compute_dtype: bfloat16
    policy_devices: [0]  
    train_devices: [0]  
    mock_devices: 0
    prealloc: True
    jit: True
    debug: False
    expect_devices: 0
    enable_policy: True
    coordinator_address: ''

  agent:
    
    loss_scales:
      {rec: 1.0, rew: 1.0, con: 1.0, dyn: 1.0, rep: 0.1,
      policy: 1.0, value: 1.0, repval: 0.3,
      flow: 0.03, cons: 0.01}

    opt: {lr: 4e-5, agc: 0.3, eps: 1e-20, beta1: 0.9, beta2: 0.999, momentum: True, wd: 0.0, schedule: const, warmup: 1000, anneal: 0}
    ac_grads: False

    dyn:
      typ: rssm
      rssm: {deter: 8192, hidden: 1024, stoch: 32, classes: 64, act: silu, norm: rms, unimix: 0.01, outscale: 1.0, winit: trunc_normal_in, imglayers: 2, obslayers: 1, dynlayers: 1, absolute: False, blocks: 8, free_nats: 1.0}

    enc:
      typ: simple
      simple: {depth: 64, mults: [2, 3, 4, 4], layers: 3, units: 1024, act: silu, norm: rms, winit: trunc_normal_in, symlog: True, outer: False, kernel: 5, strided: False}

    dec:
      typ: simple
      simple: {depth: 64, mults: [2, 3, 4, 4], layers: 3, units: 1024, act: silu, norm: rms, outscale: 1.0, winit: trunc_normal_in, outer: False, kernel: 5, bspace: 8, strided: False}

    rewhead: {layers: 1, units: 1024, act: silu, norm: rms, output: symexp_twohot, outscale: 0.0, winit: trunc_normal_in, bins: 255}
    conhead: {layers: 1, units: 1024, act: silu, norm: rms, output: binary, outscale: 1.0, winit: trunc_normal_in}

    # =========================
    # FPCT: shared representation head 
    # =========================
    repr_dim: 1024
    reprhead: {layers: 2, units: 1024, act: silu, norm: rms, output: mse, outscale: 1.0, winit: trunc_normal_in}

    
    policy: {layers: 3, units: 1024, act: silu, norm: rms, minstd: 0.1, maxstd: 1.0, outscale: 0.01, unimix: 0.01, winit: trunc_normal_in}
    value:  {layers: 3, units: 1024, act: silu, norm: rms, output: symexp_twohot, outscale: 0.0, winit: trunc_normal_in, bins: 255}

    # =========================
    # FlowMAP: flow head 
    # =========================
    flowhead: {layers: 2, units: 1024, act: silu, norm: rms, output: mse, outscale: 1.0, winit: trunc_normal_in}

    # FlowMAP hyperparams 
    flow_grads: false          
    flow_huber_delta: 1.0
    flow_topk: 64              
    flow_early_steps: 3        
    flow_value_bonus: 1.0      
    flow_dt: 0.0666667         

    policy_dist_disc: categorical
    policy_dist_cont: bounded_normal
    imag_last: 0
    imag_length: 15
    horizon: 333
    contdisc: True
    imag_loss: {slowtar: False, lam: 0.95, actent: 3e-4, slowreg: 1.0}
    repl_loss: {slowtar: False, lam: 0.95, slowreg: 1.0}
    slowvalue: {rate: 0.02, every: 1}
    retnorm: {impl: perc, rate: 0.01, limit: 1.0, perclo: 5.0, perchi: 95.0, debias: False}
    valnorm: {impl: none, rate: 0.01, limit: 1e-8}
    advnorm: {impl: none, rate: 0.01, limit: 1e-8}
    reward_grad: True
    repval_loss: True
    repval_grad: True
    report: True
    report_gradnorms: False

size1m: &size1m
  .*\.rssm: {deter: 512, hidden: 64, classes: 4}
  .*\.depth: 4
  .*\.units: 64

size12m: &size12m
  .*\.rssm: {deter: 2048, hidden: 256, classes: 16}
  .*\.depth: 16
  .*\.units: 256

size25m: &size25m
  .*\.rssm: {deter: 3072, hidden: 384, classes: 24}
  .*\.depth: 24
  .*\.units: 384

size50m: &size50m
  .*\.rssm: {deter: 4096, hidden: 512, classes: 32}
  .*\.depth: 32
  .*\.units: 512

size100m: &size100m
  .*\.rssm: {deter: 6144, hidden: 768, classes: 48}
  .*\.depth: 48
  .*\.units: 768

size200m: &size200m
  .*\.rssm: {deter: 8192, hidden: 1024, classes: 64}
  .*\.depth: 64
  .*\.units: 1024

size400m: &size400m
  .*\.rssm: {deter: 12288, hidden: 1536, classes: 96}
  .*\.depth: 96
  .*\.units: 1536

minecraft:
  task: minecraft_diamond

dmlab:
  task: dmlab_explore_goal_locations_small
  run: {steps: 2.6e7, train_ratio: 32}

atari:
  task: atari_pong
  run: {steps: 5.1e7, train_ratio: 32}

procgen:
  task: procgen_coinrun
  run: {steps: 1.1e8, train_ratio: 64}

atari100k:
  task: atari100k_pong
  run: {steps: 1.1e5, envs: 1, train_ratio: 256}

crafter:
  task: crafter_reward
  run: {steps: 1.1e6, envs: 1, train_ratio: 512}

dmc_proprio:
  <<: *size1m
  task: dmc_walker_walk
  env.dmc.image: False
  run: {steps: 1.1e6, train_ratio: 1024}

dmc_vision:
  task: dmc_walker_walk
  env.dmc.proprio: False
  run: {steps: 1.1e6, train_ratio: 256}

bsuite:
  task: bsuite_mnist/0
  run: {envs: 1, train_ratio: 1024}  

loconav:
  task: loconav_ant_maze_m
  env.loconav.repeat: 1
  run.train_ratio: 256

multicpu:
  batch_size: 12
  jax.mock_devices: 8
  jax.policy_devices: [0, 1]
  jax.train_devices: [2, 3, 4, 5, 6, 7]

debug:
  batch_size: 8
  batch_length: 10
  report_length: 5
  jax: {platform: cpu, debug: True, prealloc: False}
  run: {envs: 4, report_every: 10, log_every: 5, save_every: 15, train_ratio: 8, debug: True}
  replay.size: 1e4
  agent:
    .*\.bins: 5
    .*\.layers: 1
    .*\.units: 8
    .*\.stoch: 2
    .*\.classes: 4
    .*\.deter: 8
    .*\.hidden: 3
    .*\.blocks: 4
    .*\.depth: 2
