# @package __global__
defaults:
  - _self_
  - /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly

lm_model: transformer_lm

codebooks_pattern:
  modeling: parallel

transformer_lm:
  dim: 512
  num_heads: 8
  num_layers: 8
  hidden_scale: 4
  n_q: 8                   # number of streams to model
  card: 1024
  dropout: 0.
  emb_lr: null
  activation: gelu
  norm_first: false        # use pre-norm instead of post-norm
  bias_ff: true            # use bias for the feedforward
  bias_attn: true          # use bias for the attention
  bias_proj: true          # use bias for the output projections
  past_context: null
  causal: true
  custom: false                 # use custom MHA implementation
  memory_efficient: false       # use flash attention
  attention_as_float32: false   # use float32 for the attention part,
                                # recommended at the moment when memory_efficient is True.
  layer_scale: null
  positional_embedding: sin     # positional embedding strategy (sin, rope, or sin_rope).
  xpos: false                   # apply xpos decay (rope only).
  checkpointing: none      # layer checkpointing method, can be none, torch, xformers_default.
                           # torch is the slowest but uses the least memory,
                           # xformers_default is somewhere in between.
  weight_init: null     # weight initialization (null, gaussian or uniform)
  depthwise_init: null  # perform depthwise initialization (null, current, global)
  zero_bias_init: false # initialize bias to zero if bias in linears and
                        # if a weight_init method is used.
  norm: layer_norm             # normalization method to use in transformer.
  cross_attention: false
  qk_layer_norm: false
  qk_layer_norm_cross: false
  attention_dropout: null
  kv_repeat: 1
  two_step_cfg: false          # whether to do true 2 steps CFG, potentially resolving some padding issues or not...
