defaults:
  - vggt_video_dataset.yaml

exp_name: vggt_video_input
img_size: 518  # Square image size divisible by patch_size=14 (1274=14×91)
num_workers: 0  # Training workers
seed_value: 42
accum_steps: 3
patch_size: 14
limit_train_batches: 800
limit_val_batches: 400

# Video input mode configuration
video_input:
  enabled: True  # Enable video input mode
  ext1_video_path: "XXX/data_processed/droid_autolab_success_move_1k/ext1/AUTOLab_success_xxx_ext1.mp4"  # First input video path (ext1)
  ext2_video_path: "XXX/data_processed/droid_autolab_success_move_1k/ext2/AUTOLab_success_xxx_ext2.mp4"  # Second input video path (ext2)
  wrist_video_path: "XXX/data_processed/droid_autolab_success_move_1k/wrist/AUTOLab_success_xxx_wrist.mp4"  # GT wrist video path
  # When enabled, train_set = val_set = all frames of the videos
  # Other GT data (point cloud, depth, etc.) will be filled with zeros

# Add early validation configuration
early_validation:
  enabled: True  # Enable early validation
  step: 5        # Perform early validation at step 5
  limit_batches: 50  # Limit batch count for early validation

logging:
  log_dir: logs
  log_visuals: True  # Enable visualization
  log_freq: 1
  log_level_primary: INFO  # Reduce debug noise
  log_level_secondary: WARNING
  all_ranks: False
  tensorboard_writer:
    _target_: train_utils.tb_writer.TensorBoardLogger
    path: ${logging.log_dir}/tensorboard
  
  # Visualization options
  log_visual_frequency:
    train: 100  # Log visuals every 100 steps for training
    val: 1     # Log visuals every epoch for validation
  
  visuals_keys_to_log:
    train:
      keys_to_log: ["depth", "world_points"]  # Keep visualization
      modality: "image"
    val:
      keys_to_log: ["depth", "world_points", "depth_conf"]  # Keep visualization
      modality: "image"
  
  visuals_per_batch_to_log: 4
  video_logging_fps: 10  # FPS for video logs
  scalar_keys_to_log:
    train:
      keys_to_log:
        - loss_objective
        - loss_camera  # Keep camera-related losses
        - loss_T
        - loss_R
        - loss_FL
        # Remove depth-related losses
        # - loss_conf_depth
        # - loss_reg_depth
        # - loss_grad_depth
        # Remove point-related losses
        # - loss_conf_point
        # - loss_reg_point
        # - loss_grad_point
        - loss_wrist  # Keep wrist-related losses
        - loss_wrist_T
        - loss_wrist_R
        - loss_wrist_FL
        - loss_projection  # Add projection loss log
        - valid_track_points  # Add valid track points count
        - total_track_points  # Add total track points count
        - depth_loss_count  # Add depth loss count
        - uv_loss_sum  # Add UV loss sum
        - depth_loss_sum  # Add depth loss sum
    val:
      keys_to_log:
        - loss_objective
        - loss_camera  # Keep camera-related losses
        - loss_T
        - loss_R
        - loss_FL
        # Remove depth-related losses
        # - loss_conf_depth
        # - loss_reg_depth
        # - loss_grad_depth
        # Remove point-related losses
        # - loss_conf_point
        # - loss_reg_point
        # - loss_grad_point
        - loss_wrist  # Keep wrist-related losses
        - loss_wrist_T
        - loss_wrist_R
        - loss_wrist_FL
        - loss_projection  # Add projection loss log
        - valid_track_points  # Add valid track points count
        - total_track_points  # Add total track points count
        - depth_loss_count  # Add depth loss count
        - uv_loss_sum  # Add UV loss sum
        - depth_loss_sum  # Add depth loss sum

checkpoint:
  save_dir: logs/${exp_name}/ckpts
  save_freq: 5
  resume_checkpoint_path: logs/vggt_droid_camera_only_finetune/ckpts/checkpoint.pt
  strict: False

loss:
  _target_: loss.MultitaskLoss
  camera: 
    weight: 5.0
    loss_type: "l1" # The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss.  
  depth:
    weight: 1.0
    gradient_loss_fn: "grad" 
    valid_range: 0.98
  point: null
  # If you want to enable point, use the following config
  # point: 
  #   weight: 1.0
  #   gradient_loss_fn: "normal" 
  #   valid_range: 0.98
  track: null   




optim:
  param_group_modifiers: False

  optimizer:
    _target_: torch.optim.AdamW
    lr: 1e-4
    weight_decay: 0.05

  frozen_module_names:
      - "*aggregator*"  # example, freeze the aggregator

  amp:
    enabled: True
    amp_dtype: bfloat16
  gradient_clip:
    _target_: train_utils.gradient_clip.GradientClipper
    configs:
      - module_name: ["aggregator"]
        max_norm: 1.0   # Reduce if instabilities occur
        norm_type: 2
      - module_name: ["depth"]
        max_norm: 1.0   # Reduce if instabilities occur
        norm_type: 2
      - module_name: ["camera"]
        max_norm: 1.0   # Reduce if instabilities occur
        norm_type: 2
      - module_name: ["point"]
        max_norm: 1.0   # Gradient clipping for point module
        norm_type: 2
      - module_name: ["wrist"]
        max_norm: 1.0   # Gradient clipping for wrist module
        norm_type: 2
  options:
    lr:
      - scheduler:
          _target_: fvcore.common.param_scheduler.CompositeParamScheduler
          schedulers:
            - _target_: fvcore.common.param_scheduler.LinearParamScheduler
              start_value: 1e-8
              end_value: 1e-4
            - _target_: fvcore.common.param_scheduler.CosineParamScheduler
              start_value: 1e-4
              end_value: 1e-8
          lengths: [0.05, 0.95]
          interval_scaling: ['rescaled', 'rescaled']
    weight_decay:
      - scheduler:
          _target_: fvcore.common.param_scheduler.ConstantParamScheduler
          value: 0.05




max_epochs: 100

model:
  _target_: vggt.models.vggt.VGGT
  enable_camera: True
  enable_depth: True
  enable_point: True  # Video input mode needs point_head to produce world_points
  enable_track: False


distributed:
  # check https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html for options
  backend: nccl
  comms_dtype: None
  find_unused_parameters: False
  timeout_mins: 30
  gradient_as_bucket_view: True  # Less memory used
  bucket_cap_mb: 25
  broadcast_buffers: True

cuda:
  cudnn_deterministic: False
  cudnn_benchmark: False
  allow_tf32: True 