# @package _global_
defaults:
  - /pipeline: celeba-all-2d
  - /model: timm/convnext_micro

model:
  drop_path_rate: 0.1
  # dropout: 0.  # layer in s4 uses this
  stem_type: patch
  stem_channels: 32
  stem_l_max: [16, 16]  # stem_l_max=None,  # len of l_max in stem (if using s4)
  downsample_type: patch  # eg, patch, s4nd, null (for strided conv)
  stage_layers:
    - null
    - null
    - null
    - null
  num_classes: 40
  img_size: ${dataset.__l_max}

trainer:
  precision: 16
  devices: 1

train:
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False

encoder: null
decoder: null

loader:
  batch_size: 512
