# Version 4 of changes to bert training hyperparameters
# Optimizes MLM rate for torch.compile, includes improved weight decay limitation, finally updated to a relative bs ramp

name: cramming-o4

defaults:
  - optim: adam
  - optim_mod: disabled

optim:
  lr: 1e-3
  weight_decay: 0.01

limited_decay_keys: [bias, LayerNorm.bias, LayerNorm.weight, norm] # no weight decay for these layers

# steps:
warmup_steps: 0.1
cooldown_steps: 0.1
steps: 4_000_000 # these are microbatch steps. This is an upper limit that is usually never reached
scheduler: budget-constant

# Training settting:
stream_depth: 2 # Train one token at a time
batch_size: 16384
batch_size_ramp: 0.60

gradient_clipping: 0.5
pretrain_in_train_mode: True # default BERT trains with dropout layers enabled in pretrain
reverse_dataset_order: False

budget: ${budget}
