name: test
tags: ["lvis96v"]
description: ""
version: null # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

seed: 42
resume: null

extras:
  k_near_views: 8

data:
  _target_: src.data.multiview.MultiViewDataModule
  train_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: data/3d-data-example/
    num_views: 16
    bg_color: white
    img_wh: [256, 256]
    k_near_views: ${extras.k_near_views}
    sample_views_mode: random
    caption_path: data/3d-data-example/caption.txt
  train_batch_size: 1
  val_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: data/3d-data-example/
    num_views: 16
    bg_color: white
    img_wh: [256, 256]
    k_near_views: ${extras.k_near_views}
    sample_views_mode: lay4
    caption_path: data/3d-data-example/caption.txt
  val_batch_size: 1
  test_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: data/3d-data-example/
    num_views: 16
    bg_color: white
    img_wh: [256, 256]
    k_near_views: ${extras.k_near_views}
    sample_views_mode: lay4
    caption_path: data/3d-data-example/caption.txt
  test_batch_size: 1
  num_workers: 64
  pin_memory: True

system:
  _target_: src.systems.i2mv3d_system.I2MV3DSystem
  base_model_id: pretrain/zero123-xl-diffusers
  variant: "fp16_ema"
  model:
    _target_: src.models.triplane_nerf.TriplaneNERF
    plane_size: 64
    plane_num: 64
    query_transformer:
      _target_: src.models.blocks.attention1d.Transformer1DModel
      in_dim: 64
      num_attention_heads: 8
      attention_head_dim: 32
      num_layers: 6


trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 20001
  # val_check_interval: 2000
  check_val_every_n_epoch: 2
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: ddp_find_unused_parameters_true
  accumulate_grad_batches: 1.0
  gradient_clip_val: 1.0
  accelerator: gpu
  devices: 1
  num_nodes: 1
  precision: 16-mixed # mixed precision for extra speed-up

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    every_n_train_steps: 50000
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  tensorboard:
    _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
    save_dir: "${output_dir}"
    name: ""
    version: "${version}"
    sub_dir: "tb_logs"
  # wandb:
  #   _target_: lightning.pytorch.loggers.wandb.WandbLogger
  #   project: "${name}"
  #   save_dir: "outputs"
  #   name: "${version}"