"""
Fine-tuning of the deterministic ViT-16 model pretrained on ImageNet-21K.
Fine-tuning dataset and distributional shift:
Diabetic Retinopathy Detection, Country Shift (APTOS dataset).
"""
name: vit16-finetune-aptos-deterministic
program: baselines/diabetic_retinopathy_detection/jax_finetune_deterministic.py
method: random
project: vit16-finetune-aptos-deterministic
entity: uncertainty-baselines

parameters:
  output_dir:
    value: 'gs://ub-vit-tuning/vit16-finetune-aptos-deterministic'
  distribution_shift:
    value: 'aptos'
  use_wandb:
    value: True
  # Uses sweep configuration from focused tuning of
  # deterministic ResNet-50 model.
  batch_size:
    values: [64, 128]
  total_steps:
    values: [10000, 20000, 50000]
  lr_base:
    distribution: log_uniform
    # Values fall in exp(min) and exp(max)
    # Take log to provide bounds
    # hyper.loguniform('base_learning_rate', hyper.interval(0.03, 0.5)),
    min: -3.5065578973
    max: -0.6931471806
  lr_warmup_steps:
    values: [500, 1000, 5000]
  decay_type:
    values: ['cosine', 'linear']
  weight_decay:
    distribution: log_uniform
    # hyper.loguniform('l2', hyper.interval(1e-6, 2e-4))
    min: -13.815510558
    max: -8.5171931914
  grad_clip_norm:
    values: [1, 2, 5]
  seed:
    values: [0, 1, 2, 3, 4]