# Minimal modifications to the architecture
# Modernized version of bert-c5

# These are the huggingface bert parameters
architectures:
  - ScriptableCrammedBERT

num_transformer_layers: 12
hidden_size: 768
intermed_size: 3072
hidden_dropout_prob: 0.1

norm: LayerNorm
norm_eps: 1e-12
norm_scheme: pre
nonlin: GELU

tie_weights: True # Tie input/output embedding
decoder_bias: True # Whether to include a bias in the decoding step

sparse_prediction: ${train.objective.mlm_probability} # Whether to predict only on masked tokens, and how many there will be
loss: cross-entropy
objective_layout: MLM # can also be SCRIPT

embedding:
  vocab_size: # will be populated automatically
  pos_embedding: learned
  dropout_prob: 0.1 # equal to hidden_dropout_prob in BERT
  pad_token_id: 0
  max_seq_length: 128 # max seq length that the positional embedding is instantiated for
  embedding_dim: ${arch.hidden_size} # has to be this value for crammedBERT
  normalization: True
  stable_low_precision: False

attention:
  type: self-attention # also works with "pytorch"
  causal_attention: False
  num_attention_heads: 12
  dropout_prob: 0.1
  skip_output_projection: False
  qkv_bias: True

  rotary_embedding: False
  seq_op_in_fp32: False # whether to always cast the operation over the sequence into fp32 (e.g.. the softmax in normal attn)
  sequence_op: torch-softmax # Can be normalization

init:
  type: normal
  std: 0.02

# Experimental options:
ffn_layer_frequency: 1 # FFN layer in every layer
skip_head_transform: False # This is only possible if embedding_dim=hidden_size
use_bias: True # Whether to learn biases on all dense layers
final_norm: False # Add a final norm layer before the end

# Downstream settings:
num_labels: # This can be automatically filled in for downstream
classification_head:
  pooler: zero_index
  include_ff_layer: True
  head_dim: 768
  nonlin: Tanh
  classifier_dropout: 0.1
