name: nerf
tags: ["lvis96v"]
description: ""
version: null # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

seed: 42
resume: null
extras:
  resolution: 32
  src_views: 5
  target_views: 3

data:
  _target_: src.data.multiview.MultiViewDataModule
  train_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: data/demo/
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: white
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    caption_path: data/demo/caption.txt
    # num_samples: 10
    repeat: 100
  train_batch_size: 1
  val_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: data/demo/
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: white
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    caption_path: data/demo/caption.txt
    num_samples: 2
  val_batch_size: 1
  test_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: data/demo/
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: white
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    caption_path: data/demo/caption.txt
  test_batch_size: 1

  num_workers: 1
  pin_memory: True

system:
  _target_: src.systems.generalize_nerf_system.GeneralizeNERFSystem
  nerf:
    _target_: src.models.nerf.volume_nerfacc_nerf.VolumeNERF 
    feature_space:
      _target_: src.models.nerf.mlp.Volume
      volume_size: ${extras.resolution}
      volume_dim: 64
    # feature_space:
    #   _target_: src.models.nerf.mlp.Triplane
    #   plane_size: ${extras.resolution}
    #   plane_dim: 64
    resolution: ${extras.resolution} # rendering resolution
    N_samples: ${extras.resolution}
    image_transformer:
      _target_: src.models.network.encoder.CrossViewEncoder
      in_channels: 3
      channels: 768
      num_layers: 8
    query_transformer:
      _target_: src.models.blocks.attention1d.Transformer1DModel
      in_channels: 64
      num_attention_heads: 12
      attention_head_dim: 64
      num_layers: 8
      cross_attention_dim: 768
    implicit_network: 
      _target_: src.models.nerf.mlp.ImplicitNetwork
      feature_vector_size: 64 
      d_in: 3
      d_out: 1
      # dims: [256, 256, 256, 256, 256, 256, 256]
      dims: [128, 128, 128, 128, 128, 128, 128]
      geometric_init: True
      skip_in: []
      weight_norm: True
      multires: 6
    rendering_network: 
      _target_: src.models.nerf.mlp.RenderingNetwork
      add_normals: True
      feature_vector_size: 64 
      d_in: 3
      d_out: 3
      # dims: [256, 256]
      dims: [128, 128]  
      weight_norm: True
      multires_view: 6
    # density_network:
    #   _target_: src.models.nerf.density.SimpleDensity
    #   noise_std: 0.0
    density_network:
      _target_: src.models.nerf.density.LaplaceDensity
      params_init:
        beta: 0.1
      beta_min: 0.0001
  criterion:
    _target_: src.models.nerf.loss.Criterion
    weights: [1, 1, 1, 0.1]
    criterions:
      - _target_: src.models.nerf.loss.RGBLoss
      - _target_: src.models.nerf.loss.DepthLoss
      - _target_: src.models.nerf.loss.MaskLoss
      # - _target_: src.models.nerf.loss.EikonalLoss

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 200001
  # val_check_interval: 2000
  check_val_every_n_epoch: 1
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: ddp_find_unused_parameters_true
  accumulate_grad_batches: 1.0
  gradient_clip_val: 1.0
  accelerator: gpu
  devices: 1
  num_nodes: 1
  precision: 16-mixed # mixed precision for extra speed-up

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    every_n_train_steps: 50000
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  tensorboard:
    _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
    save_dir: "${output_dir}"
    name: ""
    version: "${version}"
    sub_dir: "tb_logs"
  # wandb:
  #   _target_: lightning.pytorch.loggers.wandb.WandbLogger
  #   project: "${name}"
  #   save_dir: "outputs"
  #   name: "${version}"