wandb: v4
name: mae-twob14--vpt
stage_name: 1p

datasets:
  train:
    template: ${yaml:datasets/imagenet/train_minaug_smooth}
    template.vars.version: imagenet1k_msn1percent
  test:
    template: ${yaml:datasets/imagenet/test_noaug}
    template.vars.version: imagenet1k

model:
#  patch_size: 16
#  kwargs: ${select:large:${yaml:models/vit}}
  kind: vit.vit_vpt
  num_prompt_tokens: 8
  prompt_token_dim: 192
  mode: classifier
  pooling:
    kind: class_token
  freezers:
    - kind: vit_block_freezer
      end_percent: 1.0
      freeze_last_norm: true
  optim:
    kind: adamw
    lr: 1.0e-3
    betas: [ 0.9, 0.999 ]
    weight_decay: 0.0
    schedule:
      - schedule:
          kind: linear_increasing_schedule
          exclude_first: true
          exclude_last: true
        end_epoch: 5
      - schedule:
          kind: cosine_decreasing_schedule
          exclude_last: true
  initializers:
    - kind: pretrained_initializer
      weights_file: mae_twob14.pt
      use_checkpoint_kwargs: true

trainer:
  kind: classification_trainer
  precision: bfloat16
  max_epochs: 50
  effective_batch_size: 128
  log_every_n_epochs: 1
  callbacks:
    - kind: checkpoint_callback
    - kind: offline_accuracy_callback
      every_n_epochs: 1
      dataset_key: test
    - kind: best_checkpoint_callback
      every_n_epochs: 1
      metric_key: accuracy1/test/main