# @package _global_

# to execute this experiment run:
# python train.py experiment=panoptic/scannet_11g

# This configuration allows training SPT on a single 11G GPU, with a
# training procedure comparable with the default
# experiment/panoptic/scannet configuration.
# Among the multiple ways of reducing memory impact, we choose here to
#   - reduce the number of samples in each batch (facilitates training
#     on smaller GPUs)
# To keep the total number of training steps consistent with the default
# configuration, while keeping informative gradient despite the smaller
# batches, we use gradient accumulation and reduce the number of epochs.
# DISCLAIMER: the tiling procedure may increase the preprocessing time
# (more raw data reading steps), and slightly reduce mode performance
# (less diversity in the spherical samples)

defaults:
  - override /datamodule: panoptic/scannet.yaml
  - override /model: panoptic/spt-2.yaml
  - override /trainer: gpu.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

datamodule:
  dataloader:
    batch_size: 1

callbacks:
  model_checkpoint:
    every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'}

  early_stopping:
    strict: False

  gradient_accumulator:
    scheduling:
      0:
        4  # accumulate gradient every 4 batches, to make up for reduced batch size

trainer:
  max_epochs: 100  # to keep the same number of steps -> epochs unchanged
  check_val_every_n_epoch: 2

model:
  optimizer:
    lr: 0.01
    weight_decay: 1e-4

  scheduler:
    num_warmup: 2

  _node_mlp_out: 64
  _h_edge_mlp_out: 64
  _down_dim: [ 128, 128, 128, 128 ]
  _up_dim: [ 128, 128, 128 ]
  net:
    no_ffn: False
    down_ffn_ratio: 1
    down_num_heads: 32

  partitioner:
    regularization: 20
    x_weight: 5e-2
    cutoff: 300

  edge_affinity_loss_lambda: 10

  partition_every_n_epoch: 4

logger:
  wandb:
    project: "spt_scannet"
    name: "SPT-128"

# metric based on which models will be selected
optimized_metric: "val/pq"
