defaults:
  - _self_  # all below configs will override this conf.yaml
  - arch: wbvima
  - task: ???

run_name: "${optional_:${prefix}}\
  ${arch_name}\
  _${task_name}\
  _lr${scientific:${lr},1}\
  _wd${scientific:${wd}}\
  _b${bs}\
  ${__optional:${suffix}}\
  ${_optional:${postsuffix}}"
exp_root_dir: ???
arch_name: ???  # filled by arch

# ====== main cfg ======
seed: -1
gpus: 1
use_cosine_lr: true
lr: 3e-5
lr_warmup_steps: 100000
lr_cosine_steps: 1500000
lr_cosine_min: 5e-6
lr_layer_decay: 1.0
wd: 0.0
bs: 256
vbs: ${bs}
data_dir: ???
eval_interval: 10
rollout_eval: false
# ------ logging ------
use_wandb: true
wandb_project: ???
wandb_run_name: ${run_name}

# ------ common ------
action_keys: ["mobile_base", "torso", "left_arm", "left_gripper", "right_arm", "right_gripper"]
action_key_dims:
  mobile_base: 3
  torso: 4
  left_arm: 6
  left_gripper: 1
  right_arm: 6
  right_gripper: 1

# ------ module ------
module:
  _target_: ???  # filled by arch
  # ====== policy ======
  policy: ???
  # ====== learning ======
  lr: ${lr}
  use_cosine_lr: ${use_cosine_lr}
  lr_warmup_steps: ${lr_warmup_steps}
  lr_cosine_steps: ${lr_cosine_steps}
  lr_cosine_min: ${lr_cosine_min}
  lr_layer_decay: ${lr_layer_decay}
  weight_decay: ${wd}
  action_keys: ${action_keys}

data_module:
  _target_: ???
  data_path: ${data_dir}
  pcd_downsample_points: ${pcd_downsample_points}
  batch_size: ${bs}
  val_batch_size: ${vbs}
  val_split_ratio: 0.1
  seed: ${seed}
  dataloader_num_workers: 4

trainer:
  cls: pytorch_lightning.Trainer
  accelerator: "gpu"
  devices: ${gpus}
  precision: bf16-mixed
  benchmark: true  # enables cudnn.benchmark
  accumulate_grad_batches: 1
  num_sanity_val_steps: 0
  max_epochs: 99999999
  val_check_interval: null
  check_val_every_n_epoch: ${eval_interval}
  gradient_clip_val: 1.0
  checkpoint:  # this sub-dict will be popped to send to ModelCheckpoint as args
  - filename: "epoch{epoch}-train_loss{train/loss:.5f}"
    save_on_train_epoch_end: true  # this is a training metric, so we save it at the end of training epoch
    save_top_k: 20
    save_last: true
    monitor: "train/loss"
    mode: min
    auto_insert_metric_name: false  # prevent creating subfolder caused by the slash
  - filename: "epoch{epoch}-val_l1_{val/l1:.5f}"
    eval_type: "static"
    save_top_k: 140
    save_last: true
    monitor: "val/l1"
    mode: min
    auto_insert_metric_name: false  # prevent creating subfolder caused by the slash
  callbacks:
    - cls: LearningRateMonitor
      logging_interval: step
    - cls: RichModelSummary

# ------------- Global cfgs for enlight.LightningTrainer ---------------


# ------------- Resume training ---------------
resume:
  ckpt_path: "./datasets/epoch559-val_l1_0.29266.pth" # null
  full_state: false  # if true, resume all states including optimizer, amp, lightning callbacks
  strict: true

# ------------- Testing ---------------
test:
  ckpt_path: null

# ----------------------------

prefix:
suffix:
postsuffix:

hydra:
  job:
    chdir: true
  run:
    dir: "."
  output_subdir: null