defaults:
  - _self_
  - task: real_pusht_image

name: train_ibc_dfo_hybrid
_target_: diffusion_policy.workspace.train_ibc_dfo_hybrid_workspace.TrainIbcDfoHybridWorkspace

task_name: ${task.name}
shape_meta: ${task.shape_meta}
exp_name: "default"

horizon: 2
n_obs_steps: 2
n_action_steps: 1
n_latency_steps: 1
dataset_obs_steps: ${n_obs_steps}
past_action_visible: False
keypoint_visible_rate: 1.0

policy:
  _target_: diffusion_policy.policy.ibc_dfo_hybrid_image_policy.IbcDfoHybridImagePolicy

  shape_meta: ${shape_meta}

  horizon: ${horizon}
  n_action_steps: ${n_action_steps}
  n_obs_steps: ${n_obs_steps}
  dropout: 0.1
  train_n_neg: 256
  pred_n_iter: 3
  pred_n_samples: 1024
  kevin_inference: False
  andy_train: False
  obs_encoder_group_norm: True
  eval_fixed_crop: True
  crop_shape: [216, 288] # ch, cw 320x240 90%

dataloader:
  batch_size: 128
  num_workers: 8
  shuffle: True
  pin_memory: True
  persistent_workers: False

val_dataloader:
  batch_size: 128
  num_workers: 1
  shuffle: False
  pin_memory: True
  persistent_workers: False

optimizer:
  _target_: torch.optim.AdamW
  lr: 1.0e-4
  betas: [0.95, 0.999]
  eps: 1.0e-8
  weight_decay: 1.0e-6

training:
  device: "cuda:0"
  seed: 42
  debug: False
  resume: True
  # optimization
  lr_scheduler: cosine
  lr_warmup_steps: 500
  num_epochs: 1000
  gradient_accumulate_every: 1
  # training loop control
  # in epochs
  rollout_every: 50
  checkpoint_every: 5
  val_every: 1
  sample_every: 5
  sample_max_batch: 128
  # steps per epoch
  max_train_steps: null
  max_val_steps: null
  # misc
  tqdm_interval_sec: 1.0

logging:
  project: diffusion_policy_debug
  resume: True
  mode: online
  name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
  tags: ["${name}", "${task_name}", "${exp_name}"]
  id: null
  group: null

checkpoint:
  topk:
    monitor_key: train_action_mse_error
    mode: min
    k: 5
    format_str: 'epoch={epoch:04d}-train_action_mse_error={train_action_mse_error:.3f}.ckpt'
  save_last_ckpt: True
  save_last_snapshot: False

multi_run:
  run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}

hydra:
  job:
    override_dirname: ${name}
  run:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  sweep:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
    subdir: ${hydra.job.num}
