# Fine-tuning of the SNGP ViT-16 model pretrained on ImageNet-21K.
# Fine-tuning dataset and distributional shift:
# Diabetic Retinopathy Detection, both Country and Severity Shifts.
# Grid search.
name: vit16-finetune-sngp-grid
program: baselines/diabetic_retinopathy_detection/jax_finetune_sngp.py
method: grid
project: vit16-finetune-sngp-grid
entity: uncertainty-baselines

parameters:
  config:
    value: 'baselines/diabetic_retinopathy_detection/experiments/config/imagenet21k_vit_base16_sngp_finetune.py'
  config.output_dir:
    value: 'gs://ub-vit-tuning/vit16-finetune-sngp-grid'
  config.distribution_shift:
    values: ['aptos', 'severity']
  config.use_wandb:
    value: True
  config.batch_size:
    values: [128]
  config.total_and_warmup_steps:
    values: ["(5_000, 250)", "(10_000, 500)", "(15_000, 750)", "(20_000, 1000)"]
  config.lr.base:
    values: [0.03, 0.01, 0.003, 0.001]
  config.lr.decay_type:
    values: ['cosine', 'linear']
  config.seed:
    values: [0, 1, 2, 3, 4, 5]

  # GP Config
  config.gp_layer.ridge_penalty:
    values: [0.0001, 0.001, 0.01]
  config.gp_layer.covmat_momentum:
    # Default in paper: 0.999
    values: [-1, 0.999, 0.9999]
  config.gp_layer.mean_field_factor:
    # The lower range of these values (up to 0.5) are based on
    # baselines/jft/experiments/vit_base16_sngp_finetune_cifar_10_and_100.py
    values: [-1., 0.1, 0.5, 1, 10, 25]
