name: svd-lrm-v1
tags: ["svd"]
description: ""
version: 'van_svd' # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

extras:
  resolution: 384
  src_views: 1
  target_views: 6
  bg_color: white
  root_dir: data/nerf/my_synthetic
  caption_path: data/nerf/my_synthetic/caption.txt

seed: 42
# resume: outputs/3dgs-fix/p0gao83a/checkpoints/epoch=189-step=19000.ckpt
data:
  _target_: src.data.multiview_svd.MultiViewDataModule
  train_dataset:
    _target_: src.data.multiview_svd.MultiViewDataset
    root_dir: ${extras.root_dir}
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: ${extras.bg_color}
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    caption_path: ${extras.caption_path}
    # num_samples: 10
    repeat: 100
  train_batch_size: 1
  val_dataset:
    _target_: src.data.multiview_svd.MultiViewDataset
    root_dir: ${extras.root_dir}
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: ${extras.bg_color}
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    caption_path: ${extras.caption_path}
    num_samples: 2
  val_batch_size: 1
  test_dataset:
    _target_: src.data.multiview_svd.MultiViewDataset
    root_dir: ${extras.root_dir}
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: ${extras.bg_color}
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    caption_path: ${extras.caption_path}
  test_batch_size: 1
  num_workers: 32
  pin_memory: True


system:
  _target_: src.systems.mv_diffusion.svd_lrm_system.SVDSystem
  lr: 1.0e-5
  base_model_id: stabilityai/stable-video-diffusion-img2vid
  variant: fp16
  cfg: 0.2
  mv_model:
    _target_: src.models.unet.mv_unet.MVModel
    _partial_: True


trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 100000
  check_val_every_n_epoch: 5
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: auto
  accelerator: gpu
  devices: 1
  num_nodes: 1
  precision: 16-mixed # mixed precision for extra speed-up
  gradient_clip_val: 1.0

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    every_n_train_steps: 20000
  # point_cloud_checkpoint:
  #   _target_: src.utils.callbacks.PointCloudCallback
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  # tensorboard:
  #   _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  #   save_dir: "${output_dir}"
  #   name: ""
  #   version: "${version}"
  #   sub_dir: "tb_logs"
  wandb:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    project: "${name}"
    save_dir: "outputs"
    name: "${version}"
  #   # id: p0gao83a
  #   # resume: 'must'