# @package __global__

defaults:
  - override /conditioner: none

lm_model: transformer_raw_lm

transformer_raw_lm:  # small size
  dim: 768
  max_tr: ${dataset.max_tr}
  tr_precision: ${dataset.tr_precision}
  space_dim: ${space_dim}
  num_heads: 12
  num_layers: 4
  hidden_scale: 4
  dropout: 0.1
  activation: gelu
  norm_first: true        # use pre-norm instead of post-norm
  bias_ff: false            # use bias for the feedforward
  bias_attn: false          # use bias for the attention
  bias_proj: false          # use bias for the output projections
  past_context: null
  causal: true
  custom: false                 # use custom MHA implementation
  memory_efficient: true       # use flash attention
  attention_as_float32: false   # use float32 for the attention part,
                                # recommended at the moment when memory_efficient is True.
  layer_scale: null
  positional_embedding: sin     # positional embedding strategy (sin, rope, or sin_rope).
  xpos: false                   # apply xpos decay (rope only).
  checkpointing: none      # layer checkpointing method, can be none, torch, xformers_default.
                           # torch is the slowest but uses the least memory,
                           # xformers_default is somewhere in between.
  weight_init: gaussian     # weight initialization (null, gaussian or uniform)
  depthwise_init: current  # perform depthwise initialization (null, current, global)
  zero_bias_init: true # initialize bias to zero if bias in linears and
                        # if a weight_init method is used.
  norm: layer_norm             # normalization method to use in transformer.
  cross_attention: false
  qk_layer_norm: false
  qk_layer_norm_cross: false
  attention_dropout: null
  kv_repeat: 1
  two_step_cfg: false          # whether to do true 2 steps CFG, potentially resolving some padding issues or not...
