wandb: v4
name: 3duek90m--e50-lr1e3-sd0-minaug--split3
stage_name: inat18-5shot
vars:
  pooling: class_token
  stage_id: 3duek90m

  max_epochs: 50
  lr: 1.0e-3
  batch_size: 256
  drop_path_rate: 0.0

datasets:
  train:
    template: ${yaml:datasets/inat/train_rescale_minaug_smooth}
    template.vars.version: inat18_5shot_split3
  test:
    template: ${yaml:datasets/inat/test_noaug}
    template.vars.version: inat18

model:
#  kind: vit.vit
#  patch_size: 16
#  kwargs: ${select:large:${yaml:models/vit}}
  pos_embed_is_learnable: true
  drop_path_rate: ${vars.drop_path_rate}
  drop_path_decay: true
  mode: classifier
  pooling:
    kind: ${vars.pooling}
  optim:
    kind: adamw
    lr: ${vars.lr}
    betas: [ 0.9, 0.999 ]
    weight_decay: 0.05
    schedule:
      - schedule:
          kind: linear_increasing_schedule
          exclude_first: true
          exclude_last: true
        end_epoch: 5
      - schedule:
          kind: cosine_decreasing_schedule
          exclude_last: true
    param_group_modifiers:
      - kind: layerwise_lr_decay_modifier
        decay: 0.75
  initializers:
    - kind: previous_run_initializer
      stage_id: ${vars.stage_id}
      stage_name: stage3
      model_name: contrastive_model.encoder
      model_info: ema=0.9999
      checkpoint: last
      use_checkpoint_kwargs: true

trainer:
  kind: classification_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: ${vars.max_epochs}
  effective_batch_size: ${vars.batch_size}
  log_every_n_epochs: 1
  callbacks:
    - kind: checkpoint_callback
    - kind: offline_accuracy_callback
      every_n_epochs: 1
      dataset_key: test
    - kind: best_checkpoint_callback
      every_n_epochs: 1
      metric_key: accuracy1/test/main