# Fine-tuning of the BatchEnsemble ViT-32 model pretrained on ImageNet-21K.
# Fine-tuning dataset and distributional shift:
# Diabetic Retinopathy Detection, Severity Shift.

# Try a number of changes:
# Smaller LR (previous 0.005 lower bound, now 0.0001)
# Smaller minibatch size (64 and 128)
# More regularization (dropout on attention / linear layers, higher weight decay)
# Other seeds for upstream pre-training model

name: vit32-finetune-severity-batchensemble-focused-1
program: baselines/diabetic_retinopathy_detection/jax_finetune_batchensemble.py
method: random
project: vit32-finetune-severity-batchensemble-focused-1
entity: uncertainty-baselines

parameters:
  config:
    value: 'baselines/diabetic_retinopathy_detection/experiments/config/vit_l32_be_finetune.py'
  config.output_dir:
    value: 'gs://ub-vit-tuning/vit32-finetune-severity-batchensemble-focused-1'
  config.model_init:
    values: ["gs://ub-checkpoints/ImageNet21k_BE-L32/baselines-jft-0209_205214/1/checkpoint.npz",
             "gs://ub-checkpoints/ImageNet21k_BE-L32/baselines-jft-0209_205214/2/checkpoint.npz",
             "gs://ub-checkpoints/ImageNet21k_BE-L32/baselines-jft-0209_205214/3/checkpoint.npz"]
  config.distribution_shift:
    value: 'severity'
  config.use_wandb:
    value: True
  config.batch_size:
    values: [64, 128]
  config.total_and_warmup_steps:
    values: ["(12_500, 5000)", "(20_000, 7500)", "(50_000, 10000)"]
  config.lr.base:
    distribution: log_uniform
    # Values fall in exp(min) and exp(max)
    # Take log to provide bounds
    # Sweep between 0.0001 and 0.02
    min: -9.210340372
    max: -3.9120230054
  config.lr.decay_type:
    values: ['linear', 'cosine']
  config.weight_decay:
    distribution: log_uniform
    # hyper.loguniform('l2', hyper.interval(1e-6, 2e-4))
    min: -13.815510558
    max: -8.5171931914
  config.grad_clip_norm:
    values: [2.5]
  config.seed:
    values: [0, 1, 2, 3, 4]

  # Batch Ensemble parameters
  config.fast_weight_lr_multiplier:
    values: [0.5, 1.0, 2.0]
  config.model.transformer.random_sign_init:
    values: [-0.5, 0.5]
