# @package _global_
defaults:
  - /pipeline: celeba-all-2d
  - /model: timm/convnext_micro

model:
  drop_path_rate: 0.1
  # dropout: 0.  # layer in s4 uses this
  stem_type: new_s4nd_patch  # eg, patch, s4nd_path, s4nd
  stem_channels: 32
  stem_l_max: [16, 16]  # stem_l_max=None,  # len of l_max in stem (if using s4)
  downsample_type: s4nd  # eg, patch, s4nd, null (for strided conv)
  downsample_glu: True
  layer:  # null means use regular conv2d in convnext
    _name_: s4nd
    d_state: 64
    channels: 1
    bidirectional: true
    activation: null  # mimics convnext style
    final_act: null
    initializer: null
    weight_norm: false
    hyper_act: null
    # dropout: ${model.dropout} # Same as null
    init: fourier
    rank: 1
    trank: 1
    dt_min: 0.01
    dt_max: 1.0
    lr: 0.001
    n_ssm: 1
    deterministic: false # Special C init
    l_max: null
    verbose: true
    linear: true
    return_state: false
    # fast_gate: false
  stage_layers:
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
  stem_layer:
    dt_min: 0.1
    dt_max: 1.0
    init: fourier
  num_classes: 40
  img_size: ${dataset.__l_max}

trainer:
  precision: 16
  devices: 1

task:
  torchmetrics: null
  metrics:
    - binary_accuracy

train:
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False

encoder: null
decoder: null

loader:
  batch_size: 512
