# Instantiates a (non-huggingface) scriptable encoder-based LM with BERT as baseline
# Version 5 of changes to BERT architecture

# These are the huggingface bert parameters
architectures:
  - ScriptableMaskedLM

num_transformer_layers: 16
hidden_size: 768
intermed_size: 3072
hidden_dropout_prob: 0.1

norm: LayerNorm
norm_eps: 1e-12
norm_scheme: pre # can be "pre", "post", "sandwich"
nonlin: GELUglu

tie_weights: True # Tie input/output embedding
sparse_prediction: True # Whether to predict only on masked tokens
decoder_bias: False # Whether to include a bias in the decoding step
loss: cross-entropy
z_loss_factor: 0
gradient_checkpointing: False
layer_fusion: True # Fuse transformer layer residual structure

embedding:
  vocab_size: # will be populated automatically
  pos_embedding: scaled-sinusoidal
  dropout_prob: 0.1 # equal to hidden_dropout_prob in BERT
  pad_token_id: 0
  max_seq_length: 128 # max seq length that the positional embedding is instantiated for
  embedding_dim: ${arch.hidden_size} # can be smaller than hidden size (this is the ALBERT trick)
  normalization: True

attention:
  type: flash-attention-impl
  causal_attention: False
  num_attention_heads: 12
  dropout_prob: 0.1
  skip_output_projection: False
  qkv_bias: False

  rotary_embedding: False
  seq_op_in_fp32: False # whether to always cast the operation over the sequence into fp32 (e.g.. the softmax in normal attn)
  sequence_op: torch-softmax # Can be normalization
  # hybrid_layers: [10, 11]  # Only used when type=fourier-hybrid to denote self-attention layers
  high_level_fusion: False
  low_level_fusion: True

init:
  type: normal
  std: 0.02

# Very experimental options:
ffn_layer_frequency: 1 # FFN layer in every layer
deepnorm_scaling: False
skip_head_transform: True # This is only possible if embedding_dim=hidden_size
layer_drop_theta: # Set to a non-null value to dynamically drop layers
use_bias: False # Whether to learn biases on all dense layers
final_norm: True # Add a final norm layer before the end
recurrent_layers:
layer_macro_type: transformer # can also be FLASH

# Downstream settings:
num_labels: # This can be automatically filled in for downstream
classification_head:
  pooler: avg
  include_ff_layer: True
  head_dim: 1024
  nonlin: Tanh
  classifier_dropout: ${arch.hidden_dropout_prob}
