# @package _global_
defaults:
  - /pipeline: imagenet
  - /model: timm/convnext_tiny
  - override /scheduler: timm_cosine  # timm_cosine or plateau

dataset:
  __l_max: [224, 224]

task:
  # 2 options for soft_cross_entropy (for mixup)
  loss:
    # use soft_cross_entropy for pytorch 1.10+, which takes in label_smoothing here
    _name_: soft_cross_entropy
    label_smoothing: 0.1

  # use timm_soft_cross_entropy for pytorch 1.9 and below. TIMM does not accept
  # label_smoothing here, add that to TIMM mixup args.
    # _name_: timm_soft_cross_entropy
  loss_val:
    _name_: cross_entropy

loader:
  batch_size: 160
  batch_size_eval: 160
  batch_size_test: 80
  num_workers: 12
  persistent_workers: ${eval:"${loader.num_workers} != 0"}  # set false when using num_workers = 0

trainer:
  max_epochs: 310
  precision: 16
  devices: 8
  use_distributed_sampler: ${eval:"${dataset.num_aug_repeats} == 0"}  # only true if using RepeatAug
  accumulate_grad_batches: ${eval:${train.global_batch_size} // ${.devices} // ${loader.batch_size}}

train:
  seed: 1112
  ema: 0.   # if using, 0.99996
  optimizer_param_grouping:
    bias_weight_decay: false
    normalization_weight_decay: false
  remove_test_loader_in_eval: true
  global_batch_size: 128  # effective batch size (handled with multiple gpus, and accumulate_grad_batches)

optimizer:
  lr: 4e-3
  weight_decay: 0.05

scheduler:
  warmup_t: 20

# encoder: passthrough
encoder: null
decoder: id

model:
  img_size: ${dataset.__l_max}
  drop_path_rate: 0.1
  patch_size: 4  # 2 or 4, use for stem downsample factor
  stem_channels: 32  # only used for s4nd stem currently
  stem_type: new_s4nd_patch  # options: patch (regular convnext), s4nd_patch, new_s4nd_patch (best), s4nd
  stem_l_max: [16, 16]  # stem_l_max=None,  # len of l_max in stem (if using s4)
  downsample_type: s4nd  # eg, s4nd, null (for regular strided conv)
  downsample_act: false
  downsample_glu: True
  conv_mlp: false
  custom_ln: false # only used if conv_mlp=1, should benchmark to make sure this is faster/more mem efficient, also need to turn off weight decay
  layer:  # null means use regular conv2d in convnext
    _name_: s4nd
    d_state: 64
    channels: 1
    bidirectional: true
    activation: null  # mimics convnext style
    final_act: none
    initializer: null
    weight_norm: false
    dropout: 0
    tie_dropout: ${oc.select:model.tie_dropout,null}
    init: fourier
    rank: 1
    trank: 1
    dt_min: 0.01
    dt_max: 1.0
    lr: 0.001
    # length_correction: true
    n_ssm: 1
    deterministic: false # Special C init
    l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to null and kernel will automatically resize
    verbose: true
    linear: true
    return_state: false
    bandlimit: null
    contract_version: 0  # 0 is for 2d, 1 for 1d or 3d (or other)
  stem_layer:
    dt_min: 0.1
    dt_max: 1.0
    init: fourier
  stage_layers:
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
