# @package _global_

# to execute this experiment run:
# python run.py experiment=example_simple.yaml

defaults:
  - override /trainer: default # choose trainer from 'configs/trainer/'
  - override /model: mixers-cifar
  - override /model/channel_mlp_cfg: null
    #- override /model/token_mlp_cfg: null
  - override /datamodule: cifar10
  - override /optimizer: adamw
  - override /scheduler: null
  - override /callbacks: default
  - override /metrics: [acc]
  - override /logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

seed: 1111

trainer:
  accelerator: gpu
  devices: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size}}}
  max_epochs: ${eval:${train.scheduler.t_initial} + ${train.cooldown_epochs}}
  precision: 16
  gradient_clip_val: 1.0

datamodule:
  batch_size: 1024  # Per GPU
  num_workers: 4  # Per GPU
  image_size: 32
  data_augmentation: autoaugment

model:
  img_size: 32
  drop_path_rate: 0.1
  drop_rate: 0.0

train:
  global_batch_size: 1024
  optimizer:
    lr: 5e-4
    weight_decay: 0.1
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  scheduler:
    _target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler
    t_initial: 300
    lr_min: 5e-6
    warmup_lr_init: 5e-5
    warmup_t: 5
    cycle_limit: 1
  scheduler_interval: epoch
  cooldown_epochs: 10
  loss_fn:
    _target_: torch.nn.CrossEntropyLoss
    label_smoothing: 0.1
  loss_fn_val:
    _target_: torch.nn.CrossEntropyLoss
  mixup:
    _target_: src.datamodules.timm_mixup.TimmMixup
    mixup_alpha: 0.5
    label_smoothing: 0.0

