app: vjepa_wm
nodes: 2
tasks_per_node: 8
cpus_per_task: 10
folder: ${JEPAWM_LOGS}/mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_save_2n
data:
  # Dataset configuration
  dataset_type: custom
  datasets:
    - PointMaze
  datasets_weights: null
  seed: 234
  img_size: 224
  # Validation configuration
  validation:
    val_datasets: []
    val_datasets_1: null
    num_frames_val: 8
    val_dataset_batch_size: 4
    val_dataset_drop_last: false
    val_dataset_fpcs: [8]
    val_dataset_camera_views: null
    val_viz_rank0_loader: false
  # DataLoader configuration
  loader:
    batch_size: 8
    num_workers: 16
    pin_mem: true
    persistent_workers: true
  # Custom dataset parameters
  custom:
    split_ratio: 0.9
    frameskip: 5
    action_skip: 1
    state_skip: 1
    normalize_action: true
    traj_subset: true
    filter_first_episodes: null
    filter_tasks: null
    num_hist: 3
    num_pred: 1
    with_reward: false
    custom_teleop_dset: null
  # Droid-specific parameters
  droid:
    camera_frame: false
    camera_views:
      - 2
    droid_to_rcasa_action_format: 1
    rcasa_to_droid_action_format: false
    fps: 4
    dataset_fpcs: [8]
    mpk_manifest_patterns: []
data_aug:
  auto_augment: false
  random_horizontal_flip: false
  motion_shift: false
  random_resize_aspect_ratio:
  - 1.
  - 1.
  random_resize_scale:
  - 1.
  - 1.
  reprob: 0.0
  normalize: [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]
logging:
  write_tag: jepa
  wandb:
    use_wandb: true
    debug: false
    project: vjepa_wm
    disable_wandb_media: true
    log_media_locally: true
loss:
  cos_loss_weight: 0.0
  l1_loss_weight: 0.0
  l2_loss_weight: 1.0
  smooth_l1_loss_weight: 0.0
meta:
  plan_only_eval_mode: false
  light_eval_only_mode: false
  quick_debug: false
  freeze_encoder: true
  load_checkpoint: true
  load_opt_scale_epoch: true
  read_checkpoint: null
  pretrained_path:
  seed: 234
  eval_freq: 1
  light_eval_freq: 300
  save_every_freq: 1
  dtype: bfloat16
  data_traj_rollout_eval:
    do_data_traj_rollout_eval: true
    data_traj_eval_rollout_steps: 6
    data_traj_decode_gt: true
    data_traj_eval_ctxt_window: 3
  energy_landscape_eval:
    do_energy_landscape_eval: false
    energy_landscape_rollout_steps: 1
    energy_landscape_ctxt_window: 3
model:
  # Shared fields
  grid_size: 16
  tubelet_size_enc: 1
  use_activation_checkpointing: false
  action_conditioning: feature
  proprio_encoding: feature
  num_frames_pred: 4
  # Visual encoder config
  visual_encoder:
    enc_type: dino
    enc_version: dinov2_vits14
    pretrain_enc_path: null
    pretrain_enc_ckpt_key: target_encoder
    embed_dim: 384
    enc_use_rope: null
    enc_name: null
    use_sdpa_enc: null
    num_frames_enc: null
    uniform_power: true
  # Action encoder config
  action_encoder:
    action_tokens: 0
    action_emb_dim: 10
    act_mlp: false
    action_encoder_inpred: false
  # Proprio encoder config
  proprio_encoder:
    proprio_tokens: 0
    proprio_emb_dim: 0
    prop_mlp: false
    proprio_encoder_inpred: false
  # Predictor config
  predictor:
    tubelet_size: 1
    pred_num_heads: 16
    pred_depth: 6
    pred_embed_dim: 384
    pred_use_extrinsics: false
    pred_type: dino_wm
    act_pred_projector: false
    use_SiLU: false
    use_rope: false
    use_sdpa: true
  # VideoWM encoding()
  wm_encoding:
    batchify_video: true
    dup_image: false
    normalize_reps: false
  # Rollout config
  rollout_cfg:
    rollout_steps: 1
    train_rollout_prefixes: random
    rollout_stop_gradient: true
    ctxt_window_train_rollout: 3
    do_parallel_rollout: false
    do_sequential_rollout: true
    prepend_gt: false
    sampling_scheduler:
      type: linear
      start: 0.
      end: 0.
  # Attention config (passed as cfgs_attn_pattern to init_video_model)
  attn:
    local_window_time: 3
    local_window_h: -1
    local_window_w: -1
  # Decoder heads config (optional - for visualization only)
  # To enable, copy heads_cfg from your trained decoder config in configs/vjepa_wm/vm2m/open_source_decs/
  heads_cfg:
    architectures: {}
    pretrain_dec_path:
      # state_head: ${JEPAWM_LOGS}/mz/step2_mz_state_head_dinovits_r224/jepa-latest.pth.tar
      # image_head: ${JEPAWM_LOGS}/mz/step2_mz_vitsdec_dinovitsenc_224_beta0.95/jepa-latest.pth.tar
    new_path_heads:
      state_head: true
      image_head: true
optimization:
  main_optimizer: transition_model
  train_heads: false
  transition_model:
    iterations_per_epoch: null
    ipe_scale: 1.
    clip_grad: 1
    use_radamw: false
    betas: [0.9, 0.999]
    eps: 1.e-8
    weight_decay: 1.e-7
    final_weight_decay: 1.e-6
    num_epochs: 50
    warmup: 0
    start_lr: 5.e-4
    ref_lr: 5.e-4
    final_lr: 5.e-4
    mixed_precision: true
evals:
  dump_eval_configs: false
  decode: false
  eval_episodes: 96
  nodes: 1
  low_pri: false
  obs: rgb
  alpha: 0
  eval_cfg_paths:
    # rand state
    # - configs/online_plan_evals/mz/gd/mz_L1_gd_sourcerandstate_H6_nas6_ctxt2.yaml
    # - configs/online_plan_evals/mz/gd/mz_L2_gd_sourcerandstate_H6_nas6_ctxt2.yaml
    - configs/online_plan_evals/mz/ng/mz_L2_ng_sourcerandstate_H6_nas6_ctxt2.yaml
    - configs/online_plan_evals/mz/mz_L2_cem_sourcerandstate_H6_nas6_ctxt2.yaml
    - configs/online_plan_evals/mz/ng/mz_L1_ng_sourcerandstate_H6_nas6_ctxt2.yaml
    - configs/online_plan_evals/mz/mz_L1_cem_sourcerandstate_H6_nas6_ctxt2.yaml
