defaults:
  - default
  - _self_

_target_: meds_torch.models.components.transformer_decoder.TransformerDecoderModel.initialize

model:
  _target_: x_transformers.TransformerWrapper
  num_tokens: ${data.vocab_size}
  max_seq_len: ${model.max_seq_len}
  emb_dropout: ${model.backbone.dropout}
  use_abs_pos_emb: false
  attn_layers:
    _target_: x_transformers.Decoder
    dim: ${model.token_dim}
    depth:  ${model.backbone.n_layers}
    heads: ${model.backbone.nheads}
    layer_dropout: ${model.backbone.dropout}  # stochastic depth - dropout entire layer
    attn_dropout: ${model.backbone.dropout}  # dropout post-attention
    ff_dropout: ${model.backbone.dropout}  # feedforward dropout
    rotary_pos_emb: true
    attn_flash: True

token_emb: null # use x-transformers to embed tokens
max_seq_len: ${model.max_seq_len} # is the sampling temperature, `0` corresponds to greedy sampling
eos_token_id: null # stop sampling when this token is generated
