name: 3dgs-v1
tags: ["3dgs"]
description: ""
version: 'max_sh+opacity_random_init+color+ssim+depth+8192' # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

extras:
  token_dim: 256
  point_dim: 256
  mlp_dim: 256
  n_pts: 8192


seed: 42

data:
  _target_: src.data.gsdataset.DataModule
  path: data/nerf/my_synthetic/data
  params:
    _target_: src.data.dataparsers.dataset.DatasetParams
    val_max_num_images_to_cache: -1
    test_max_num_images_to_cache: -1
  n_pts: ${extras.n_pts}

system:
  _target_:  src.systems.general_gs.gs_fix_system_van.GaussianSplattingSystem
  gaussian:
    _target_: src.models.gs.gaussian_model_van.GaussianModel
    n_pts: ${extras.n_pts}
    active_sh_degree: 3
    max_sh_degree: 3

  renderer: 
    _target_: src.models.gs.renderers.vanilla_renderer_van.VanillaRenderer

  loss_fn:
    _target_: src.models.loss_fn.LossFN
    loss_list:
      - name: rgb_diff_loss
        weight: 1
        pred: render
        target: render
        loss_fn:
          _target_: torch.nn.L1Loss
      - name: ssim_metric
        weight: 0.2
        pred: render
        target: render
        # used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.SSIMLoss
      - name: depth_loss
        weight: 1
        pred: depth
        target: depth
        # used_for_optimization: false
        loss_fn:
          _target_: torch.nn.L1Loss
      - name: dist_loss
        weight: 1.0e-4
        pred: means3D
        target: points
        used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.CoulombLoss
      - name: chamfer_loss
        weight: 1.0
        pred: means3D
        target: points
        used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.ChamferLoss


  save_val_output: True

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 20000
  check_val_every_n_epoch: 20
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: auto
  accelerator: gpu
  devices: 1
  num_nodes: 1
  # precision: 16-mixed # mixed precision for extra speed-up

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    every_n_train_steps: 1000
  point_cloud_checkpoint:
    _target_: src.utils.callbacks.PointCloudCallback
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  # tensorboard:
  #   _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  #   save_dir: "${output_dir}"
  #   name: ""
  #   version: "${version}"
  #   sub_dir: "tb_logs"
  wandb:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    project: "${name}"
    save_dir: "outputs"
    name: "${version}"
    # id: p0gao83a
    # resume: 'must'