model:

  name: gpt2-base
  class_name: GPT2LMHeadModel

  transformer:
    d_head: 64
    n_head: 12
    n_layer: 4

  tokenizer:
    model_max_length: 512
    tokenizer_type: 'wordlevel'  # 'gpt2' | 'wordlevel'
    vocab_file: 'vocab.json'  # omitted if tokenizer_type is 'gpt2'

training:

  # general settings
  use_tf32: true
  use_amp: true
  amp_dtype: 'bf16'
  torch_compile: false
  grad_accum_steps: 1
  grad_clip_norm: 1.0

  # data loader
  batch_size_per_gpu: 16
  num_workers: 4
  num_threads: 32
  prefetch_factor: 32

  # dataset
  path: '[name censored]/childes-pretrain'
  is_train_all: True  # merge all splits and train, ignoring the following dataset configs
  train_split: ['']  # merge all splits in train_split and train, if is_train_all == False

  training_cycle_unit: 'steps'  # 'epochs' | 'steps'
  training_cycle_value: 20000
  seed: 442

  lr: 0.00005
  # beta1: 0.9
  # beta2: 0.99
  # weight_decay: 0.01
  warmup_steps: 1000
  reset_lr: false
  reset_weight_decay: false

  # checkpoint
  checkpoint_every: 500
  checkpoint_dir: '/scratch/[name censored]/[name censored]/[name censored]/experiments/checkpoints/childes_warmup_s442_4layer_pure/'

  # output
  output_dir: '/scratch/[name censored]/[name censored]/[name censored]/experiments/output/childes_warmup_s442_4layer_pure/'

  # wandb
  api_key_path: './wandb_api_keys.yaml'
  wandb_project: '[name censored]'
  wandb_exp_name: 'childes-warmup-s442-4layer_pure'
  wandb_log_every: 1
  wandb_offline: false

