wandb: v4
stage_name: knn
name: ijepa-h16res448--noln-ema
processors:
  - kind: knn_per_layer_processor
    num_blocks: 32
    has_class_token: false

datasets:
  train:
    template: ${yaml:datasets/imagenet/train_noaug}
    template.vars.version: imagenet1k
    template.vars.resize_resolution: 512
    template.vars.crop_resolution: 448
  test:
    template: ${yaml:datasets/imagenet/test_noaug}
    template.vars.version: imagenet1k
    template.vars.resize_resolution: 512
    template.vars.crop_resolution: 448

model:
  kind: vit.vit
#    patch_size: 16
#    kwargs: ${select:large:${yaml:models/vit}}
  mode: classifier
  pooling:
    kind: mean_patch
  is_frozen: true
  initializers:
    - kind: pretrained_initializer
      weights_file: ijepa_huge16res448.pth.tar
      use_checkpoint_kwargs: true

trainer:
  kind: classification_trainer
  precision: bfloat16
  backup_precision: float16
  effective_batch_size: 256
  max_epochs: 0
  log_every_n_epochs: 1
  callbacks:
    - kind: offline_knn_metrics_callback
      every_n_epochs: 1
      train_dataset_key: train
      test_dataset_key: test
      extractors: from_processor