defaults:
  - vggt_realbot_dataset.yaml

exp_name: vggt_realbot
img_size: 518
num_workers: 0
seed_value: 42
accum_steps: 3
patch_size: 14
limit_train_batches: 800
limit_val_batches: 0
view_num: 3

# Fix aspect ratio to 480/640=0.75; use image size 518
data:
  train:
    common_config:
      fix_img_num: ${view_num}
      fix_aspect_ratio: 0.75
      img_nums: [3,3]
    dataset:
      view_num: ${view_num}
  val:
    common_config:
      fix_img_num: ${view_num}
      fix_aspect_ratio: 0.75
      img_nums: [3,3]
    dataset:
      view_num: ${view_num}
early_validation:
  enabled: True  # Enable early validation
  step: 5        # Perform early validation at step 5
  limit_batches: 8  # Limit the number of batches for early validation

logging:
  log_dir: logs
  log_visuals: True
  log_freq: 1
  log_level_primary: INFO
  log_level_secondary: WARNING
  all_ranks: False
  tensorboard_writer:
    _target_: train_utils.tb_writer.TensorBoardLogger
    path: ${logging.log_dir}/tensorboard

  log_visual_frequency:
    train: 100
    val: 1

  visuals_keys_to_log:
    train:
      keys_to_log: ["depth", "world_points"]
      modality: "image"
    val:
      keys_to_log: ["depth", "world_points", "depth_conf"]
      modality: "image"

  visuals_per_batch_to_log: 4
  video_logging_fps: 10

  scalar_keys_to_log:
    train:
      keys_to_log:
        - loss_objective
        - loss_wrist
        - loss_wrist_FL
        - loss_projection
        - valid_track_points
        - total_track_points
        - depth_loss_count
        - uv_loss_sum
        - depth_loss_sum
    val:
      keys_to_log:
        - loss_objective
        - loss_wrist
        - loss_wrist_FL
        - loss_projection
        - valid_track_points
        - total_track_points
        - depth_loss_count
        - uv_loss_sum
        - depth_loss_sum

checkpoint:
  save_dir: logs/${exp_name}/ckpts
  save_freq: 5
  resume_checkpoint_path: logs/vggt_realbot/ckpts/checkpoint.pt
  strict: False

# Enable only wrist FL and projection losses; disable camera/depth/point losses
loss:
  _target_: loss.MultitaskLoss
  camera:
    weight: 0.0
    loss_type: "l1"
  depth:
    weight: 0.0
    gradient_loss_fn: "grad"
    valid_range: 0.98
    mask_nan: True
    robust_loss: True
  point:
    weight: 0.0
    gradient_loss_fn: "normal"
    valid_range: 0.98
  track: null
  wrist:
    weight: 10.0
    loss_type: "l1"
    gamma: 0.6
    pose_encoding_type: "absT_quaR_FoV"
    weight_trans: 0.0
    weight_rot: 0.0
    weight_focal: 1.0
    mask_invalid: False
  projection:
    weight: 1.0
    depth_loss_weight: 1
    track_confidence_threshold: 0.1
    max_track_points: 1024

optim:
  param_group_modifiers: False

  optimizer:
    _target_: torch.optim.AdamW
    weight_decay: 0.05

  # Adjust freezing strategy: freeze depth and point related modules
  frozen_module_names: 
    - "aggregator*"     # Freeze all aggregator-related modules
    - "depth_head*"     # Freeze depth head; keep pretrained weights for visualization
    - "point_head*"     # Freeze point head; keep pretrained weights for visualization

  amp:
    enabled: True
    amp_dtype: bfloat16
  
  # Adjust gradient clipping: only for camera and wrist heads
  gradient_clip:
    _target_: train_utils.gradient_clip.GradientClipper
    configs:
      - module_name: ["camera_head"]  # Keep camera_head
        max_norm: 1.0
        norm_type: 2
      - module_name: ["wrist_head"]   # Keep wrist_head
        max_norm: 1.0
        norm_type: 2

  # Adjust grouped LR config: optimize only camera and wrist heads
  options:
    lr:
      # Wrist head: 2e-5
      - param_names: ["wrist_head*"]
        scheduler:
          _target_: fvcore.common.param_scheduler.CompositeParamScheduler
          schedulers:
            - _target_: fvcore.common.param_scheduler.LinearParamScheduler
              start_value: 1e-8
              end_value: 2e-5
            - _target_: fvcore.common.param_scheduler.CosineParamScheduler
              start_value: 2e-5
              end_value: 1e-8
          lengths: [0.05, 0.95]
          interval_scaling: ['rescaled', 'rescaled']
      
      # Camera head: 5e-6 (remove depth_head and point_head)
      - param_names: ["camera_head*"]
        scheduler:
          _target_: fvcore.common.param_scheduler.CompositeParamScheduler
          schedulers:
            - _target_: fvcore.common.param_scheduler.LinearParamScheduler
              start_value: 1e-9
              end_value: 5e-6
            - _target_: fvcore.common.param_scheduler.CosineParamScheduler
              start_value: 5e-6
              end_value: 1e-9
          lengths: [0.05, 0.95]
          interval_scaling: ['rescaled', 'rescaled']
    
    weight_decay:
      # Use the same weight decay for all parameters
      - scheduler:
          _target_: fvcore.common.param_scheduler.ConstantParamScheduler
          value: 0.05

max_epochs: 100

# Important: keep full inference capability but only train camera parameters
model:
  _target_: vggt.models.vggt.VGGT
  img_size: ${img_size}
  patch_size: ${patch_size}
  embed_dim: 1024
  enable_camera: True   # Keep main camera pose prediction
  enable_depth: True    # Keep depth prediction (for inference/visualization)
  enable_point: True    # Keep point cloud prediction (for inference/visualization)
  enable_track: True    # Enable track prediction for evaluation
  enable_wrist: True    # Keep wrist pose prediction
  
  # NEW: WristHead config - use all tokens + multi-layer Transformer aggregation
  wrist_head_config:
    token_aggregation: "attention"  # "attention", "mean", "weighted_mean"
    
    # Multi-layer Transformer token aggregation params
    aggregation_num_layers: 3       # Number of Transformer layers
    aggregation_num_heads: 8        # Number of attention heads
    aggregation_dropout: 0.1        # Dropout ratio
    
    # Legacy parameters
    trunk_depth: 4
    num_heads: 16
    mlp_ratio: 4
    init_values: 0.01
    
  pretrained: "facebook/VGGT-1B"  # Load pretrained model
  use_lora: False  # Do not use LoRA; fine-tune directly
  lora_rank: 16  
  lora_alpha: 32  

# Distributed training configuration
distributed:
  backend: nccl
  comms_dtype: None
  find_unused_parameters: False
  timeout_mins: 30
  gradient_as_bucket_view: True
  bucket_cap_mb: 25
  broadcast_buffers: True

# CUDA configuration
cuda:
    cudnn_deterministic: False
    cudnn_benchmark: False
    allow_tf32: True 