# @package _global_

# to execute this experiment run:
# python run.py experiment=example_simple.yaml

defaults:
  - override /trainer: default # choose trainer from 'configs/trainer/'
  - override /model: t2tvit
  - override /datamodule: imagenet
  - override /optimizer: adamw
  - override /scheduler: null
  - override /callbacks: [default, ema]
  - override /metrics: [acc, acctop5]
  - override /logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

seed: 1111

trainer:
  accelerator: gpu
  devices: 8
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes} * ${datamodule.batch_size}}}
  max_epochs: ${eval:${train.scheduler.t_initial} + ${train.cooldown_epochs}}
  precision: 32  # I get NaN loss at epoch 1 with t2t_vit_7 and precision=16

datamodule:
  batch_size: 64  # Per GPU
  num_workers: 8  # Per GPU
  image_size: 224
  train_transforms:
    _target_: timm.data.create_transform
    input_size: ${datamodule.image_size}
    is_training: True
    auto_augment: rand-m9-mstd0.5-inc1  # Use AutoAugment policy
    interpolation: random
    re_prob:  0.25  # Random erase prob
    re_mode: pixel  # Random erase mode
  val_transforms:  # Taken from model definition in t2t_vit.py
    _target_: timm.data.create_transform
    input_size: ${datamodule.image_size}
    interpolation: bicubic
    crop_pct: 0.9
  test_transforms: ${.val_transforms}

train:
  global_batch_size: 512
  optimizer:
    lr: 1e-3
    weight_decay: 0.03
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  scheduler:
    _target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler
    t_initial: 300
    lr_min: 1e-5
    warmup_lr_init: 1e-6
    warmup_t: 10
    cycle_limit: 1
  scheduler_interval: epoch
  cooldown_epochs: 10
  loss_fn:
    _target_: torch.nn.CrossEntropyLoss
    label_smoothing: 0.1
  loss_fn_val:
    _target_: torch.nn.CrossEntropyLoss
  mixup:
    _target_: src.datamodules.timm_mixup.TimmMixup
    mixup_alpha: 0.8
    cutmix_alpha: 1.0
    label_smoothing: 0.0  # We're using label smoothing from Pytorch's CrossEntropyLoss


callbacks:
  ema:
    decay: 0.99996

