"""
Fine-tuning of the SNGP ViT-16 model pretrained on ImageNet-21K.
Fine-tuning dataset and distributional shift:
Diabetic Retinopathy Detection, Country Shift (APTOS dataset).
"""
name: vit16-finetune-aptos-sngp
program: baselines/diabetic_retinopathy_detection/jax_finetune_sngp.py
method: random
project: vit16-finetune-aptos-sngp
entity: uncertainty-baselines

parameters:
  output_dir:
    value: 'gs://ub-vit-tuning/vit16-finetune-aptos-sngp'
  distribution_shift:
    value: 'aptos'
  use_wandb:
    value: True
  batch_size:
    values: [64, 128]
  total_steps:
    values: [12500, 20000, 50000]
  lr_base:
    distribution: log_uniform
    # Values fall in exp(min) and exp(max)
    # Take log to provide bounds
    # Sweep between 0.005 and 0.05
    min: -5.2983173665
    max: -2.9957322736
  lr_warmup_steps:
    values: [5000, 7500, 10000]
  lr_decay_type:
    values: ['linear']
  weight_decay:
    distribution: log_uniform
    # hyper.loguniform('l2', hyper.interval(1e-6, 2e-4))
    min: -13.815510558
    max: -8.5171931914
  grad_clip_norm:
    values: [2.5]
  seed:
    values: [0, 1, 2, 3, 4]

  # SNGP params
  # See paper: https://arxiv.org/pdf/2006.10108.pdf
  # See also https://www.tensorflow.org/tutorials/understanding/sngp
  sngp_ridge_penalty:
    distribution: log_uniform
    # Values fall in exp(min) and exp(max)
    # Take log to provide bounds
    # Default value in paper: 0.001
    # Sweep between 0.0001 and 0.01
    min: -9.210340372
    max: -4.605170186

  sngp_covmat_momentum:
    # Default in paper: 0.999
    values: [-1, 0.1, 0.25, 0.5, 0.9, 0.99, 0.999, 0.9999]

  sngp_mean_field_factor:
    # The lower range of these values (up to 0.5) are based on
    # baselines/jft/experiments/vit_base16_sngp_finetune_cifar_10_and_100.py
    values: [-1., 0.1, 0.15, 0.2, 0.25, 0.3, 0.5, 1, 5, 10, 25]
