# Instantiates a (non-huggingface) scriptable decoder-based LM
# This inherits architecture changes from the crammed-bert project

model_type: ScriptableCrammedDepthRecurrent

layers_in_recurrent_block: 4
maximal_recurrence: 4
max_backprop: # use half of maximal_recurrence if not given, minimal is 1
maximal_recurrence_in_eval: ${arch.maximal_recurrence} # could be set to think longer

hidden_size: 768
intermed_size: 3072
input_injection_type: add
initial_hidden_randomized: True
state_init: embed # initialized random like embedding


norm: LayerNorm
norm_eps: 1e-12
norm_scheme: post # can be "pre", "post"

nonlin: GELUglu
sub_normalization: False # Sub-normalization in attn and ffn blocks

tie_weights: False # Tie input/output embedding
decoder_bias: False # Whether to include a bias in the decoding step
use_bias: False # Whether to learn biases on all dense layers
final_norm: True # Add a final norm layer before the end
head: ffn

objective_layout: TBPTT

embedding:
  vocab_size: # will be populated automatically
  pos_embedding: learned
  max_seq_length: ${data.seq_length} # max seq length that the positional embedding is instantiated for
  embedding_dim: ${arch.hidden_size} # has to be this value for crammedBERT
  normalization: True
  stable_low_precision: False
  max_recycle_len: 100

attention:
  type: pytorch # also works with "pytorch"
  num_attention_heads: 16 # for flash
  skip_output_projection: False
  qkv_bias: False
  bias_in_proj: False
  max_length: 0 # for randomised PE's (NOT IMPLEMENTED FOR ALL)

  rotary_embedding: False
  seq_op_in_fp32: False # whether to always cast the operation over the sequence into fp32 (e.g.. the softmax in normal attn)
  sequence_op: torch-softmax # Can be normalization
  sub_normalization: ${arch.sub_normalization} # could be turned off separately # Is only used if type=self-attention (i.e the hand-made version)

init:
  type: deepnorm-straight
  std: 0.02 # only used if type=normal

throttle: False # only active during TBPTT
alpha: 1.0 # only active during TBPTT
mask_before_equals: False
local_compilation: True # Try to compile the static block, no matter what the global compile setting is set to
loss_reduction: mean
use_old_delete_function: False # use the old delete method, i.e. calcualte the delete mask at runtime
replace_all_del_with_pad: False # replace delete with pad in all cases
forward_only_model_with_skip: False # forward only model with skip
