# Copyright (c) 2021 Copyright holder of the paper "Test-Time Adaptation to Distribution Shifts by Confidence Maximization and Input Transformation" submitted to NeurIPS 2021 for review
# All rights reserved.

run_name: Test-Time Adaption Data-Subsets

model_type: ResNet50 #  can be ResNet50, ResNeXt50, MobileNetV2, DenseNet121, ResNet50-DeepAugmentAugmix

imagenet_c:
  data_dir: /home/ImageNet/imagenet_corrupted
  corruptions: ["gaussian_noise", "glass_blur", "frost", "contrast"]
  severities: [5]

data_split:
  frac: 0.1  # train ratio, can be - 0.1, 0.25, 0.5, 0.75, 1.0
  by: samples_per_class # what to split on `class` or `samples_per_class`
  random_state: 42

confidence_maximization: "hard_likelihood_ratio"  # can be "soft_likelihood_ratio", "hard_likelihood_ratio", "pseudolabels", "TENT"
freeze_top_layers: True    # set False for TENT
running_diversity_regularizer: True    # set False for TENT
kappa: 0.9
# For TENT+, choose confidence_maximization: TENT, freeze_top_layers: True and running_diversity_regularizer: True

test_time_epochs: 5
test_time_batch_size: 64
parameters_to_update: input_affine
# possible "parameters_to_update" values
# (1) 'input_affine' - update model augmentation module parameters along with channel wise affine transformation parameters of a model
# (2) `affine` - update only channel wise affine transformation parameters of a model
# (3) `all` - update all parameters of a model including convolutional kernels

model_augmentation:
  enabled: True
  kernel_size: 3
  multiplicity: 6
  depth: 6
  affine: True
  normalization: True

optimizer:
  type: Adam  # SGD for TENT and TENT+
  lr: 0.0006  # 0.00025 for TENT and TENT+
  momentum: 0.9  # active only when using SGD
  weight_decay: 0

# check the "setup_mlflow_model_and_h5dataset" folder to setup the torchvision models to support mlflow
models:
  ResNet50:
    model: <insert mlflow based torchvision model path here>
    freeze: [layer4]

  MobileNetV2:
    model: <insert mlflow based  torchvision model path here>
    freeze: [features.16, features.17, features.18]

  ResNeXt50:
    model: <insert mlflow based  torchvision model path here>
    freeze: [layer4]

  DenseNet121:
    model: <insert mlflow based  torchvision model path here>
    freeze: [features.denseblock4, features.norm5]

  ResNet50-DeepAugmentAugmix:
    model: <insert mlflow based robust resnet50 model path here>
    freeze: [layer4]
