# @package _global_

# to execute this experiment run:
# python run.py experiment=example_simple.yaml

defaults:
  - override /trainer: default # choose trainer from 'configs/trainer/'
  - override /model: t2tvit
  - override /datamodule: cifar10
  - override /optimizer: adamw
  - override /scheduler: null
  - override /callbacks: [default, ema]
  - override /metrics: [acc]
  - override /logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

seed: 1111

trainer:
  accelerator: gpu
  devices: 1
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes} * ${datamodule.batch_size}}}
  max_epochs: ${eval:${train.scheduler.t_initial} + ${train.cooldown_epochs}}
  precision: 16  # I get NaN loss at epoch 1 with t2t_vit_7 and precision=16


model:
  img_size: 224
  drop_path_rate: 0.1

datamodule:
  image_size: 224
  data_augmentation: autoaugment

train:
  optimizer:
    lr: 5e-4
    weight_decay: 0.05
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  scheduler:
    _target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler
    t_initial: 300
    lr_min: 1e-5 
    warmup_lr_init: 1e-6
    warmup_t: 5
    cycle_limit: 1
  scheduler_interval: epoch
  cooldown_epochs: 10
  loss_fn:
    _target_: torch.nn.CrossEntropyLoss
    label_smoothing: 0.1
  loss_fn_val:
    _target_: torch.nn.CrossEntropyLoss
  mixup:
    _target_: src.datamodules.timm_mixup.TimmMixup
    mixup_alpha: 0.5
    cutmix_alpha: 0.0
    label_smoothing: 0.0  # We're using label smoothing from Pytorch's CrossEntropyLoss

callbacks:
  ema:


