# Instantiates a (non-huggingface) scriptable janus-type RNN, right now with all tested bells-and-whistles

# These are the huggingface bert parameters
model_type: ScriptableCrammedJanus

num_transformer_layers: 8
state_dim: 3584

norm_scheme: shaped
norm: LayerNorm
norm_eps: 1e-12

nonlin: GELUglu
sub_normalization: False # Sub-normalization in attn and ffn blocks

tie_weights: True # Tie input/output embedding
decoder_bias: False # Whether to include a bias in the decoding step
use_bias: True # Whether to learn biases on all dense layers
final_norm: True # crashes without this improvement to stability
force_normalized_state: True # last normalization learnable?

loss: cross-entropy
objective_layout: autoregressive # nothing else implemented so far

ffn_block:
  structure: stack-sideways-transformer
  intermed_multiplier: 4
  hidden_dropout_prob: 0.0

  # settings only relevant for structure=state-attention:
  qkv_bias: True
  proj_bias: True
  num_chunks_in_sequence: 16
  num_read_write_heads: 8
  run_causal_heads: False
  positional_info: True
  garbage_collect_state: False
  num_blocks_to_accumulate: 0 # Can be any number of embedding chunks that will added to state, this is N^2 atttention again :>
  gradient_checkpointing: False
  workspace: ${arch.ffn_block.num_chunks_in_sequence} # only used if block in structure, can be smaller than num_chunks_in_sequence

head:
  structure: chunked # dense-nonlin-norm
  nonlin: GELU
  norm: LayerNorm
  norm_eps: 1e-12
  use_bias: True
  include_attn_in_chunked_heads: True # only valid for chunked heads
  num_chunked_heads: 4 # only valid for chunked heads
  intermed_multiplier: 4

objective:
  historian_weight: 1.0
  predictor_weight: 1.0
  present_historian_weight: 1.0
  present_predictor_weight: 1.0
  rscale_correction: False

  antiquarian_weight: 0.0 #
  antiquarian_range: ${data.seq_length} # maximal range a previous state may be looked up with # set to -1 to encompass all previous states
  historian_loss_fn: MSE

embedding:
  vocab_size: # will be populated automatically
  pos_embedding:
  embedding_dim: 512
  normalization: True
  stable_low_precision: False
  max_seq_length: ${data.seq_length} # legacy position, do not use


max_seq_length: ${data.seq_length} # max seq length during training (not always used)
position_information: learned # none learned or simple

init:
  type: deepnorm-straight
  std: 0.02

# Experimental options:
state_corruption: 0.0
eos_state_reset: True
state_init: unit

# Set dynamically:
eos_token_id:
