# @package _global_
defaults:
  - /pipeline: imagenet
  - /model: vit/vit
  - override /model/layer: s4nd

model:
  _name_: vit_b_16
  dropout: 0.0
  drop_path_rate: 0.1
  d_model: 768
  depth: 12
  expand: 4
  norm: layer
  layer_reps: 1
  use_cls_token: false
  use_pos_embed: false

  layer:
    d_state: 64
    final_act: glu
    bidirectional: true
    channels: 2
    lr: 0.001
    n_ssm: 1
    contract_version: 1  # 0 is for 2d, 1 for 1d or 3d (or other)

task:
  # 2 options for soft_cross_entropy (for mixup)
  loss:
    # use soft_cross_entropy for pytorch 1.10+, which takes in label_smoothing here
    _name_: soft_cross_entropy
    label_smoothing: 0.1

  # use timm_soft_cross_entropy for pytorch 1.9 and below. TIMM does not accept
  # label_smoothing here, add that to TIMM mixup args.
    # _name_: timm_soft_cross_entropy
  loss_val:
    _name_: cross_entropy

loader:
  batch_size: 128
  num_workers: 8

trainer:
  max_epochs: 310
  precision: 16
  devices: 8
  use_distributed_sampler: ${eval:"${dataset.num_aug_repeats} == 0"}  # only true if using RepeatAug

train:
  seed: 1112
  ema: 0.99996
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  remove_test_loader_in_eval: true

optimizer:
  lr: 0.001
  weight_decay: 0.03

encoder: null
decoder: null
