name: 3dgs-v5
tags: ["3dgs"]
description: ""
version: 'sdf' # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

extras:
  token_dim: 256 # latent token dim
  point_dim: 256 # for fold net
  mlp_dim: 256 # head dim
  n_tokens: 1024
  n_pts: 8192
  resolution: 384
  src_views: 1
  target_views: 6
  bg_color: black
  root_dir: data/nerf/my_synthetic
  caption_path: data/nerf/my_synthetic/caption.txt
  return_pc: True # used for cd loss

seed: 42
# resume: outputs/3dgs-fix/p0gao83a/checkpoints/epoch=189-step=19000.ckpt
data:
  _target_: src.data.multiview.MultiViewDataModule
  train_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: ${extras.root_dir}
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: ${extras.bg_color}
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    caption_path: ${extras.caption_path}
    # num_samples: 10
    return_pc: ${extras.return_pc}
    repeat: 100
  train_batch_size: 12
  val_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: ${extras.root_dir}
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: ${extras.bg_color}
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    caption_path: ${extras.caption_path}
    return_pc: ${extras.return_pc}
    num_samples: 2
  val_batch_size: 1
  test_dataset:
    _target_: src.data.multiview.MultiViewDataset
    root_dir: ${extras.root_dir}
    src_views: ${extras.src_views}
    target_views: ${extras.target_views}
    bg_color: ${extras.bg_color}
    img_wh: 
      - ${extras.resolution}
      - ${extras.resolution}
    include_src: True
    relative_pose: True
    return_pc: ${extras.return_pc}
    caption_path: ${extras.caption_path}
  test_batch_size: 1
  num_workers: 32
  pin_memory: True

system:
  _target_:  src.systems.general_gs.gs_gen_system.GaussianSplattingSystem
  gaussian:
    _target_: src.models.gs.gaussian_model_gen.GaussianModel
    n_pts: ${extras.n_tokens}
    n_dim: ${extras.token_dim}
    query_transformer:
      _target_: src.models.gs.anchor_transformers.Transformer1DModel
      in_channels: ${extras.token_dim}
      num_attention_heads: 12
      attention_head_dim: 64
      cross_attention_dim: 768
      num_layers: 8

    # query_transformer:
    #   _target_: torch.nn.Identity

    fold_transformer:
      _target_: src.models.gs.anchor_transformers.Transformer1DModel
      in_channels: ${extras.token_dim}
      num_attention_heads: 12
      attention_head_dim: 64
      num_layers: 2
    
    encoder:
      _target_: src.models.gs.backbone.DinoWrapper
      model_name: facebook/dino-vitb16
      freeze: False

    # fold_transformer:
    #   _target_: src.models.gs.foldnet.LearableFold
    #   n_folds: 10
    #   n_dim: ${extras.point_dim}
    #   query_transformer:
    #     _target_: src.models.blocks.attention1d.Transformer1DModel
    #     in_channels:  ${extras.point_dim}
    #     num_attention_heads: 12
    #     attention_head_dim: 32
    #     num_layers: 2
    #     cross_attention_dim: ${extras.token_dim}

    # fold_transformer:
    #   _target_: torch.nn.Identity

    head_config:
      opacity:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 1
      scaling:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 3
      rotation:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 4
      color:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        # out_dim: 48
        out_dim: 3
      xyz:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 3
    
  renderer: 
    _target_: src.models.gs.renderers.vanilla_renderer_gen.VanillaRenderer

  loss_fn:
    _target_: src.models.loss_fn.LossFN
    loss_list:
      - name: rgb_diff_loss
        weight: 1
        pred: render
        target: render
        loss_fn:
          _target_: torch.nn.L1Loss
      - name: ssim_metric
        weight: 0.2
        pred: render
        target: render
        # used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.SSIMLoss
      - name: lpips_loss
        weight: 0.2
        pred: render
        target: render
        # used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.LPIPSLoss
      - name: depth_loss
        weight: 1
        pred: depth
        target: depth
        # used_for_optimization: false
        loss_fn:
          _target_: torch.nn.L1Loss
      # - name: dist_loss
      #   weight: 1.0e-4
      #   pred: means3D
      #   target: points
      #   used_for_optimization: false
      #   loss_fn:
      #     _target_: src.models.loss_fn.CoulombLoss
      - name: chamfer_loss
        weight: 1.0
        pred: means3D
        target: pc_fps
        # used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.ChamferLoss
      - name: emd_loss
        weight: 1.0
        pred: means3D
        target: pc_fps
        # used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.EMDLoss
      # - name: infochamfer_loss
      #   weight: 1.0
      #   pred: means3D
      #   target: pc_fps
      #   # used_for_optimization: false
      #   loss_fn:
      #     _target_: src.models.loss_fn.InfoChamferLoss

  save_val_output: True

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 100000
  check_val_every_n_epoch: 5
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: auto
  accelerator: gpu
  devices: 1
  num_nodes: 1
  precision: 16-mixed # mixed precision for extra speed-up
  gradient_clip_val: 1.0

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    every_n_train_steps: 20000
  point_cloud_checkpoint:
    _target_: src.utils.callbacks.PointCloudCallback
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  # tensorboard:
  #   _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  #   save_dir: "${output_dir}"
  #   name: ""
  #   version: "${version}"
  #   sub_dir: "tb_logs"
  wandb:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    project: "${name}"
    save_dir: "outputs"
    name: "${version}"
  #   # id: p0gao83a
  #   # resume: 'must'