wandb: v4
stage_name: decoder-per-layer
name: b16og-decoder-per-layer

datasets:
  train:
    template: ${yaml:datasets/imagenet/train_minaug}
    template.vars.version: imagenet1k
    template.vars.min_scale: 0.2
  test:
    template: ${yaml:datasets/imagenet/test_noaug}
    template.vars.version: imagenet1k

model:
  kind: mae_decoder_per_block
  encoder:
#    kind: vit.vit
#    patch_size: 16
#    kwargs: ${select:large:${yaml:models/vit}}
    initializers:
      - kind: pretrained_initializer
        weights_file: mae_base16.pth
        use_checkpoint_kwargs: true
  decoder:
    kind: vit.mae_decoder_per_block
    depth: 2
    kwargs: ${select:default:${yaml:models/mae_decoder}}
    optim:
      kind: adamw
      lr: 1.5e-4
      betas: [ 0.9, 0.95 ]
      weight_decay: 0.05
      schedule:
        template: ${yaml:schedules/wupcos_epoch}
        template.vars.end_epoch: 5

trainer:
  kind: mae_decoder_per_block_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: 20
  effective_batch_size: 4096
  mask_generator:
    kind: random_mask_generator
    mask_ratio: 0.75
  loss_function:
    kind: mae_loss
    normalize_pixels: true
    loss_fn:
      kind: mse_loss
  log_every_n_epochs: 1
  callbacks:
    - kind: checkpoint_callback
    - kind: offline_loss_callback
      every_n_epochs: 1
      dataset_key: test
      forward_kwargs:
        mask_generator:
          template: ${trainer.mask_generator}
          seed: 0