defaults:
  - _self_
  - task: pbrl_realrobot_pick #pbrl_realrobot_pick_slice

name: train_diffusion_realrobot
_target_: diffusion_policy.workspace.pbrl_diffusion_realrobot_online_type1_workspace.PbrlDiffusionRealRobotWorkspace #_type1_, _type2_

checkpoint_dir: 'data/experiment/realrobot/latest.ckpt'
exclude_keys: 'optimizer'

task_name: ${task.name}
shape_meta: ${task.shape_meta}
exp_name: "default"

gamma: 1
stride: 8
horizon: 16
n_obs_steps: 2
n_action_steps: 8
n_latency_steps: 0
dataset_obs_steps: ${n_obs_steps}
past_action_visible: False
keypoint_visible_rate: 1.0
obs_as_global_cond: True

policy:
  _target_: diffusion_policy.policy.ours_diffusion_realrobot_policy.DiffusionRealRobotPolicy #dpo_, ours_

  shape_meta: ${shape_meta}
  
  noise_scheduler:
    _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler
    num_train_timesteps: 100
    beta_start: 0.0001
    beta_end: 0.02
    # beta_schedule is important
    # this is the best we found
    beta_schedule: squaredcos_cap_v2
    clip_sample: True
    set_alpha_to_one: True
    steps_offset: 0
    prediction_type: epsilon # or sample

  obs_encoder:
    _target_: diffusion_policy.model.vision.realrobot_image_obs_encoder.RealRobotImageObsEncoder
    shape_meta: ${shape_meta}
    rgb_model:
      _target_: diffusion_policy.model.vision.model_getter.get_resnet
      name: resnet18
      weights: null
    resize_shape: [1024, 1280] # [1024, 1280]
    crop_shape: [922, 1152] # [922, 1152]
    random_crop: True
    use_group_norm: True
    share_rgb_model: False
    imagenet_norm: True

  beta: 0.0002
  map_ratio: 1.0
  bias_reg: 1.0
  gamma: ${gamma}
  horizon: ${horizon}
  n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
  n_obs_steps: ${n_obs_steps}
  num_inference_steps: 100
  obs_as_global_cond: ${obs_as_global_cond}
  # crop_shape: null
  diffusion_step_embed_dim: 128
  down_dims: [512, 1024, 2048]
  kernel_size: 5
  n_groups: 8
  cond_predict_scale: True


ema:
  _target_: diffusion_policy.model.diffusion.ema_model.EMAModel
  update_after_step: 0
  inv_gamma: 1.0
  power: 0.75
  min_value: 0.0
  max_value: 0.9999

dataloader:
  batch_size: 1
  num_workers: 8
  shuffle: True
  pin_memory: True #True
  persistent_workers: True

val_dataloader:
  batch_size: 1
  num_workers: 8
  shuffle: False
  pin_memory: True
  persistent_workers: True

optimizer:
  _target_: torch.optim.AdamW
  lr: 1.0e-5
  betas: [0.95, 0.999]
  eps: 1.0e-8
  weight_decay: 1.0e-6

training:
  device_cpu: "cpu"
  device_gpu: "cuda:0"
  seed: 42
  debug: False
  resume: True #True
  # optimization
  lr_scheduler: cosine
  # Transformer needs LR warmup
  lr_warmup_steps: 0
  num_epochs: 2
  # lr_end: 1.0e-8
  # power: 1.5
  gradient_norm_clip: 0.5
  gradient_accumulate_every: 1
  use_ema: True
  freeze_encoder: False
  # training loop control
  # in epochs
  rollout_every: 1
  checkpoint_every: 1
  val_every: 1
  sample_every: 1
  # steps per epoch
  max_train_steps: null
  max_val_steps: null
  # misc
  tqdm_interval_sec: 1.0
  dataset_1_dir: ${task.dataset_1.dataset_dir}
  dataset_2_dir: ${task.dataset_2.dataset_dir}
  online:
    num_groups: 4
    reverse_ratio: 0.2
    update_history: False
  map:
    use_map: True #True
    map_ratio: 0.15


logging:
  project: diffusion_policy_debug
  resume: True
  mode: online
  name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
  tags: ["${name}", "${task_name}", "${exp_name}"]
  id: null
  group: null

checkpoint:
  topk:
    monitor_key: train_loss
    mode: max
    k: 5
    format_str: 'epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt'
  save_last_ckpt: True
  save_last_snapshot: False

multi_run:
  run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}

hydra:
  job:
    override_dirname: ${name}
  run:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  sweep:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
    subdir: ${hydra.job.num}
