defaults:
  - _self_
  - task: kitchen_lowdim_hdf5_abs #kitchen_lowdim_abs

name: train_diffusion_transformer_lowdim
_target_: diffusion_policy.workspace.pbrl_diffusion_transformer_lowdim_online_type1_workspace.PbrlDiffusionTransformerLowdimWorkspace #_type1_, _type2_
checkpoint_dir: 'data/experiments/low_dim/kitchen/diffusion_policy_transformer/base_0.5/seed=42.ckpt'

obs_dim: ${task.obs_dim}
action_dim: ${task.action_dim}
task_name: ${task.name}
exp_name: "default"

gamma: 0.997
horizon: 16
n_obs_steps: 4
n_action_steps: 8
n_latency_steps: 0
past_action_visible: False
keypoint_visible_rate: 1.0
obs_as_cond: True
pred_action_steps_only: False

policy: 
  _target_: diffusion_policy.policy.cpl_diffusion_transformer_lowdim_policy.DiffusionTransformerLowdimPolicy #cpl_, dpo_, ours_

  model: 
    _target_: diffusion_policy.model.diffusion.transformer_for_diffusion.TransformerForDiffusion
    input_dim: ${eval:'${action_dim} if ${obs_as_cond} else ${obs_dim} + ${action_dim}'}
    output_dim: ${policy.model.input_dim}
    horizon: ${horizon}
    n_obs_steps: ${n_obs_steps}
    cond_dim: ${eval:'${obs_dim} if ${obs_as_cond} else 0'}

    n_layer: 8
    n_head: 4
    n_emb: 768
    p_drop_emb: 0.0
    p_drop_attn: 0.1

    causal_attn: True
    time_as_cond: True # if false, use BERT like encoder only arch, time as input
    obs_as_cond: ${obs_as_cond}
    n_cond_layers: 0 # >0: use transformer encoder for cond, otherwise use MLP

  noise_scheduler:
    _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
    num_train_timesteps: 100
    beta_start: 0.0001
    beta_end: 0.02
    beta_schedule: squaredcos_cap_v2
    variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
    clip_sample: True # required when predict_epsilon=False
    prediction_type: epsilon # or sample

  gamma: ${gamma}
  beta: 1.785e-06 #cpl:1.785e-06 ;dpo:1.785e-06 ours:3.57e-05
  use_map: ${training.map.use_map}
  map_ratio: ${training.map.map_ratio}
  bias_reg: 0.25
  train_time_samples: 1
  horizon: ${horizon}
  obs_dim: ${obs_dim}
  action_dim: ${action_dim}
  n_action_steps: ${n_action_steps}
  n_obs_steps: ${n_obs_steps}
  num_inference_steps: 100
  obs_as_cond: ${obs_as_cond}
  pred_action_steps_only: ${pred_action_steps_only}

  # scheduler.step params
  # predict_epsilon: True

ema:
  _target_: diffusion_policy.model.diffusion.ema_model.EMAModel
  update_after_step: 0
  inv_gamma: 1.0 #change
  power: 0.75
  min_value: 0.0
  max_value: 0.9999

dataloader:
  batch_size: 18
  num_workers: 1
  shuffle: True
  pin_memory: True
  persistent_workers: False

val_dataloader:
  batch_size: 18
  num_workers: 1
  shuffle: True
  pin_memory: True
  persistent_workers: False

optimizer:
  learning_rate: 2.0e-5
  weight_decay: 1.0e-3
  betas: [0.9, 0.95]

training:
  device_cpu: "cpu"
  device_gpu: "cuda:0"
  seed: 42
  debug: False
  resume: True
  # optimization
  lr_scheduler: cosine
  # Transformer needs LR warmup
  lr_warmup_steps: 0
  num_epochs: 10
  # lr_end: 1.0e-8
  # power: 1.5
  gradient_accumulate_every: 1
  use_ema: True
  # training loop control
  # in epochs
  rollout_every: 25
  checkpoint_every: 25
  val_every: 1
  sample_every: 5
  # steps per epoch
  max_train_steps: null
  max_val_steps: null
  # misc
  tqdm_interval_sec: 1.0
  dataset_1_dir: 'data/kitchen/Minari/normal/base_0.5/kitchen_data_0.5.h5'
  dataset_2_dir: 'data/kitchen/Minari/normal/base_0.5/kitchen_data_0.5.h5'
  online:
    num_groups: 1 #4
    all_votes: 100
    reverse_ratio: 0.2
    reverse_rate: 0.5 #0.25, 0.5, 0.75
    reverse_freq: 1
    update_history: False
  map:
    use_map: True # only True when use our method
    map_ratio: 0.25


logging:
  project: diffusion_policy_debug
  resume: True
  mode: online
  name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
  tags: ["${name}", "${task_name}", "${exp_name}"]
  id: null
  group: null

checkpoint:
  topk:
    monitor_key: test_mean_score
    mode: max
    k: 5
    format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
  save_last_ckpt: True
  save_last_snapshot: False

multi_run:
  run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}

hydra:
  job:
    override_dirname: ${name}
  run:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  sweep:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
    subdir: ${hydra.job.num}
