wandb: v4
stage_name: probe-cifar10
name: dinov2-giga14
vars:
  epochs: 50
  pooling: class_token

processors:
  - kind: probe_processor
    grid: dinov2
    probe_kind: linear_probe
    poolings:
      - ${vars.pooling}
      - concat_class_average

datasets:
  train:
    kind: cifar10
    split: train
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: kd_random_resized_crop
            size: 224
            scale:
              - 0.08
              - 1.0
            interpolation: bicubic
          - kind: kd_random_horizontal_flip
          - kind: kd_imagenet_norm
  test:
    kind: cifar10
    split: test
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: kd_resize
            size: 224
            interpolation: bicubic
          - kind: kd_imagenet_norm

model:
  kind: probe_model
  encoder:
    kind: torch_hub_model
    repo: facebookresearch/dinov2
    model: dinov2_vitg14
  heads: from_processor

trainer:
  kind: classification_trainer
  precision: bfloat16
  effective_batch_size: 1024
  max_epochs: ${vars.epochs}
  log_every_n_epochs: 1
  callbacks:
    # accuracy
    - kind: offline_accuracy_callback
      every_n_epochs: 1
      topk: [ 1 ]
      dataset_key: test
