# ===>loading from cfgs/shapenetpart/pointnext-s.yaml
# Number of params: 0.9817 M
# test input size: ((torch.Size([1, 2048, 3]), torch.Size([1, 3, 2048])))
# Batches npoints Params.(M)      GFLOPs
# 64      2048     0.982   4.52
model:
  NAME: BasePartSeg
  encoder_args:
    NAME: PointNextEncoder
    blocks: [ 1, 1, 1, 1, 1 ] # 1, 1, 1, 2, 1 is better, but not the main focus of this paper
    strides: [ 1, 2, 2, 2, 2 ]
    width: 32
    in_channels: 7  # better than 4,6 
    sa_layers: 3  # better than 2 
    sa_use_res: True
    radius: 0.1
    radius_scaling: 2.5
    nsample: 32 # will not improve performance. 
    expansion: 4
    aggr_args:
      feature_type: 'dp_fj'
    reduction: 'max'
    group_args:
      NAME: 'ballquery'
      normalize_dp: True
    conv_args:
      order: conv-norm-act
    act_args:
      act: 'relu' # leakrelu makes training unstable.
    norm_args:
      norm: 'bn'  # ln makes training unstable
  decoder_args:
    NAME: PointNextPartDecoder
    cls_map: curvenet
  cls_args:
    NAME: SegHead
    global_feat: max,avg  # apped global feature to each point feature
    num_classes: 50
    in_channels: null
    norm_args:
      norm: 'bn'


# ---------------------------------------------------------------------------- #
# Training cfgs
# ---------------------------------------------------------------------------- #
lr: 0.001
min_lr: null
optimizer:
  NAME: adamw
  weight_decay: 1.0e-4  # the best 

criterion_args:
  NAME: Poly1FocalLoss

# scheduler
epochs: 300
sched: multistep
decay_epochs: [210, 270]
decay_rate: 0.1
warmup_epochs: 0

datatransforms:
  train: [PointsToTensor, PointCloudScaling,PointCloudCenterAndNormalize,PointCloudJitter,ChromaticDropGPU]
  val: [PointsToTensor, PointCloudCenterAndNormalize]
  kwargs:
    jitter_sigma: 0.001
    jitter_clip: 0.005
    scale: [0.8, 1.2]
    gravity_dim: 1
    angle: [0, 1.0, 0]