defaults:
  - default
  - _self_

_target_: meds_torch.models.components.transformer_decoder.TransformerDecoderModel.initialize

model:
  _target_: x_transformers.TransformerWrapper
  logits_dim: ${model.token_dim}
  num_tokens: ${data.vocab_size}
  max_seq_len: ${model.max_seq_len}
  emb_dropout: ${model.backbone.dropout}
  use_abs_pos_emb: false
  attn_layers:
    _target_: x_transformers.Decoder
    dim: ${model.token_dim}
    depth:  ${model.backbone.n_layers}
    heads: ${model.backbone.nheads}
    layer_dropout: ${model.backbone.dropout}  # stochastic depth - dropout entire layer
    attn_dropout: ${model.backbone.dropout}  # dropout post-attention
    ff_dropout: ${model.backbone.dropout}  # feedforward dropout
    rotary_pos_emb: true
    attn_flash: True

token_emb:
  _target_: torch.nn.Identity # assume data is pre-embedded by the input_encoder
