app: vjepa_wm
nodes: 2
tasks_per_node: 8
cpus_per_task: 16
folder: ${JEPAWM_LOGS}/mz/step2_mz_state_head_dinovits_tra05_r224
data:
  # Dataset configuration
  dataset_type: custom
  datasets:
    - PointMaze
  seed: 234
  img_size: 224
  # Validation configuration
  validation:
    val_datasets: null
    num_frames_val: 8
    val_dataset_fpcs: [8]
  # DataLoader configuration
  loader:
    batch_size: 8
    num_workers: 16
    pin_mem: true
    persistent_workers: true
  # Simulation dataset parameters
  custom:
    num_hist: 3
    num_pred: 1
    state_skip: 1
    frameskip: 5
    action_skip: 1
    traj_subset: true
    normalize_action: true
data_aug:
  normalize: [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]
logging:
  write_tag: jepa
  wandb:
    use_wandb: True
    debug: false
    project: vjepa_wm
    disable_wandb_media: True
    log_media_locally: True
loss:
  cos_loss_weight: 0.0
  l1_loss_weight: 0.0
  l2_loss_weight: 1.0
  smooth_l1_loss_weight: 0.0
meta:
  plan_only_eval_mode: false
  light_eval_only_mode: false
  quick_debug: false
  freeze_encoder: true
  load_checkpoint: true
  load_opt_scale_epoch: true
  read_checkpoint: null
  seed: 234
  eval_freq: 1
  light_eval_freq: 2000
  save_every_freq: -1
  dtype: bfloat16
  data_traj_rollout_eval:
    do_data_traj_rollout_eval: false
  energy_landscape_eval:
    do_energy_landscape_eval: false
    energy_landscape_rollout_steps: 1
    energy_landscape_ctxt_window: 3
model:
  # Shared fields
  grid_size: 16
  tubelet_size_enc: 1
  use_activation_checkpointing: false
  action_conditioning: none
  proprio_encoding: none
  num_frames_pred: 4
  # Visual encoder config
  visual_encoder:
    enc_type: dino
    enc_version: dinov2_vits14
    pretrain_enc_path: null
    embed_dim: 384
    enc_use_rope: true
    enc_name: null
    use_sdpa_enc: null
    num_frames_enc: 64
    uniform_power: true
  # Action encoder config
  action_encoder:
    action_tokens: 0
    action_emb_dim: 10
    act_mlp: false
    action_encoder_inpred: false
  # Proprio encoder config
  proprio_encoder:
    proprio_tokens: 0
    proprio_emb_dim: 0
    prop_mlp: false
    proprio_encoder_inpred: false
  # Predictor config
  predictor:
    tubelet_size: 1
    pred_num_heads: 16
    pred_depth: 6
    pred_embed_dim: 384
    pred_type: none
  # VideoWM encoding
  wm_encoding:
    batchify_video: true
    dup_image: false
    normalize_reps: false
  # Rollout config
  rollout_cfg:
    rollout_steps: 1
    train_rollout_prefixes: random
    rollout_stop_gradient: true
    sampling_scheduler:
      type: linear
      start: 0.
      end: 0.
  # Attention config
  attn:
    local_window_time: 8
    local_window_h: -1
    local_window_w: -1
  # Heads config
  heads_cfg:
    architectures:
      state_head:
        kind: vit
        config:
          state_dim: 4
          embed_dim: 384
          decoder_embed_dim: 384
          depth: 6
          num_heads: 16
          mlp_ratio: 4.0
          num_views: 1
          use_activation_checkpointing: false
    pretrain_dec_path: null
optimization:
  main_optimizer: state_head
  train_heads: True
  heads:
    train_predictor: false
    train_heads_on_predictor: false
    state_head:
      use_radamw: false
      betas: [0.9, 0.99]
      eps: 1e-8
      ipe_scale: 1.25
      weight_decay: 0.1
      final_weight_decay: 0.1
      final_lr: 1.0e-6
      start_lr: 0.0
      ref_lr: 0.0005
      warmup: 2
      num_epochs: 40
      iterations_per_epoch: 1000
      clip_grad: 1
      mixed_precision: true
evals:
  eval_cfg_paths:
