# 256 node setup
#
# the peak rises to 16m tokens in a batch, which is a bs of 2 per accelerator, so no space for a ramp-up
# this results in a total of only 59604 steps to reach 1T tokens, but each step will be looong

# Main settings
run_name: baseline
resume: True
max_tokens: 4_000_000_000_000 # 4T

seed: 233
out_dir: /FILE_LOCATION/recllm_runs/baseline/

# Model configuration
model_name: nebel-raven-3.5b # replaced
model_impl: dynamic
block_size: 4096
model_overwrite:
  attn_impl: tridao # this assumes that the AMD fork of the repo is installed, as in the install guide
  use_fused_head: pytorch
  mup_model_scaling_factor: 1
  lockstep_n: True
  lockstep_k: True
  mean_recurrence: 1
  mean_backprop_depth: 1
  sampling_scheme: non-recurrent

# Training hyperparameters
world_batch_size: 4096 # up to 16mil token
batch_size_ramp: 0 # no room ...
optimizer: ELLISAdam
optim_config:
  lr: 4e-5 # let's be careful here
  weight_decay: 4e-5 # decoupled!
  betas:
    - 0.9
    - 0.95 # conservative
  # eps: 1e-8 # removed as a factor
  update_clipping: True
  atan_adam: True
  running_init: True
  tensor_wise_finite_check: False # True is a recipe for wasting compute
  tensor_wise_gradient_normalization: False
  decouple_wd: True

min_lr: 0
warmup_steps: 4096 # this is ~6.8% warmup ##### world-size dependent #####
lr_schedule: trapezoid
cooldown_steps: 4096 # ~6.8% cooldown percentage ##### world-size dependent #####
no_weight_decay_for_bias_and_norm_params: True

# Regularization
z_regularization: 0

# Implementation and backend
fabric_strategy: simple-ddp # fsdp-hybrid2 broken
fabric:
  optim_sharding: True
micro_batch_size: 2 #
compile_model: True
dynamo_ddp_config: python_reducer
gradient_checkpointing: True # not the worst with SAC-attn
loss_guardrail_active: False
skip_nonfinite_grads: False
fail_instead_of_recompile: False # turn this on for large-node setups

# Logging
logger_project: "incite-arch-CLUSTER"
model_telemetry: True # telemetry on attention turned off for now
shape_watching_steps: 0
log_step_interval: 4 # every 30 secs
save_step_interval: 128 # 15 minutes  ##### world-size dependent #####
eval_step_interval: 128 # 15 minutes? ##### world-size dependent #####
eval_iters: 1 # validation is distributed
save_n_min_before_job_done: 5 # maybe we're brave enough for a 5min interval
save_last_step: True

partial_depth_eval: [1, 4, 8, 16, 64]

# Data Handling
all_block_size_tensors: True
pad_to_block_size: True
data_config:
  train_data:
    - type: pqds-pure
      prefix: ""
      weight: 1
      data_dir: "/FILE_LOCATION/language_datasets/processed/recllm_project_v02_pqds_reshuffled_1/train"
  val_data:
    - type: pqds-pure
      prefix: ""
      weight: 1
      data_dir: "/FILE_LOCATION/language_datasets/processed/recllm_project_v02_pqds_reshuffled_1/val"

tokenizer_path: "/FILE_LOCATION/language_models/tokenizers/MODEL_tokenizer_65k"
