# @package _global_
defaults:
  - /pipeline: imagenet
  - /model: timm/convnext_tiny
  - override /scheduler: timm_cosine  # timm_cosine or plateau

dataset:
  __l_max: [224, 224]

task:
  loss:
    _name_: soft_cross_entropy  # use soft_cross_entropy for mixup (uses PT cross entropy currently)
    label_smoothing: 0.1  # PT 1.10+ now accepts label smoothing
  loss_val:
    _name_: cross_entropy

loader:
  batch_size: 512
  batch_size_eval: 512  # default is (train) batch_size
  batch_size_test: 256  # default is (train) batch_size
  num_workers: 12
  persistent_workers: ${eval:"${loader.num_workers} != 0"}  # set False when using num_workers = 0

trainer:
  max_epochs: 310
  precision: 16
  devices: 1
  replace_sampler_ddp: ${eval:"${dataset.num_aug_repeats} == 0"}  # only True if using RepeatAug
  accumulate_grad_batches: ${eval:${train.global_batch_size} // ${.devices} // ${loader.batch_size}}

train:
  seed: 1112
  ema: 0.   # if using, 0.99996
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  remove_test_loader_in_eval: True
  global_batch_size: 128  # effective batch size (handled with multiple gpus, and accumulate_grad_batches)

optimizer:
  lr: 4e-3
  weight_decay: 0.05

scheduler:
  warmup_t: 20

# encoder: passthrough
encoder: null
decoder: id

model:
  img_size: ${dataset.__l_max}
  drop_path_rate: 0.1
  # dropout: 0.  # only layer in s4 uses this
  patch_size: 4  # 2 or 4, use for stem downsample factor
  stem_type: patch  # eg, patch, s4nd_path, s4nd
  downsample_type: null  # eg, patch, s4nd, null (for strided conv)
  stage_layers:  # null means use regular conv2d in convnext
  - null
  - null
  - null
  - null