# @package _global_
defaults:
  - /pipeline: cifar-2d
  - /model: vit
  - override /model/layer: hyena
  - override /scheduler: cosine_warmup_timm

dataset:
  permute: 2d_transpose
  augment: true

task:
  # 2 options for soft_cross_entropy (for mixup)
  loss:
    # use soft_cross_entropy for pytorch 1.10+, which takes in label_smoothing here
    _name_: soft_cross_entropy
    label_smoothing: 0.1

  # use timm_soft_cross_entropy for pytorch 1.9 and below. TIMM does not accept
  # label_smoothing here, add that to TIMM mixup args.
    # _name_: timm_soft_cross_entropy
  loss_val:
    _name_: cross_entropy

loader:
  batch_size: 50
  num_workers: 8
  persistent_workers: ${eval:"${loader.num_workers} != 0"}  # set false when using num_workers = 0

trainer:
  max_epochs: 100
  precision: 16
  devices: 1
  replace_sampler_ddp: true  # only true if using RepeatAug

train:
  seed: 1112
  ema: 0.   # if using, 0.99996
  optimizer_param_grouping:
    bias_weight_decay: false
    normalization_weight_decay: false
  remove_test_loader_in_eval: true
  global_batch_size: 50  # effective batch size (handled with multiple gpus, and accumulate_grad_batches)

optimizer:
  lr: 0.001
  weight_decay: 0.03

encoder: null
decoder: id

model:
  _name_: vit_b_16
  dropout: 0.0
  drop_path_rate: 0.1
  d_model: 64 # default 768
  depth: 4 # default 12
  expand: 4 # default 4
  norm: layer
  layer_reps: 1
  img_size: 32
  num_classes: 10
  patch_size: 2
  use_pos_embed: false
  use_cls_token: false
  track_norms: false
  layer:
    num_heads: 1
    return_state: true  # needed to hook up to other layers nicely
  #   l_max: 1024
  #   _name_: hyena1d
  #   d_state: 64  # order
  #   l_max: 1024  # Grab dataset length imagenet
  #   emb_dim: 33  # for num bands in pos enc
  #   head_dim: 1 # currently unused
  #   inner_factor: 1 # width multiplier
  #   num_blks: 1 # fft blocks
  #   groups: 192 # groups of the input conv
  #   fused_bias_fc: false
  #   # dropout: ${..dropout}
  #   kernel_cls: ImplicitTimeMemoryKernel
  #   jit_kernel: false
  #   in_kernel_cls: null # kernel_cls for the input conv, if None uses nn.Conv1d
  #   local_order: 5 # kernel size of input conv (if using explicit parametrization)
  #   verbose: true
  #   # later add these to FF
  #   activation: gelu # activation in between SS and FF
  #   ln: false # Extra normalization
  #   postact: null # activation after FF
  #   initializer: null # initializer on FF
  #   weight_norm: false # weight normalization on FF
  #   track_norms: false
  #   lr: 1e-4
  #   doublegate: false
  #   inner_norm: false
  #   num_inner_mlps: 1