wandb: v4
stage_name: stage3
name: l16-d2v2--4heads
processors:
  - kind: stage3_processor
vars:
  stage_id: xpge1tv6
  max_epochs: 20
  batch_size: 1024
  optim:
    kind: adamw
    lr: 4.0e-4
    betas: [ 0.9, 0.95 ]
    weight_decay: TEMPLATE
    schedule:
      template: ${yaml:schedules/wupcos_epoch}
      template.vars.end_epoch: 4

datasets:
  train:
    template: ${yaml:datasets/imagenet/train_byolaug_localviews}
    template.vars.version: imagenet1k
  train_noaug_10p:
    template: ${yaml:datasets/imagenet/train_noaug}
    template.vars.version: imagenet1k_10percent_simclrv2
  test:
    template: ${yaml:datasets/imagenet/test_noaug}
    template.vars.version: imagenet1k

model:
  kind: contrastive_model
  encoder:
#    kind: vit.vit
#    patch_size: 16
#    kwargs: ${select:large:${yaml:models/vit}}
    optim:
      template: ${vars.optim}
      template.weight_decay: 0.05
      template.param_group_modifiers:
        - kind: layerwise_lr_decay_modifier
          decay: 0.65
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage2
        model_name: contrastive_model.encoder
        checkpoint: last
        use_checkpoint_kwargs: true
  heads:
    main:
#      kind: ssl.nnclr_head
#      kwargs: ${select:base:${yaml:models/nnclr}}
#      pooling:
#        kind: class_token
      queue_kwargs:
        topk: 20
      optim:
        template: ${vars.optim}
        template.weight_decay: 1.0e-5
      initializers:
        - kind: previous_run_initializer
          stage_id: ${vars.stage_id}
          stage_name: stage2
          model_name: contrastive_model.heads.nnclr_temp02_cls
          checkpoint: last
          use_checkpoint_kwargs: true
    cls22:
#      kind: ssl.nnclr_head
#      kwargs: ${select:base:${yaml:models/nnclr}}
#      pooling:
#        kind: class_token
      queue_kwargs:
        topk: 20
      optim:
        template: ${vars.optim}
        template.weight_decay: 1.0e-5
      initializers:
        - kind: previous_run_initializer
          stage_id: ${vars.stage_id}
          stage_name: stage2
          model_name: contrastive_model.heads.nnclr_temp02_cls22
          checkpoint: last
          use_checkpoint_kwargs: true
    cls21:
#      kind: ssl.nnclr_head
#      kwargs: ${select:base:${yaml:models/nnclr}}
#      pooling:
#        kind: class_token
      queue_kwargs:
        topk: 20
      optim:
        template: ${vars.optim}
        template.weight_decay: 1.0e-5
      initializers:
        - kind: previous_run_initializer
          stage_id: ${vars.stage_id}
          stage_name: stage2
          model_name: contrastive_model.heads.nnclr_temp02_cls21
          checkpoint: last
          use_checkpoint_kwargs: true
    cls20:
#      kind: ssl.nnclr_head
#      kwargs: ${select:base:${yaml:models/nnclr}}
#      pooling:
#        kind: class_token
      queue_kwargs:
        topk: 20
      optim:
        template: ${vars.optim}
        template.weight_decay: 1.0e-5
      initializers:
        - kind: previous_run_initializer
          stage_id: ${vars.stage_id}
          stage_name: stage2
          model_name: contrastive_model.heads.nnclr_temp02_cls20
          checkpoint: last
          use_checkpoint_kwargs: true

trainer:
  kind: contrastive_trainer
  precision: bfloat16
  max_epochs: ${vars.max_epochs}
  effective_batch_size: ${vars.batch_size}
  log_every_n_epochs: 1
  callbacks:
    # store models
    - kind: checkpoint_callback
      every_n_epochs: 5
    - kind: ema_callback
      every_n_epochs: 5
      model_paths:
        - encoder
      target_factors:
        - 0.9999
    # callbacks for resuming
    - kind: checkpoint_callback
      every_n_epochs: 1
      save_weights: false
      save_optim: false
      save_latest_weights: true
      save_latest_optim: true
    - kind: ema_callback
      every_n_epochs: 1
      save_weights: false
      save_latest_weights: true
      model_paths:
        - encoder
      target_factors:
        - 0.9999
    # k-NN 10%
    - kind: offline_knn_metrics_callback
      every_n_epochs: 1
      train_dataset_key: train_noaug_10p
      test_dataset_key: test
      extractors: from_processor