# @package _global_

# to execute this experiment run:
# python train.py experiment=semantic/dales_11g

# This configuration allows training SPT on a single 11G GPU, with a
# training procedure comparable with the default
# experiment/semantic/dales configuration.
# Among the multiple ways of reducing memory impact, we choose here to
#   - divide the dataset into smaller tiles (facilitates preprocessing
#     and inference on smaller GPUs)
#   - reduce the number of samples in each batch (facilitates training
#     on smaller GPUs)
# To keep the total number of training steps consistent with the default
# configuration, while keeping informative gradient despite the smaller
# batches, we use gradient accumulation and reduce the number of epochs.
# DISCLAIMER: the tiling procedure may increase the preprocessing time
# (more raw data reading steps), and slightly reduce mode performance
# (less diversity in the spherical samples)

defaults:
  - override /datamodule: semantic/dales.yaml
  - override /model: semantic/spt-2.yaml
  - override /trainer: gpu.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

datamodule:
  xy_tiling: 5  # split each cloud into xy_tiling²=25 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory
  sample_graph_k: 2  # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory

callbacks:
  gradient_accumulator:
    scheduling:
      0:
        2  # accumulate gradient every 2 batches, to make up for reduced batch size

trainer:
  max_epochs: 288  # to keep same nb of steps: 25/9x more tiles, 2-step gradient accumulation -> epochs * 2 * 9 / 25

model:
  optimizer:
    lr: 0.01
    weight_decay: 1e-4

logger:
  wandb:
    project: "spt_dales"
    name: "SPT-64"
