# @package model

_target_: src.models.semantic.SemanticSegmentationModule

num_classes: ${datamodule.num_classes}
sampling_loss: False
loss_type: 'ce_kl'  # supports 'ce', 'wce', 'kl', 'ce_kl', 'wce_kl'
weighted_loss: True
init_linear: null  # defaults to xavier_uniform initialization
init_rpe: null  # defaults to xavier_uniform initialization
multi_stage_loss_lambdas: [1, 50]  # weights for the multi-stage loss
transformer_lr_scale: 0.1
gc_every_n_steps: 0

# Every N epoch, the model may store to disk predictions for some
# tracked validation batch of interest. This assumes the validation
# dataloader is non-stochastic. Additionally, the model may store to
# disk predictions for some or all the test batches.
track_val_every_n_epoch: 10  # trigger the tracking every N epoch
track_val_idx: null  # index of the validation batch to track. If -1, all the validation batches will be tracked, at every `track_val_every_n_epoch` epoch
track_test_idx: null  # index of the test batch to track. If -1, all the test batches will be tracked

optimizer:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 0.01
  weight_decay: 1e-4

scheduler:
  _target_: src.optim.CosineAnnealingLRWithWarmup
  _partial_: True
  T_max: ${eval:'${trainer.max_epochs} - ${model.scheduler.num_warmup}'}
  eta_min: 1e-6
  warmup_init_lr: 1e-6
  num_warmup: 20
  warmup_strategy: 'cos'

criterion:
  _target_: torch.nn.CrossEntropyLoss
  ignore_index: ${datamodule.num_classes}

# Parameters declared here to facilitate tuning configs. Those are only
# used here for config interpolation but will/should actually fall in
# the ignored kwargs of the SemanticSegmentationModule
_point_mlp: [32, 64, 128]  # point encoder layers
_node_mlp_out: 32  # size of level-1+ handcrafted node features after MLP, set to 'null' to use directly the raw features
_h_edge_mlp_out: 32  # size of level-1+ handcrafted horizontal edge features after MLP, set to 'null' to use directly the raw features
_v_edge_mlp_out: 32  # size of level-1+ handcrafted vertical edge features after MLP, set to 'null' to use directly the raw features

_point_hf_dim: ${eval:'${model.net.use_pos} * 3 + ${datamodule.num_hf_point} + ${model.net.use_diameter_parent}'}  # size of handcrafted level-0 node features (points)
_node_hf_dim: ${eval:'${model.net.use_node_hf} * ${datamodule.num_hf_segment}'}  # size of handcrafted level-1+ node features before node MLP
_node_injection_dim: ${eval:'${model.net.use_pos} * 3 + ${model.net.use_diameter} + ${model.net.use_diameter_parent} + (${model._node_mlp_out} if ${model._node_mlp_out} and ${model.net.use_node_hf} and ${model._node_hf_dim} > 0 else ${model._node_hf_dim})'}  # size of parent level-1+ node features for Stage injection input
_h_edge_hf_dim: ${datamodule.num_hf_edge}  # size of level-1+ handcrafted horizontal edge features
_v_edge_hf_dim: ${datamodule.num_hf_v_edge}  # size of level-1+ handcrafted vertical edge features

_down_dim: [64, 64, 64, 64]  # encoder stage dimensions
_up_dim: [64, 64, 64]  # decoder stage dimensions
_mlp_depth: 2  # default nb of layers in all MLPs (i.e. MLP depth)

net: ???
