# @package _global_

defaults:
  - cifar100

name: cifar10_${add_to_name}
num_classes: 10

cifar_dataset_target: torchvision.datasets.CIFAR10
cifar_dataset_path: data/cifar10

cifar10_lr_scheduler:
  _target_: torch.optim.lr_scheduler.CosineAnnealingLR
  T_max: ${max_epochs}

training_loop:
  _target_: llid.training_loops.image_classification.ImageClassificationTL
  _recursive_: False
  config:
    model: ${model}
    regs: ${regs}
    metrics: ${metrics}
    epoch_metrics: ${epoch_metrics}

    warmup_steps: 0

    lr: ${initial_lr}

    optimizer: ${optimizer}
    lr_scheduler: ${cifar10_lr_scheduler}

cifar_train_args:
  _target_: ${cifar_dataset_target}
  root: ${cifar_dataset_path}
  download: True
  train: True
  transform:
    _target_: torchvision.transforms.Compose
    _args_:
      # First arg is a list itself
      - - _target_: torchvision.transforms.RandomCrop
          _args_: [32]
          padding: 4
        - _target_: torchvision.transforms.RandomHorizontalFlip
        - _target_: torchvision.transforms.ToTensor
        - _target_: torchvision.transforms.Normalize
          _args_: 
          - [0.4914, 0.4822, 0.4465]
          - [0.2023, 0.1994, 0.2010]


cifar_val_args:
  _target_: ${cifar_dataset_target}
  root: ${cifar_dataset_path}
  download: True
  train: False
  transform:
    _target_: torchvision.transforms.Compose
    _args_:
      # First arg is a list itself
      - - _target_: torchvision.transforms.ToTensor
        - _target_: torchvision.transforms.Normalize
          _args_: 
          - [0.4914, 0.4822, 0.4465]
          - [0.2023, 0.1994, 0.2010]