name: 3dgs-v3
tags: ["3dgs"]
description: ""
version: '2048latent+max_sh+color+ssim+depth+anchor+transformer+layer4_dim64+tempurature50+100k+learn_anchor+pos+multi_layer' # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

extras:
  token_dim: 256
  point_dim: 256
  mlp_dim: 256
  n_tokens: 2048
  n_pts: 8192

seed: 42
# resume: outputs/3dgs-fix/p0gao83a/checkpoints/epoch=189-step=19000.ckpt

data:
  _target_: src.data.gsdataset.DataModule
  path: data/nerf/my_synthetic/data
  params:
    _target_: src.data.dataparsers.dataset.DatasetParams
    val_max_num_images_to_cache: -1
    test_max_num_images_to_cache: -1
  n_pts: ${extras.n_pts}

system:
  _target_:  src.systems.general_gs.gs_split_system.GaussianSplattingSystem
  gaussian:
    _target_: src.models.gs.gaussian_model_split.GaussianModel
    n_pts: ${extras.n_tokens}
    n_dim: ${extras.token_dim}
    query_transformer:
      _target_: src.models.gs.anchor_transformers.Transformer1DModel
      in_channels: ${extras.token_dim}
      num_attention_heads: 12
      attention_head_dim: 64
      num_layers: 2

    # query_transformer:
    #   _target_: torch.nn.Identity

    fold_transformer:
      _target_: src.models.gs.anchor_transformers.Transformer1DModel
      in_channels: ${extras.token_dim}
      num_attention_heads: 12
      attention_head_dim: 64
      num_layers: 2

    # fold_transformer:
    #   _target_: src.models.gs.foldnet.LearableFold
    #   n_folds: 10
    #   n_dim: ${extras.point_dim}
    #   query_transformer:
    #     _target_: src.models.blocks.attention1d.Transformer1DModel
    #     in_channels:  ${extras.point_dim}
    #     num_attention_heads: 12
    #     attention_head_dim: 32
    #     num_layers: 2
    #     cross_attention_dim: ${extras.token_dim}

    # fold_transformer:
    #   _target_: torch.nn.Identity

    head_config:
      opacity:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 1
      scaling:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 3
      rotation:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 4
      color:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 48
      xyz:
        in_dim: ${extras.point_dim}
        hidden_dim: ${extras.mlp_dim}
        out_dim: 3
    
  renderer: 
    _target_: src.models.gs.renderers.vanilla_renderer_split.VanillaRenderer

  loss_fn:
    _target_: src.models.loss_fn.LossFN
    loss_list:
      - name: rgb_diff_loss
        weight: 1
        pred: render
        target: render
        loss_fn:
          _target_: torch.nn.L1Loss
      - name: ssim_metric
        weight: 0.2
        pred: render
        target: render
        # used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.SSIMLoss
      - name: depth_loss
        weight: 1
        pred: depth
        target: depth
        # used_for_optimization: false
        loss_fn:
          _target_: torch.nn.L1Loss
      - name: dist_loss
        weight: 1.0e-4
        pred: means3D
        target: points
        used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.CoulombLoss
      - name: chamfer_loss
        weight: 1.0
        pred: means3D
        target: points
        used_for_optimization: false
        loss_fn:
          _target_: src.models.loss_fn.ChamferLoss


  save_val_output: True

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  max_steps: 100000
  check_val_every_n_epoch: 40
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: auto
  accelerator: gpu
  devices: 1
  num_nodes: 1
  # precision: 16-mixed # mixed precision for extra speed-up

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    every_n_train_steps: 20000
  point_cloud_checkpoint:
    _target_: src.utils.callbacks.PointCloudCallback
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  # tensorboard:
  #   _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  #   save_dir: "${output_dir}"
  #   name: ""
  #   version: "${version}"
  #   sub_dir: "tb_logs"
  wandb:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    project: "${name}"
    save_dir: "outputs"
    name: "${version}"
    # id: p0gao83a
    # resume: 'must'