# all transforms are imagenet, taken from
# https://github.com/pytorch/examples/blob/master/imagenet/main.py

train_dataset:
    classname: &breeds_class datasets.breeds.Breeds
    args:
        source: True
        target: False
        split: 'train'
        breeds_name: &breeds_name 'BREEDS_NAME_REPLACE'
    transforms:
        - classname: torchvision.transforms.RandomResizedCrop
          args:
              size: 224
        - classname: torchvision.transforms.RandomHorizontalFlip
        - classname: torchvision.transforms.ToTensor
        - classname: torchvision.transforms.Normalize
          args:
              mean: &norm_mean [0.485, 0.456, 0.406]
              std: &norm_std [0.229, 0.224, 0.225]

default_test_transforms:
    - classname: torchvision.transforms.Resize
      args:
          size: 256
    - classname: torchvision.transforms.CenterCrop
      args:
          size: 224
    - classname: torchvision.transforms.ToTensor
    - classname: torchvision.transforms.Normalize
      args:
          mean: *norm_mean
          std: *norm_std

test_datasets:
    - name: 'source_val'
      max_test_examples: 2000
      classname: *breeds_class
      args:
          source: True
          target: False
          split: 'val'
          breeds_name: *breeds_name
    - name: 'target_val'
      max_test_examples: 2000
      classname: *breeds_class
      args:
          source: False
          target: True
          split: 'val'
          breeds_name: *breeds_name

early_stop_dataset_names:
  - 'source_val'

log_interval: 5000
use_cuda: True
save_freq: 25
epochs: &epochs EPOCHS_REPLACE
batch_size: 96
num_workers: 2
save_all_checkpoints: False

num_classes: NUM_CLASSES_REPLACE

finetune: True
linear_probe: False # freezes all layers up to the last
use_net_val_mode: False

optimizer:
  classname: torch.optim.SGD
  args:
    lr: &lr 0.1
    momentum: 0.9
    weight_decay: 0.0001

scheduler:
  classname: torch.optim.lr_scheduler.CosineAnnealingLR
  args:
    T_max: *epochs

model:
  classname: models.imnet_resnet.ResNet50
  args:
    pretrained: True
    pretrain_style: 'swav'
    checkpoint_path: 'CHECKPOINT_PATH_REPLACE'

criterion:
  classname: torch.nn.CrossEntropyLoss

