defaults:
  - vggt_calvin_datasets.yaml

exp_name: vggt_calvin
img_size: 518  # Square image size divisible by patch_size=14 (1274=14×91)
num_workers: 0  # Training workers
seed_value: 42
accum_steps: 3
patch_size: 14
limit_train_batches: 800
limit_val_batches: 50
view_num: 1

## Number of views: controlled by view_num; data and model work with arbitrary number of views

# Add early validation configuration
early_validation:
  enabled: True  # Enable early validation
  step: 5        # Run early validation at step 5
  limit_batches: 50  # Limit number of batches for early validation

logging:
  log_dir: logs
  log_visuals: True  # Enable visualization
  log_freq: 1
  log_level_primary: INFO  # Use INFO to reduce debug noise
  log_level_secondary: WARNING
  all_ranks: False
  tensorboard_writer:
    _target_: train_utils.tb_writer.TensorBoardLogger
    path: ${logging.log_dir}/tensorboard
  
  # Visualization configuration
  log_visual_frequency:
    train: 100  # Log training visualization every 100 steps
    val: 1     # Log validation visualization every epoch
  
  visuals_keys_to_log:
    train:
      keys_to_log: ["depth", "world_points"]  # Keep visualization
      modality: "image"
    val:
      keys_to_log: ["depth", "world_points", "depth_conf"]  # Keep visualization
      modality: "image"
  
  visuals_per_batch_to_log: 4
  video_logging_fps: 10  # Video log FPS
  scalar_keys_to_log:
    train:
      keys_to_log:
        - loss_objective
        - loss_camera  # Only keep camera-related losses
        - loss_FL
        - loss_wrist  # Keep wrist-related losses
        - loss_wrist_FL
        - loss_projection  # Add projection loss log
        - valid_track_points  # Add valid track points count
        - total_track_points  # Add total track points count
        - depth_loss_count  # Add depth loss count
        - uv_loss_sum  # Add UV loss sum
        - depth_loss_sum  # Add depth loss sum
    val:
      keys_to_log:
        - loss_objective
        - loss_camera  # Only keep camera-related losses
        - loss_FL
        - loss_wrist  # Keep wrist-related losses
        - loss_wrist_FL
        - loss_projection  # Add projection loss log
        - valid_track_points  # Add valid track points count
        - total_track_points  # Add total track points count
        - depth_loss_count  # Add depth loss count
        - uv_loss_sum  # Add UV loss sum
        - depth_loss_sum  # Add depth loss sum

checkpoint:
  save_dir: logs/${exp_name}/ckpts
  save_freq: 5
  resume_checkpoint_path: logs/vggt_calvin/ckpts/checkpoint.pt
  strict: False

# Key change: loss configuration - only keep camera and wrist training
loss:
  _target_: loss.MultitaskLoss
  camera: 
    weight: 5.0
    loss_type: "l1"
  
  # Disable depth loss training (model still produces depth for visualization)
  depth:
    weight: 0.0  # Set to 0 to disable
    gradient_loss_fn: "grad" 
    valid_range: 0.98
    mask_nan: True
    robust_loss: True
  
  # Disable point loss training (model still produces points for visualization)
  point:
    weight: 0.0  # Set to 0 to disable
    gradient_loss_fn: "normal"
    valid_range: 0.98
  
  track: null  # Keep track loss disabled
  
  wrist:  # Keep wrist pose supervision
    weight: 10.0 # Wrist pose supervision weight
    loss_type: "l1"
    gamma: 0.6
    pose_encoding_type: "absT_quaR_FoV"
    weight_trans: 0.0  # Translation component weight
    weight_rot: 0.0    # Rotation component weight
    weight_focal: 1.0  # Focal component weight (lower for wrist)
    mask_invalid: False  # Mask invalid wrist pose samples
  
  # Add projection loss configuration
  projection:
    weight: 1.0  # Projection loss weight
    depth_loss_weight: 1  # Depth loss weight when Z is negative
    track_confidence_threshold: 0.2  # Confidence threshold for tracked points (lower)
    max_track_points: 1024  # Max number of tracked points

optim:
  param_group_modifiers: False

  optimizer:
    _target_: torch.optim.AdamW
    weight_decay: 0.05

  # Freeze strategy: freeze depth and point related modules
  frozen_module_names: 
    - "aggregator*"     # Freeze all aggregator related modules
    - "depth_head*"     # Freeze depth head, keep for visualization
    - "point_head*"     # Freeze point head, keep for visualization

  amp:
    enabled: True
    amp_dtype: bfloat16
  
  # Gradient clipping: only for camera and wrist head
  gradient_clip:
    _target_: train_utils.gradient_clip.GradientClipper
    configs:
      - module_name: ["camera_head"]  # Only keep camera_head
        max_norm: 1.0
        norm_type: 2
      - module_name: ["wrist_head"]   # Keep wrist_head
        max_norm: 1.0
        norm_type: 2

  # Grouped LR: optimize only camera and wrist head
  options:
    lr:
      # Wrist head: 2e-5
      - param_names: ["wrist_head*"]
        scheduler:
          _target_: fvcore.common.param_scheduler.CompositeParamScheduler
          schedulers:
            - _target_: fvcore.common.param_scheduler.LinearParamScheduler
              start_value: 1e-8
              end_value: 2e-5
            - _target_: fvcore.common.param_scheduler.CosineParamScheduler
              start_value: 2e-5
              end_value: 1e-8
          lengths: [0.05, 0.95]
          interval_scaling: ['rescaled', 'rescaled']
      
      # Camera head: 5e-6 (remove depth_head and point_head)
      - param_names: ["camera_head*"]
        scheduler:
          _target_: fvcore.common.param_scheduler.CompositeParamScheduler
          schedulers:
            - _target_: fvcore.common.param_scheduler.LinearParamScheduler
              start_value: 1e-9
              end_value: 5e-6
            - _target_: fvcore.common.param_scheduler.CosineParamScheduler
              start_value: 5e-6
              end_value: 1e-9
          lengths: [0.05, 0.95]
          interval_scaling: ['rescaled', 'rescaled']
    
    weight_decay:
      # Use same weight decay for all params
      - scheduler:
          _target_: fvcore.common.param_scheduler.ConstantParamScheduler
          value: 0.05

max_epochs: 100

# Key: model keeps full inference capability, but only trains camera parameters
model:
  _target_: vggt.models.vggt.VGGT
  img_size: ${img_size}
  patch_size: ${patch_size}
  embed_dim: 1024
  enable_camera: True   # Keep main camera pose prediction
  enable_depth: True    # Keep depth prediction (for inference and visualization)
  enable_point: True    # Keep point cloud prediction (for inference and visualization)
  enable_track: True    # Enable track prediction for evaluation
  enable_wrist: True    # Keep wrist pose prediction
  
  # NEW: WristHead config - use all tokens + multi-layer Transformer aggregation
  wrist_head_config:
    token_aggregation: "attention"  # "attention", "mean", "weighted_mean"
    
    # Multi-layer Transformer token aggregation
    aggregation_num_layers: 3       # Number of Transformer layers
    aggregation_num_heads: 8        # Number of attention heads  
    aggregation_dropout: 0.1        # Dropout
    
    # Original params
    trunk_depth: 4
    num_heads: 16
    mlp_ratio: 4
    init_values: 0.01
    
  pretrained: "facebook/VGGT-1B"  # Load pretrained model
  use_lora: False  # No LoRA, fine-tune directly
  lora_rank: 16  
  lora_alpha: 32  

# Distributed training
distributed:
  backend: nccl
  comms_dtype: None
  find_unused_parameters: False
  timeout_mins: 30
  gradient_as_bucket_view: True
  bucket_cap_mb: 25
  broadcast_buffers: True

# CUDA configuration
cuda:
  cudnn_deterministic: False
  cudnn_benchmark: False
  allow_tf32: True 