wandb: v4
stage_name: probe
name: ???
vars:
  epochs: ???
  stage_id: ???
  pooling: ???

processors:
  - kind: probe_processor
    grid: dinov2
    probe_kind: linear_probe
    poolings:
      - ${vars.pooling}
      - concat_class_average

datasets:
  train:
    template: ${yaml:datasets/imagenet/train_minaug}
    template.vars.version: imagenet1k
  test:
    template: ${yaml:datasets/imagenet/test_noaug}
    template.vars.version: imagenet1k

model:
  kind: probe_model
  encoder:
#    kind: vit.vit
#    patch_size: 16
#    kwargs: ${select:large:${yaml:models/vit}}
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage3
        model_name: contrastive_model.encoder
        model_info: ema=0.9999
        checkpoint: last
        use_checkpoint_kwargs: true
  heads: from_processor

trainer:
  kind: classification_trainer
  precision: bfloat16
  effective_batch_size: 1024
  max_epochs: ${vars.epochs}
  log_every_n_epochs: 1
  callbacks:
    # accuracy
    - kind: offline_accuracy_callback
      every_n_epochs: 1
      topk: [ 1 ]
      dataset_key: test
    # save last model (+ saves encoder)
    - kind: checkpoint_callback
