name: nips
tags: ["svd"]
description: ""
version: 'svd_lgm+mix-adapter-rgb+plucker_inference' # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

extras:
  resolution: 512
  src_views: 1
  target_views: 6
  bg_color: white
  root_dir: https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/objaverse/
  instance_file: data/gobjaverse/gobjaverse_280k.json
  num_nodes: 1
  devices: 1
  num_frames: 4
  video_frames: 8
  add_plucker: True
  bucket: gcc:s3://wenhao/nips/svd_lgm_rgb/
train: False
test: True

seed: 42
resume: ckpt/epoch=2-step=19000.ckpt
data:
  _target_: src.data.test_dataset.MultiViewDataModule
  # train_dataset:
  #   _target_: src.data.gobjaverse.MultiViewDataset
  #   root_dir: ${extras.root_dir}
  #   instance_file: data/gobjaverse/gobjaverse_280k.json
  #   bg_color: ${extras.bg_color}
  #   num_frames: ${extras.video_frames}
  #   img_wh: 
  #     - ${extras.resolution}
  #     - ${extras.resolution}
  #   # num_samples: 10000
  #   repeat: 1
  # train_batch_size: 1
  # val_dataset:
  #   _target_: src.data.gobjaverse.MultiViewDataset
  #   root_dir: data/gobjaverse/data/
  #   instance_file: data/gobjaverse/test_list.json
  #   bg_color: ${extras.bg_color}
  #   num_frames: ${extras.video_frames}
  #   img_wh: 
  #     - ${extras.resolution}
  #     - ${extras.resolution}
  #   num_samples: 50
  # val_batch_size: 1
  test_dataset:
    _target_: src.data.test_dataset.MultiViewDataset
    root_dir: data/test/images
    meta_file: data/test/meta.json
    bg_color: ${extras.bg_color}
    num_frames: ${extras.video_frames}
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
  test_batch_size: 1
  num_workers: 16
  pin_memory: True


system:
  _target_: src.systems.mv_diffusion.svd_lgm_rgb_system.SVDSystem
  lr: 1.0e-5
  base_model_id: /mnt/petrelfs/wenhao1/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid/snapshots/0f2d55c1e358d608120344d3ea9c35fb5f2c31b3
  variant: fp16
  cfg: 0.1
  mv_model:
    _target_: src.models.unet.mv_unet.MVModel
    cond_encoder:
      _target_: src.models.unet.adaptor.Adapter_XL
      cin: 192 # 3 x 8 x 8
      # cin: 768 # 16 x 16 x 3
      # cin: 1024 # 16 x 16 x 4 = 1024
      channels: [320, 640, 1280, 1280]
      sk: True
      use_conv: False
      ksize: 1
    add_plucker: ${extras.add_plucker}
    _partial_: True

  recon_model:
    _target_: src.models.network.lgm.models.LGM
    num_frames: ${extras.num_frames}
    opt:
      _target_: src.models.network.lgm.options.Options
      input_size: 256
      up_channels: [1024, 1024, 512, 256, 128]
      up_attention: [True, True, True, False, False]
      splat_size: 128
      output_size: 512
      batch_size: 8
      num_views: 8
      gradient_accumulation_steps: 1
      mixed_precision: fp16

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 1000000
  # check_val_every_n_epoch: 8
  val_check_interval: 4000
  accumulate_grad_batches: 4
  log_every_n_steps: 20
  num_sanity_val_steps: 1
  enable_progress_bar: true
  # strategy: ddp_find_unused_parameters_true
  # strategy: deepspeed_stage_1
  strategy:
    _target_: lightning.pytorch.strategies.DeepSpeedStrategy
    config: config.json
  # accelerator: gpu
  devices: ${extras.devices}
  num_nodes: ${extras.num_nodes}
  # precision: 32
  precision: 16-mixed
  gradient_clip_val: 1

callbacks:
  model_checkpoint:
    _target_: callbacks.CustomModelCheckpoint
    bucket: ${extras.bucket}
    save_top_k: -1
    # every_n_train_steps: 1
    every_n_train_steps: 1000
  
  # point_cloud_checkpoint:
  #   _target_: src.utils.callbacks.PointCloudCallback
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar


logger:
  tensorboard:
    _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
    save_dir: "${output_dir}"
    name: ""
    version: "${version}"
    sub_dir: "tb_logs"
  # wandb:
  #   _target_: lightning.pytorch.loggers.wandb.WandbLogger
  #   project: "${name}"
  #   save_dir: "outputs"
  #   name: "${version}"

  #   id: hpxskp68
  #   resume: 'must'