# @package _global_

add_to_name: run
name: cifar100_${add_to_name}
num_classes: 100


cifar_dataset_target: torchvision.datasets.CIFAR100
cifar_dataset_path: data/cifar100
train_data_pct: 1

cifar_train_args:
  _target_: ${cifar_dataset_target}
  root: ${cifar_dataset_path}
  download: True
  train: True
  transform:
    _target_: torchvision.transforms.Compose
    _args_:
      # First arg is a list itself
      - - _target_: torchvision.transforms.RandomCrop
          _args_: [32]
          padding: 4
        - _target_: torchvision.transforms.RandomHorizontalFlip
        - _target_: torchvision.transforms.RandomRotation
          _args_: [15]
        - _target_: torchvision.transforms.ToTensor
        - _target_: torchvision.transforms.Normalize
          _args_: 
          - [0.5070751592371323,0.48654887331495095,0.4409178433670343]
          - [0.2673342858792401,0.2564384629170883,0.27615047132568404]


cifar_val_args:
  _target_: ${cifar_dataset_target}
  root: ${cifar_dataset_path}
  download: True
  train: False
  transform:
    _target_: torchvision.transforms.Compose
    _args_:
      # First arg is a list itself
      - - _target_: torchvision.transforms.ToTensor
        - _target_: torchvision.transforms.Normalize
          _args_: 
          - [0.5070751592371323,0.48654887331495095,0.4409178433670343]
          - [0.2673342858792401,0.2564384629170883,0.27615047132568404]

num_workers: 24

alternating_reg_loss: False

train_batch_size: 128
train_dataloader:
  _target_: torch.utils.data.DataLoader
  batch_size: ${train_batch_size}
  shuffle: True
  num_workers: ${num_workers}
  pin_memory: True
  persistent_workers: True

val_batch_size: 128
val_dataloader:
  _target_: torch.utils.data.DataLoader
  batch_size: ${val_batch_size}
  shuffle: False
  num_workers: ${num_workers}
  pin_memory: True
  persistent_workers: True

data:
  _target_: llid.data.cifar.CIFAR
  _recursive_: False
  config:
    cifar_train_args: ${cifar_train_args}
    cifar_val_args: ${cifar_val_args}
    train_dataloader: ${train_dataloader}
    val_dataloader: ${val_dataloader}
    train_data_pct: ${train_data_pct}

epoch_metrics:
  _target_: llid.utils.metrics.EpochMetricsList
  _args_:
    - _target_: llid.utils.metrics.EpochClassAccuracy

metrics:
  _target_: llid.utils.metrics.MetricsList
  _args_:
    - _target_: llid.utils.metrics.Accuracy
  
regs:
  _target_: llid.utils.regs.NoRegularization

fc_dropout: 0
block_dropout: 0
model:
  _target_: llid.models.resnet.resnet18
  num_blocks: [2, 2, 2, 2]
  num_classes: ${num_classes}
  fc_dropout: ${fc_dropout}
  block_dropout: ${block_dropout}


initial_lr: 0.1
id_estimation_layers: []

cifar100_lr_scheduler:
  _target_: torch.optim.lr_scheduler.MultiStepLR
  milestones: [60, 120, 160]
  gamma: 0.2

training_loop:
  _target_: llid.training_loops.image_classification.ImageClassificationTL
  _recursive_: False
  config:
    model: ${model}
    regs: ${regs}
    metrics: ${metrics}
    epoch_metrics: ${epoch_metrics}
    alternating_reg_loss: ${alternating_reg_loss}
    l1_norm_coef: 0

    warmup_steps: 0

    lr: ${initial_lr}

    optimizer: ${optimizer}
    lr_scheduler: ${cifar100_lr_scheduler}

    class_weights: ${class_weights}
    loss_weight_alpha: ${loss_weight_alpha}
    num_classes: ${num_classes}

norm_loss: False
weight_decay: 5e-4

loss_weight_alpha: 1
class_weights: None

optimizer:
  _target_: torch.optim.SGD
  momentum: 0.9
  lr: ${initial_lr}
  weight_decay: ${weight_decay}