# Version 4 of changes to bert training hyperparameters
# Optimizes MLM rate for torch.compile, includes improved weight decay limitation, finally updated to a relative bs ramp

name: bert-o4

defaults:
  - optim: adam
  - optim_mod: disabled

optim:
  lr: 1e-3
  weight_decay: 0.01

limited_decay_keys: [bias, LayerNorm.bias, LayerNorm.weight, norm] # no weight decay for these layers

# Steps:
warmup_steps: 0
cooldown_steps: 0
steps: 900_000 # these are microbatch steps. This is an upper limit that is usually never reached.
scheduler: budget-triangle2

# Training settings:
batch_size: 8192
batch_size_ramp: 0.60 # in percentage points, i.e. training reaches the target 8192 at 60% of the total budget.

gradient_clipping: 0.5
pretrain_in_train_mode: False # default BERT trains with dropout layers enabled in pretrain

objective:
  name: masked-lm
  mlm_probability: 0.25 # leads to mbs * seq_length * prob = 4096 masked entries, for mbs=128, seq=128
  # Use 0.25 for mbs=64 and 0.16666 for mbs=96, at least this many sixes to round up to 2048
  # f = lambda x: pow(2, math.ceil(math.log(x)/math.log(2)))
  # f(mbs * seq * 0.15) / mbs / seq
  use_80_20_rule: True
  disable_mlm: False
  token_drop: 0.0
reverse_dataset_order: False

budget: ${budget}
