name: test
tags: ["lvis96v"]
description: ""
version: null # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

seed: 42
resume: null

data:
  _target_: src.data.multiview.MultiViewDataModule
  train_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: /mnt/pfs/data/render_lvis_hzh
    num_views: 16
    bg_color: white
    img_wh: [256, 256]
    k_near_views: 16
    sample_views_mode: random
    caption_path: /mnt/pfs/users/huangzehuan/project/mvgen/data/lvis96v/caption_train.txt
  train_batch_size: 1
  val_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: /mnt/pfs/data/render_lvis_hzh
    num_views: 16
    bg_color: white
    img_wh: [256, 256]
    k_near_views: 16
    sample_views_mode: lay4
    caption_path: /mnt/pfs/users/huangzehuan/project/mvgen/data/lvis96v/caption_val.txt
  val_batch_size: 1
  test_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: /mnt/pfs/data/render_lvis_hzh
    num_views: 16
    bg_color: white
    img_wh: [256, 256]
    k_near_views: 16
    sample_views_mode: lay4
    caption_path: /mnt/pfs/users/huangzehuan/project/mvgen/data/lvis96v/caption_test.txt
  test_batch_size: 1
  num_workers: 64
  pin_memory: True

model:
  _target_: src.systems.i2mv_system.I2MVSystem
  model:
    _target_: src.models.mv_model.MVModel
    base_model_id: "bennyguo/zero123-xl-diffusers"
    variant: "fp16_ema"
    insert_stages: ["mid", "up"]
    insert_up_layers: [0, 1, 2, 3]
    use_residual: false
  lr: 1e-5
  cfg: 0.1
  report_to: wandb

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 20001
  val_check_interval: 0.5
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: ddp
  accelerator: gpu
  devices: 1
  num_nodes: 1
  precision: 16-mixed # mixed precision for extra speed-up

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    every_n_train_steps: 5000
  rich_progress_bar:
    _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  # tensorboard:
  #   _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  #   save_dir: "${output_dir}"
  #   name: ""
  #   version: "${version}"
  #   sub_dir: "tb_logs"
  wandb:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    project: "${name}"
    save_dir: "outputs"
    name: "${version}"