defaults:
  - default_dataset.yaml

exp_name: exp001
img_size: 518
num_workers: 0
seed_value: 42
accum_steps: 3
patch_size: 14

limit_train_batches: 800
limit_val_batches: 400

data:
  # The code for data still looks too complicated. I should refactor this again (do I have time?...)
  train:
    _target_: data.dynamic_dataloader.DynamicTorchDataset
    num_workers: ${num_workers}
    common_config:
      img_size: ${img_size}
      patch_size: ${patch_size}
      debug: True
      repeat_batch: True
    dataset:
      _target_: data.composed_dataset.ComposedDataset
      dataset_configs:
        - _target_: data.datasets.co3d.Co3dDataset
          split: train
          # CO3D_DIR: /checkpoint/repligen/shared/datasets/co3d/
          # CO3D_ANNOTATION_DIR: /checkpoint/repligen/jianyuan/datasets/co3d_anno
          CO3D_DIR: /fsx-repligen/jianyuan/transfer_buffer/small_set/co3d/
          CO3D_ANNOTATION_DIR: /fsx-repligen/jianyuan/transfer_buffer/small_set/co3d_anno
  val:
    _target_: data.dynamic_dataloader.DynamicTorchDataset
    num_workers: ${num_workers}
    common_config:
      img_size: ${img_size}
      patch_size: ${patch_size}
      debug: True
    dataset:
      _target_: data.composed_dataset.ComposedDataset
      dataset_configs:
        - _target_: data.datasets.co3d.Co3dDataset
          split: test
          # CO3D_DIR: /checkpoint/repligen/shared/datasets/co3d/
          # CO3D_ANNOTATION_DIR: /checkpoint/repligen/jianyuan/datasets/co3d_anno
          CO3D_DIR: /fsx-repligen/jianyuan/transfer_buffer/small_set/co3d/
          CO3D_ANNOTATION_DIR: /fsx-repligen/jianyuan/transfer_buffer/small_set/co3d_anno


logging:
  log_dir: logs
  log_visuals: False
  log_freq: 1
  log_level_primary: DEBUG
  log_level_secondary: WARNING
  all_ranks: False
  tensorboard_writer:
    _target_: train_utils.tb_writer.TensorBoardLogger
    path: ${logging.log_dir}/tensorboard
  scalar_keys_to_log:
    train:
      keys_to_log:
        - loss_objective
        - loss_camera
        - loss_T
        - loss_R
        - loss_FL
        - loss_conf_depth
        - loss_reg_depth
        - loss_grad_depth
    val:
      keys_to_log:
        - loss_objective
        - loss_camera
        - loss_T
        - loss_R
        - loss_FL
        - loss_conf_depth
        - loss_reg_depth
        - loss_grad_depth



checkpoint:
  save_dir: logs/${exp_name}/ckpts
  save_freq: 5
  resume_checkpoint_path: /fsx-repligen/jianyuan/transfer_buffer/ckpts/model.pt
  strict: False


loss:
  _target_: loss.MultitaskLoss
  camera: 
    weight: 5.0
    loss_type: "l1" # The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss.  
  depth:
    weight: 1.0
    gradient_loss_fn: "grad" 
    valid_range: 0.98
  point: null
  # If you want to enable point, use the following config
  # point: 
  #   weight: 1.0
  #   gradient_loss_fn: "normal" 
  #   valid_range: 0.98
  track: null   




optim:
  param_group_modifiers: False

  optimizer:
    _target_: torch.optim.AdamW
    lr: 1e-4
    weight_decay: 0.05

  frozen_module_names:
      - "*aggregator*"  # example, freeze the aggregator

  amp:
    enabled: True
    amp_dtype: bfloat16
  gradient_clip:
    _target_: train_utils.gradient_clip.GradientClipper
    configs:
      - module_name: ["aggregator"]
        max_norm: 1.0   # feel free to reduce this if you see instabilities
        norm_type: 2
      - module_name: ["depth"]
        max_norm: 1.0   # feel free to reduce this if you see instabilities
        norm_type: 2
      - module_name: ["camera"]
        max_norm: 1.0   # feel free to reduce this if you see instabilities
        norm_type: 2
  options:
    lr:
      - scheduler:
          _target_: fvcore.common.param_scheduler.CompositeParamScheduler
          schedulers:
            - _target_: fvcore.common.param_scheduler.LinearParamScheduler
              start_value: 1e-8
              end_value: 1e-4
            - _target_: fvcore.common.param_scheduler.CosineParamScheduler
              start_value: 1e-4
              end_value: 1e-8
          lengths: [0.05, 0.95]
          interval_scaling: ['rescaled', 'rescaled']
    weight_decay:
      - scheduler:
          _target_: fvcore.common.param_scheduler.ConstantParamScheduler
          value: 0.05




max_epochs: 100

model:
  _target_: vggt.models.vggt.VGGT
  enable_camera: True
  enable_depth: True
  enable_point: False
  enable_track: False


distributed:
  # check https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html for options
  backend: nccl
  comms_dtype: None
  find_unused_parameters: False
  timeout_mins: 30
  gradient_as_bucket_view: True  # Less memory used
  bucket_cap_mb: 25
  broadcast_buffers: True

cuda:
    cudnn_deterministic: False
    cudnn_benchmark: False
    allow_tf32: True