defaults:
  - _self_
hydra:
  run:
    dir: ${log_dir}
_target_: src.agent.train.TrainAgent

log_dir: ${oc.env:VLA_LOG_DIR}/train/${name}/${now:%Y-%m-%d}_${now:%H-%M}_${seed}
name: ${data.train.dataset_mix}_${data.train.split}_tp${horizon_steps}_${flow_sampling}
device: cuda
n_nodes: 1
seed: 42
pretrained_model_path: ${oc.env:TRANSFORMERS_CACHE}/paligemma-3b-pt-224
load_pretrained_weights: True
resume_checkpoint_path:
train_vlm: True
action_expert_adaptive_mode:  # adaLN, adaLN-Zero, or None
use_torch_compile: True
use_bf16: True
use_amp: True
quantize: False
lora: False
lora_r: 32
lora_dropout: 0.0
debug: False

use_ema: False
ema_decay: 0.99
ema_start: ${save_model_start}
ema_freq: 1
ema_device: cuda
use_swa: False
swa_start: ${eval:'1550000 // ${global_batch_size} * 5'}
swa_freq: ${eval:'1550000 // ${global_batch_size} // 4'}
swa_device: cpu

data:
  val:  # stil aplying image randomization
    split:  # full val split
    shuffle_buffer_size: 10000
  train:
    dataset_mix: bridge
    split: train
    data_path: ${oc.env:VLA_DATA_DIR}/
    window_size: ${cond_steps}
    action_horizon: ${horizon_steps}
    skip_unlabeled: True
    load_proprio: True
    shuffle_buffer_size: 200000
    num_parallel_calls: 100
    traj_transform_threads: 10
    traj_read_threads: 10

wandb:
  entity: ${oc.env:VLA_WANDB_ENTITY}
  project: open-pi-zero
  run: ${now:%H-%M-%S}_${name}

log_freq: 16
n_epochs: 30  # provided ckpts were about 12 epochs
n_updates: ${eval:'1550000 // ${global_batch_size} * ${n_epochs}'}  # bridge dataset has 2195527 transitions in total, but we are skipping (text-)unlabeled episodes, which is roughly 30% of the data
save_model_freq: ${eval:'1550000 // ${global_batch_size} * 1'}
save_model_start: ${eval:'1550000 // ${global_batch_size} * 5'}
eval_freq: 8000 # per_device_batch_size
eval_size: 1024
eval_thresholds: [0.05, 0.1, 0.2, 0.3, 0.5]

global_batch_size: 1024
per_device_batch_size: 32
action_lr: 5e-5
vlm_lr: 5e-5
action_lr_scheduler:
  first_cycle_steps: 10000000 # basically no decaying
  min_lr: 1e-8
  warmup_steps: 200 # a bit of warmup
vlm_lr_scheduler:
  first_cycle_steps: 10000000
  min_lr: 1e-8
  warmup_steps: 200
action_weight_decay: 0
vlm_weight_decay: 0
max_grad_norm: 1.0

flow_sampling: beta
num_inference_steps: 10
final_action_clip_value: 1.0  # data normalized in [-1,1]

cond_steps: 1
horizon_steps: 4
action_dim: 7 # EEF_POS
proprio_dim: 7  # POS_EULER
max_seq_len: 276  # fixed 256 for image + max 20 for text
tokenizer_padding: max_length # instead of truncating to longest
max_image_text_tokens: ${max_seq_len}

mixture:
  vlm:   # gemma
    hidden_size: 2048
    intermediate_size: 16384
    use_final_norm: False
    cache: True
    use_quantize: ${quantize}
    use_lora: ${lora}
    adaptive_mode:  # not applicable for gemma
    rope_theta: 10000.0  # 10000 in gemma
  proprio:
    hidden_size: 1024
    intermediate_size: 4096
    use_final_norm: True  # technically no, but sharing weights with action anyway
    cache: True
    use_quantize: False
    use_lora: False
    adaptive_mode: ${action_expert_adaptive_mode}
    rope_theta: ${action_expert_rope_theta}
  action:
    hidden_size: 1024
    intermediate_size: 4096
    use_final_norm: True
    cache: False
    use_quantize: False
    use_lora: False
    adaptive_mode: ${action_expert_adaptive_mode}
    rope_theta: ${action_expert_rope_theta}
time_hidden_size: 256 # only applicable if using adaptive
time_max_period: 100.0
action_expert_rope_theta: 100.0   # since action/proprio seq_len is pretty small

# Fixed
image_token_index: 257152
vocab_size: 257216
pad_token_id: 0

vision:
  _target_: src.model.paligemma.siglip.SiglipVisionModel
  config:
    hidden_size: 1152 # siglip
    intermediate_size: 4304
    num_hidden_layers: 27
    num_attention_heads: 16
    num_channels: 3
    image_size: 224
    patch_size: 14
    layer_norm_eps: 1e-6
    attention_dropout: 0.0
    num_image_tokens: 256
    lora:
      r: ${lora_r}
      dropout: ${lora_dropout}
  use_quantize: ${quantize}
  use_lora: ${lora}

vision_projector:
  _target_: src.model.paligemma.siglip.PaliGemmaMultiModalProjector
  config:
    vision_config:
      hidden_size: 1152
      projection_dim: 2048
    lora:
      r: ${lora_r}
      dropout: ${lora_dropout}
  use_quantize: ${quantize}
  use_lora: ${lora}

joint:
  _target_: src.model.vla.joint_model.JointModel
  config:
    action_expert_adaptive_mode: ${action_expert_adaptive_mode}
    time_hidden_size: ${time_hidden_size}
    mixture: ${mixture}
    lora:
      r: ${lora_r}
      dropout: ${lora_dropout}
    #
    num_hidden_layers: 18
    num_attention_heads: 8
    num_key_value_heads: 1
    head_dim: 256
    rms_norm_eps: 1e-6
    attention_bias: False
    attention_dropout: 0.0
    pad_token_id: ${pad_token_id}
