app: vjepa_wm
nodes: 4
tasks_per_node: 8
cpus_per_task: 16
folder: ${JEPAWM_LOGS}/droid_final_sweep/droid_vitgopen
data:
  # Dataset configuration
  dataset_type: custom
  datasets:
    - DROID
  datasets_weights: null
  seed: 234
  img_size: 256
  # Validation configuration
  validation:
    val_datasets:
      - Franka_hf
    num_frames_val: 5
    val_dataset_batch_size: 4
    val_dataset_drop_last: false
    val_dataset_fpcs: [8]
    val_dataset_camera_views: ["exterior_image_2_left"]
    val_viz_rank0_loader: true
    val_datasets_1:
      names:
        - Franka_hf
      batch_size: 4
      drop_last: False
      fps: 4
      fpcs:
        - 8
      camera_views:
        - exterior_image_2_left
  # DataLoader configuration
  loader:
    batch_size: 8
    num_workers: 16
    pin_mem: true
    persistent_workers: true
  # Custom dataset parameters
  custom:
    split_ratio: null
    frameskip: 1
    action_skip: 1
    state_skip: 1
    normalize_action: false
    traj_subset: true
    filter_first_episodes: null
    filter_tasks: null
    num_hist: null
    num_pred: null
    with_reward: null
    custom_teleop_dset: null
  # Droid-specific parameters
  droid:
    camera_frame: null
    camera_views:
      - left_mp4_path
    droid_to_rcasa_action_format: 1
    rcasa_to_droid_action_format: null
    fps: 4
    dataset_fpcs: [8]
    # comment out all except liftcup_v0 for counterfactuals
    mpk_manifest_patterns: [
      '**/pick/liftcup_v0/run_0001/episode.h5',
      '**/pick/pickandplaceredcube_v0/run_0001/episode.h5',
      '**/pick/pickcube_v0/run_0001/episode.h5',
      '**/pick/pickpen_v0/run_0001/episode.h5',
      '**/pick/pickupcup_v0/run_0001/episode.h5',
      '**/pick/reachcup_v0/run_0001/episode.h5',
      '**/pick/reachliftcup_v0/run_0001/episode.h5',
      '**/pick/reachliftcup_v1/run_0001/episode.h5',

      '**/push/brownboxpush_v0/run_0001/episode.h5',

      '**/push/push_various_objects/blue_bowl/episode.h5',
      '**/push/push_various_objects/blue_box/episode.h5',
      '**/push/push_various_objects/cap/episode.h5',
      '**/push/push_various_objects/pengiun_plush/episode.h5',

      '**/folding/foldjacketsleeve_v0/run_0001/episode.h5',
      '**/folding/foldjacketsleeve_v1/run_0001/episode.h5'
    ]
data_aug:
  auto_augment: false
  random_horizontal_flip: false
  motion_shift: false
  random_resize_aspect_ratio:
  - 1.
  - 1.
  random_resize_scale:
  - 1.777
  - 1.777
  reprob: 0.0
  normalize: [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
logging:
  write_tag: jepa
  wandb:
    use_wandb: true
    debug: false
    project: vjepa_wm
    disable_wandb_media: true
    log_media_locally: true
loss:
  cos_loss_weight: 0.0
  l1_loss_weight: 1.
  l2_loss_weight: 0.
  smooth_l1_loss_weight: 0.0
  proprio_loss: False
meta:
  plan_only_eval_mode: false
  light_eval_only_mode: false
  unroll_decode_eval_only_mode: false
  quick_debug: false
  freeze_encoder: true
  load_checkpoint: true
  load_opt_scale_epoch: true
  read_checkpoint: null
  pretrained_path:
  seed: 234
  eval_freq: 1
  light_eval_freq: 200
  save_every_freq: 1
  dtype: bfloat16
  data_traj_rollout_eval:
    do_data_traj_rollout_eval: true
    data_traj_eval_rollout_steps: 6
    data_traj_decode_gt: true
    data_traj_eval_ctxt_window: 8
  energy_landscape_eval:
    do_energy_landscape_eval: true
    energy_landscape_rollout_steps: 1
    energy_landscape_ctxt_window: 3
model:
  # Shared fields
  grid_size: 16
  tubelet_size_enc: 1
  use_activation_checkpointing: false
  action_conditioning: token
  proprio_encoding: token
  num_frames_pred: 512
  # Visual encoder config
  visual_encoder:
    enc_type: vjepa
    enc_version: v2_open
    pretrain_enc_path: ${JEPAWM_OSSCKPT}/vjepa2_opensource/vjepa2_vit_giant.pth
    pretrain_enc_ckpt_key: encoder
    embed_dim: 1408
    enc_use_rope: true
    enc_name: vit_giant_xformers
    use_sdpa_enc: true
    num_frames_enc: 512
    uniform_power: true
  # Action encoder config
  action_encoder:
    action_tokens: 1
    action_emb_dim: 0
    act_mlp: false
    action_encoder_inpred: true
  # Proprio encoder config
  proprio_encoder:
    proprio_tokens: 1
    proprio_emb_dim: 0
    prop_mlp: false
    proprio_encoder_inpred: true
  # Predictor config
  predictor:
    tubelet_size: 1
    pred_num_heads: 16
    pred_depth: 24
    pred_embed_dim: 1024
    pred_use_extrinsics: false
    pred_type: vjepa2_ac
    act_pred_projector: null
    use_SiLU: null
    use_rope: true
  # VideoWM encoding()
  wm_encoding:
    batchify_video: true
    dup_image: true
    normalize_reps: true
    proprio_rollout_mode: use_ground_truth
  # Rollout config
  rollout_cfg:
    rollout_steps: 2
    train_rollout_prefixes: first
    rollout_stop_gradient: false
    ctxt_window_train_rollout: 8
    do_parallel_rollout: null
    do_sequential_rollout: true
    prepend_gt: null
    sampling_scheduler:
      type: linear
      start: 0.
      end: 0.
  # Attention config (passed as cfgs_attn_pattern to init_video_model)
  attn:
    local_window_time: 8
    local_window_h: -1
    local_window_w: -1
  # Heads config
  heads_cfg:
    architectures:
      image_head:
        kind: vit
        config:
          patch_size: 8
          in_chans: 3
          img_size: [256, 256]
          embed_dim: 1408
          decoder_embed_dim: 1024
          depth: 12
          num_heads: 16
          mlp_ratio: 4.0
          num_views: 1
          use_activation_checkpointing: false
    pretrain_dec_path:
      state_head: null
      image_head: ${JEPAWM_LOGS}/vm2m/opensource_decs/step2_lpips_vm2m_vjepa2vitgopen_vgxf_vitldec_dup_256_vjtransf_norm_bs4_4n/jepa-latest.pth.tar
    new_path_heads:
      state_head: null
      image_head: true
optimization:
  main_optimizer: transition_model
  train_heads: false
  transition_model:
    iterations_per_epoch: 300
    ipe_scale: 1.
    clip_grad: 10.
    use_radamw: false
    betas: [0.9, 0.999]
    eps: 1.e-8
    weight_decay: 0.04
    final_weight_decay: 0.04
    num_epochs: 315
    warmup: 15
    anneal_steps: 15
    use_wsd_schedule: True
    start_lr: 0.000075
    ref_lr: 0.000425
    final_lr: 0.
    mixed_precision: true
evals:
  separate: true
  decode: true
  eval_episodes: 64 # 64 for DROID, 32 for rcasa
  nodes: 1 # 1 for DROID, 2 for rcasa
  low_pri: true
  obs: rgb_state
  alpha: 0
  override_cfgs_data: true # always True
  override_datasets: true # False to eval droid on rcasa
  dump_eval_configs: false
  wrapper_kwargs:
    ctxt_window: 2
    proprio_mode: compute_new_pose # predict_proprio | compute_new_pose
  eval_cfg_paths:
  # Robocasa
  # - configs/online_plan_evals/rcasa_custom/gd/place_L1_gd_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/gd/place_L2_gd_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/gd/reach_L1_gd_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/gd/reach_L2_gd_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml

  # - configs/online_plan_evals/rcasa_custom/reach_L1_cem_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/reach_L2_cem_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/place_L1_cem_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/place_L2_cem_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml

  # - configs/online_plan_evals/rcasa_custom/ng/reach_L1_ng_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/ng/reach_L2_ng_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/ng/place_L1_ng_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/ng/place_L2_ng_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml

  # - configs/online_plan_evals/rcasa_custom/adam/reach_L1_adam_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/adam/reach_L2_adam_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/adam/place_L1_adam_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # - configs/online_plan_evals/rcasa_custom/adam/place_L2_adam_sourcedset_H3_nas1_maxnorm005_scaleact_repeat5_fskip5_max60_ctxt2.yaml
  # DROID
    - configs/online_plan_evals/droid/ng/droid_L2_ng_sourcedset_H3_nas3_maxnorm01_ctxt2_gH3.yaml
    - configs/online_plan_evals/droid/droid_L2_cem_sourcedset_H3_nas3_maxnorm01_ctxt2_gH3.yaml
    - configs/online_plan_evals/droid/ng/droid_L1_ng_sourcedset_H3_nas3_maxnorm01_ctxt2_gH3.yaml
    - configs/online_plan_evals/droid/droid_L1_cem_sourcedset_H3_nas3_maxnorm01_ctxt2_gH3.yaml
    - configs/online_plan_evals/droid/gd/droid_L1_gd_sourcedset_H3_nas3_maxnorm01_ctxt2_gH3.yaml
    - configs/online_plan_evals/droid/gd/droid_L2_gd_sourcedset_H3_nas3_maxnorm01_ctxt2_gH3.yaml
    - configs/online_plan_evals/droid/adam/droid_L1_adam_sourcedset_H3_nas3_maxnorm01_ctxt2_gH3.yaml
    - configs/online_plan_evals/droid/adam/droid_L2_adam_sourcedset_H3_nas3_maxnorm01_ctxt2_gH3.yaml
unroll_decode_evals:
  wrapper_kwargs:
    ctxt_window: 2
    proprio_mode: compute_new_pose
  specific_video: false
  specific_video_path: app/vjepa2_ac/offline_eval/franka_example_traj.npz
  play_in_reverse: false
  save_decoding_only: true
  repeat_hardcode_act: 1
  obs: rgb_state
