# @package _group_

common:
  fp16: true
  log_format: json
  log_interval: 200
  tensorboard_logdir: tb
  min_loss_scale: 1e-6
  fp16_no_flatten_grads: true
  user_dir: ${env:PWD}/examples/data2vec

checkpoint:
  save_interval: 5
  save_interval_updates: 25000
  keep_interval_updates: 1
  no_epoch_checkpoints: true

task:
  _name: mae_image_pretraining
  data: /datasets01/imagenet_full_size/061417/
  rebuild_batches: true
  local_cache_path: /scratch/cache_abaevski/imagenet
  key: source
  precompute_mask_config: {}

dataset:
  num_workers: 10
  batch_size: 8
  skip_invalid_size_inputs_valid_test: true
  required_batch_size_multiple: 1
  disable_validation: true

distributed_training:
  distributed_world_size: 32
  ddp_backend: c10d

criterion:
  _name: model
  log_keys:
    - ema_decay
    - target_var
    - pred_var
    - model_norm
    - ema_norm
    - masked_pct

optimization:
  max_update: 500000
  lr: [ 0.0004 ]
  debug_param_names: true
  clip_norm: 4

optimizer:
  _name: composite
  dynamic_groups: true
  groups:
    default:
      lr_float: 4e-4
      optimizer:
        _name: adam
        adam_betas: [0.9,0.95]
        weight_decay: 0.05
      lr_scheduler:
        _name: cosine
        warmup_updates: 50040

lr_scheduler: pass_through

model:
  _name: data2vec_multi

  ema_decay: 0.9998
  ema_end_decay: 1
  ema_anneal_end_step: 300000
  instance_norm_target_layer: true
  layer_norm_target_layer: false
  layer_norm_targets: true
  end_of_block_targets: false

  depth: 32
  embed_dim: 1280
  num_heads: 16

  average_top_k_layers: 24
  clone_batch: 16

  norm_eps: 1e-6

  min_target_var: 0
  min_pred_var: 0

  encoder_dropout: 0
  post_mlp_drop: 0
  attention_dropout: 0
  activation_dropout: 0

  supported_modality: IMAGE
  cls_loss: 0.01

  ema_encoder_only: false

  modalities:
    image:
      patch_size: 14
      inverse_mask: true
      mask_prob: 0.75
      mask_prob_adjust: 0.1
      mask_length: 3
      mask_noise_std: 0.01
      prenet_depth: 0
      ema_local_encoder: true
      num_extra_tokens: 1
      init_extra_token_zero: false
      use_alibi_encoder: false
      embed_dim: 1280
      decoder:
        decoder_dim: 1024
        decoder_groups: 16
        decoder_kernel: 5
        decoder_layers: 3
        final_layer_norm: false
        input_dropout: 0