model:

  name: MAMBA
  class_name: MAMBA

  transformer:
    d_head: 64
    n_head: 12
    n_layer: 12

  tokenizer:
    model_max_length: 512
    tokenizer_type: 'wordlevel'  # 'gpt2' | 'wordlevel'
    vocab_file: 'vocab.json'  # omitted if tokenizer_type is 'gpt2'

training:

  # general settings
  use_tf32: true
  use_amp: true
  amp_dtype: 'bf16'
  torch_compile: false
  grad_accum_steps: 1
  grad_clip_norm: 1.0

  # data loader
  batch_size_per_gpu: 16
  num_workers: 4
  num_threads: 32
  prefetch_factor: 32

  # dataset
  path: '[censored]/childes-pretrain'
  is_train_all: True  # merge all splits and train, ignoring the following dataset configs
  train_split: ['ACL', 'NAACL', 'EMNLP']  # merge all splits in train_split and train, if is_train_all == False

  training_cycle_unit: 'steps'  # 'epochs' | 'steps'
  training_cycle_value: 20000
  seed: 142

  lr: 0.0004
  # beta1: 0.9
  # beta2: 0.99
  weight_decay: 0.4
  warmup_steps: 2000
  reset_lr: false
  reset_weight_decay: false

  # checkpoint
  checkpoint_every: 500
  checkpoint_dir: '/scratch/[censored]_root/[censored]2/[censored]/trabank-dev/experiments/checkpoints/childes_s142_mamba2_12layer_final'

  # output
  output_dir: '/scratch/[censored]_root/[censored]2/[censored]/trabank-dev/experiments/output/childes_s142_mamba2_12layer_final'

  # wandb
  api_key_path: './wandb_api_keys.yaml'
  wandb_project: 'trabank'
  wandb_exp_name: 'childes-mamba-seed142-12layer_final'
  wandb_log_every: 1
  wandb_offline: false
