# Version 4 of changes to bert training hyperparameters
# Optimizes MLM rate for torch.compile, includes improved weight decay limitation, finally updated to a relative bs ramp

name: cramming-o4

defaults:
  - optim: adam
  - optim_mod: disabled

optim:
  lr: 1e-3
  weight_decay: 0.01

limited_decay_keys: [bias, LayerNorm.bias, LayerNorm.weight, norm] # no weight decay for these layers

# steps:
warmup_steps: 0.1
cooldown_steps: 0.1
steps: 12_000_000 # these are microbatch steps. This is an upper limit that is usually never reached
scheduler: budget-constant

# Training settting:
stream_depth: ${data.seq_length} # full sequence as input to model
batch_size: 8192
batch_size_ramp: 0.60

gradient_clipping: 0.5
pretrain_in_train_mode: True # default BERT trains with dropout layers enabled in pretrain
reverse_dataset_order: False

budget: ${budget}
overall_budget: ${overall_budget}

# for loading previously saved
arch_modifications: null
# checkpoint name:
# This can be either "latest", or a reference to a specific checkpoint in a subfolder
checkpoint: latest
path: ${impl.path} # Path for caches of datasets and tokenizers
