ckpt_step: CHECKPOINT_STEP  # Choose between [40000, 113000, 339000]
dataset_path: PATH_TO_DATASET
ckpt_dir: DIRECTORY_CONTAINS_OFFICIAL_CHECKPOINTS


run_name: OLMo-7B-${ckpt_step}
seed: 6198
dry_run: false

model:
  d_model: 4096
  n_heads: 32
  n_layers: 32
  mlp_hidden_size: 22016
  weight_tying: false
  alibi: false
  rope: true
  flash_attention: true
  attention_dropout: 0.0
  attention_layer_norm: false
  multi_query_attention: false
  include_bias: false
  block_type: sequential
  layer_norm_type: default
  layer_norm_with_affine: false
  bias_for_layer_norm: false
  attention_layer_norm_with_affine: false
  activation_type: swiglu
  residual_dropout: 0.0
  embedding_dropout: 0.0
  max_sequence_length: 2048
  vocab_size: 50280
  embedding_size: 50304
  eos_token_id: 50279
  pad_token_id: 1
  init_device: meta
  init_fn: mitchell

compile: null

optimizer:
  name: adamw
  learning_rate: 3.0e-4
  weight_decay: 0.1
  betas:
  - 0.9
  - 0.95
  metrics_log_interval: 10

scheduler:
  name: linear_with_warmup
  t_warmup: 5000
  alpha_f: 0.1
  grad_clip_warmup_steps: 1000
  grad_clip_warmup_factor: 10.0

tokenizer:
  identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
  truncate_direction: right

save_folder: ${oc.env:SCRATCH_DIR,no_exist}/checkpoints/${run_name}
remote_save_folder: null
save_overwrite: true
# Sharded checkpoints (best for restarts)
save_interval: 1000
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1

load_path: ${ckpt_dir}/step${ckpt_step}-unsharded
inject_indices_map: ${oc.env:SCRATCH_DIR,no_exist}/analysis/inject_indices_map/7b-0.pkl
probe_dataset: ${oc.env:SCRATCH_DIR,no_exist}/fictional_knowledge/fictional_knowledge_paraphrased.json
eval_on_load: true

max_duration: 2e12T  # 2T tokens
global_train_batch_size: 2048
device_train_microbatch_size: 1
time_limit: null

precision: amp_bf16

fsdp:
  wrapping_strategy: by_block
  precision: mixed

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
  window_size: 20

eval_interval: 1
eval_subset_num_batches: -1
device_eval_batch_size: 2
evaluators: []

data_shuffling: false

data:
  pad_direction: right
  num_workers: 16
  drop_last: true
  pin_memory: true
  prefetch_factor: 1
  persistent_workers: true
  timeout: 0
  paths:
    - $dataset_path