run_name: OLMo-7B
seed: 6198
dry_run: false

# Comment out the swanlab section if you don't wish to use swanlab for logging
# Otherwise update it according to your account info and it's setup on the cluster
#swanlab:
#  name: ${run_name}
#  project: olmo-medium
#  group: OLMo-7B

model:
  d_model: 4096
  n_heads: 32
  n_layers: 32
  mlp_hidden_size: 16384 # default: 4 * d_model
  weight_tying: false
  alibi: false
  rope: true
  flash_attention: false
  attention_dropout: 0.0
  attention_layer_norm: true
  multi_query_attention: false
  include_bias: false
  block_type: sequential
  layer_norm_type: default
  layer_norm_with_affine: true
  bias_for_layer_norm: false
  attention_layer_norm_with_affine: true
  activation_type: gelu
  residual_dropout: 0.0
  embedding_dropout: 0.0
  max_sequence_length: 2048
  vocab_size: 32100           # depends on the tokenizer
  embedding_size: 32128 
  eos_token_id: 1
  pad_token_id: 0
  init_device: cuda
  init_fn: mitchell

compile: # causes instability on AMD GPUs
  fullgraph: false

activation_checkpointing: whole_layer

optimizer:
  name: adamw
  learning_rate: 3.0e-4
  weight_decay: 0.1
  betas:
  - 0.9
  - 0.95
  metrics_log_interval: 100

scheduler:
  name: cosine_with_warmup
  t_warmup: 1000
  alpha_f: 0.1
  grad_clip_warmup_steps: 1000
  grad_clip_warmup_factor: 10.0

tokenizer:
  identifier: t5-base
  truncate_direction: right

save_folder: ${oc.env:CHECKPOINTS_PATH}
remote_save_folder: null
save_overwrite: false
# Sharded checkpoints (best for restarts)
save_interval: 100
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 515  # 135M tokens (global_train_batch_size * max_sequence_length * max_duration)
global_train_batch_size: 128 # 4 GPUs * 32
device_train_microbatch_size: 32
time_limit: null

precision: amp_bf16

ddp:
  grad_sync_mode: batch

distributed_strategy: fsdp

fsdp:
  wrapping_strategy: by_block
  precision: mixed
  sharding_strategy: FULL_SHARD

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
  window_size: 20

eval_interval: ${save_interval}
eval_subset_num_batches: 5 # limit how many batches to evaluate on (set to -1 for all batches)
device_eval_batch_size: ${device_train_microbatch_size}

evaluators:

  ##########################
  # Downstream evaluations #
  ##########################
  - label: piqa
    type: downstream

  - label: hellaswag
    type: downstream

  - label: winogrande
    type: downstream

  - label: openbook_qa
    type: downstream

  # - label: boolq  # requires implemention of the pmi_dc matrix
    # type: downstream

  - label: sciq
    type: downstream

  - label: arc_easy
    type: downstream

  # - label: arc_challenge  # requires implemention of the pmi_dc matrix
  #   type: downstream

  - label: copa
    type: downstream

  - label: rte
    type: downstream

  - label: commitment_bank
    type: downstream

  - label: mrpc
    type: downstream

  - label: sst2
    type: downstream

data:
  pad_direction: right
  num_workers: 2
  drop_last: true
  pin_memory: true
  prefetch_factor: 1
  persistent_workers: true
  timeout: 0
  custom_dataset:
    name: "<module>.<class>"
    args:
      arg1: "value1"
      arg2: "value2"
      arg3: "value3"
    collate_config:
      input_id_field: "token_ids" # The field in the dataset that contains the input ids