name: ""
_target_: diffusion_policy.workspace.arp_workspace.ARPWorkspace
  
# ! ========================= Data ========================== 

task:
  dataset:
    _target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
    horizon: 16
    max_train_episodes: 90
    pad_after: 7
    pad_before: 1
    seed: 42
    val_ratio: 0.02
    zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
  env_runner:
    _target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
    log_video: false
    fps: 10
    legacy_test: true
    max_steps: 300
    n_action_steps: 8
    n_envs: null
    n_obs_steps: 2
    n_test: 50
    n_test_vis: 4
    n_train: 6
    n_train_vis: 2
    past_action: false
    test_start_seed: 100000
    train_start_seed: 0
  image_shape: [3, 96, 96]
  name: pusht_image
  shape_meta:
    action:
      shape: [2]
    obs:
      agent_pos:
        shape: [2]
        type: low_dim
      image:
        shape: [3, 96, 96]
        type: rgb
  
dataloader:
  batch_size: 128 #64
  num_workers: 8
  persistent_workers: false
  pin_memory: true
  shuffle: true

val_dataloader:
  batch_size: 128 # 64
  num_workers: 8
  persistent_workers: false
  pin_memory: true
  shuffle: false


# ! ========================= Training ========================== 


training: 
  checkpoint_every: 2000 # 50
  debug: false
  device: cuda:0
  gradient_accumulate_every: 1
  lr_scheduler: cosine_with_restarts # change
  lr_num_cycles: 10 # change
  lr_warmup_steps: 500
  max_train_steps: null
  max_val_steps: null
  num_epochs: 12200 # 3050
  resume: false
  rollout_every: 50 # 50
  sample_every: 20 # 5
  seed: 42
  tqdm_interval_sec: 1.0
  use_ema: true
  val_every: 50 # 1
  stop_at_epoch: 2500
  add_full_horizon_eval: false

ema:
  _target_: diffusion_policy.model.diffusion.ema_model.EMAModel
  inv_gamma: 1.0
  max_value: 0.9999
  min_value: 0.0
  power: 0.75
  update_after_step: 0

optimizer: # TOCHANGE
  betas: [0.95, 0.999]
  lr: 0.0001
  obs_encoder_weight_decay: 1.0e-06
  transformer_weight_decay: 0.001

logging:
  group: null
  id: null
  mode: offline
  name: ""
  project: iclr2025
  tags: []


checkpoint:
  save_last_ckpt: true
  save_last_snapshot: false
  topk:
    format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
    k: 5
    mode: max
    monitor_key: test_mean_score

  

# ! ========================= Model ==========================


policy:
  _target_: diffusion_policy.policy.autoregressive_policy.ARPolicy

  shape_meta:
    action:
      shape: [2]
    obs:
      agent_pos: # will be ignore in encoder
        shape: [2]
        type: low_dim
      image:
        shape: [3, 96, 96]
        type: rgb

  crop_shape: [84, 84]
  eval_fixed_crop: true
  obs_encoder_group_norm: true

  horizon: 16
  n_action_steps: 8
  n_obs_steps: 2

  arp_cfg: 
    plan_steps: 4
    plan_dict_size: 100 # not used
    plan_upscale_ratio: 8
    plan_corr_dim: -1
    reverse_plan: True

    plan_chunk_size: 4
    action_chunk_size: 16  

    
    num_latents: 1
    layer_norm_every_block: False

    n_embd: 64  # or 96
    embd_pdrop: 0.1  
    num_layers: 30

    layer_cfg:
      n_head: 8
      mlp_ratio: 4.0
      AdaLN: True
      mlp_dropout: 0.1
      attn_kwargs: { "attn_pdrop": 0.1, "resid_pdrop": 0.1 }
      cond_attn_kwargs: { "attn_pdrop": 0.1, "resid_pdrop": 0.1 }
    
    sample: True
    augment_ratio: 0.
    low_var_eval: False