training:
  batch_size: 256
  lr: 0.001
  seed: 42
  val_every_step: 200
  lr_warmup_steps: 10000
  lr_decay_until_steps: ${.num_steps}
  lr_decay_factor: 1e-3
  min_lr: 0.000001
  weight_decay: 0.1
  compile: false
  grad_clip_norm: None
  num_steps: 100000
  device: cuda
  amp_precision: bfloat16
  weight_precision: float32
  enable_mixed_precision: true

model:
  n_layers: 2
  name: mamba
  add_embedding_dropout: false
  dropout: 0.0
  weight_decay_on_embedding: false
  d_model: 128
  layer_type: mamba
  positive_and_negative: true
  tie_weights: false
  context_length: ${dataset.kwargs.context_length}
  vocab_size: ${dataset.kwargs.vocab_size}

dataset:
  name: form_language
  kwargs:
    synth_lang_type: modular_arithmetic_with_brackets
    vocab_size: 12
    seed: 42
    enable_mask: true
    context_length: 256
    min_sequence_length: 3
    max_sequence_length: 40
    count:
      train: 102400000
      validation: 8192
      test: 8192
    subpar:
      validation:
        min_sequence_length: 40
        max_sequence_length: 256
      test:
        min_sequence_length: 40
        max_sequence_length: 256
