attention:
  head_size: 64
  num_heads: 12
  probs_dropout: 0.
  qk_norm: true
  implementation: "flash_attention_2"

embedding:
  dim: 768
  max_position_embeddings: 128
  initializer_range: 0.02

hidden:
  size: 768
  dropout: 0.1
  num_layers: 12
  ff_mult: 4

latent:
  dim: 768
  num_latents: 16

normalization:
  layer_eps: 1.0e-5

model:
  text_encoder: "bert-base-cased"
  text_encoder_freeze_params: true
  mlm_probabilities: [1.]
  bert_masking: true

tokens:
  vocab_size: 28996
  mask_token_id: 103

finetuning:
  max_std: 0.5
  training_iters: 1000
  logging:
    log_freq: 1
    save_freq: 1000
    eval_freq: 1000
