name: learnable_3dgs
tags: ["lvis96v"]
description: ""
version: null # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

seed: 42
resume: null

data:
  _target_: src.data.gsdataset.DataModule
  path: data/my_synthetic/lego
  params:
    _target_: src.data.dataparsers.dataset.DatasetParams
    val_max_num_images_to_cache: -1
    test_max_num_images_to_cache: -1

system:
  _target_:  src.systems.general_gs.generalize_gs_system.GaussianSplattingSystem
  gaussian:
    _target_: src.models.gs.learnable_gaussian_model.GaussianModel
    feature_space:
      _target_: src.models.gs.mlp.Points
      point_num: 50000
      point_dim: 128
    gs_head:
      _target_: src.models.gs.mlp.GSHEAD
      configs:
        xyz:
          in_dim: 128
          hidden_dim: 128
          out_dim: 3
          num_layers: 3
        scaling:
          in_dim: 128
          hidden_dim: 128
          out_dim: 3
          num_layers: 3
        rotation:
          in_dim: 128
          hidden_dim: 128
          out_dim: 4
          num_layers: 3
        opacity:
          in_dim: 128
          hidden_dim: 128
          out_dim: 1
          num_layers: 3
        rgb:
          in_dim: 128
          hidden_dim: 128
          out_dim: 3
          num_layers: 3
        
  renderer: 
    _target_: src.models.gs.renderers.GeneralRenderer
  save_val_output: True

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 20000
  check_val_every_n_epoch: 10
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: auto
  accelerator: gpu
  devices: 1
  num_nodes: 1
  # precision: 16-mixed # mixed precision for extra speed-up

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    every_n_train_steps: 1000
  # point_cloud_checkpoint:
    # _target_: src.utils.callbacks.PointCloudCallback
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  tensorboard:
    _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
    save_dir: "${output_dir}"
    name: ""
    version: "${version}"
    sub_dir: "tb_logs"
  # wandb:
  #   _target_: lightning.pytorch.loggers.wandb.WandbLogger
  #   project: "${name}"
  #   save_dir: "outputs"
  #   name: "${version}"