max_seq_len: 512
global_seed: 17
model_name_or_path: #TODO

load_path:  # set via bash script to be absolute path to your sparse checkpoint
precision: amp_bf16

max_duration: # TODO
eval_interval: 1ep
eval_first: false
seed: ${global_seed}

global_train_batch_size: #TODO
# for mpt-7b dense:
# 4 x A100_80GB = "device_train_microbatch_size: 12"
# 8 x A6000_48GB = "device_train_microbatch_size: 6"

# for mpt-7b sparse (with masks):
# 8 x A6000_48GB = "device_train_microbatch_size: 4"
device_train_microbatch_size: 16
device_eval_batch_size: 16

# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME

model:
  name: hf_causal_lm
  pretrained: true
  pretrained_model_name_or_path: ${model_name_or_path}
  max_seq_len: ${max_seq_len}
  output_hidden_states: true
  dtype: bfloat16

quant_config:
  fwd:
    had_x: none # none, right, left, both
    had_w: none
    quant_x: none # none, global, block_32, row_wise, col_wise
    quant_w: none
  bwd1:
    had_e: none
    had_w: none
    quant_e: none
    quant_w: none
  bwd2:
    had_e: none
    had_x: none
    quant_e: none
    quant_x: none
  simulate: true
  kernel: simulated # kernel, fp8_pure, fp8_hey

# Tokenizer
tokenizer:
  name: ${model_name_or_path}
  kwargs:
    model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
  name: finetuning
  dataset:
    hf_name: gsm8k
    split: train
    hf_kwargs:
      name: main
    preprocessing_fn: preprocessing:gsm8k_preprocessing_function
    max_seq_len: ${max_seq_len}
    allow_pad_trimming: false
    decoder_only_format: true
    shuffle: true
    # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...`
    # # to profile this run's optimal packing_ratio as it depends on GPU count,
    # # batch size, sequence length
    # packing_ratio: 13 # padding=0.36%, waste=0.79%
  shuffle: true
  drop_last: false
  num_workers: 8
  pin_memory: false
  prefetch_factor: 2
  persistent_workers: true
  timeout: 0

# eval_loader:
#   name: finetuning
#   dataset:
#     hf_name: gsm8k
#     split: test
#     hf_kwargs:
#       name: main
#     preprocessing_fn: preprocessing:gsm8k_preprocessing_function
#     max_seq_len: ${max_seq_len}
#     allow_pad_trimming: false
#     decoder_only_format: true
#     # packing_ratio: 13
#     shuffle: false
#   drop_last: false
#   num_workers: 8
#   pin_memory: false
#   prefetch_factor: 2
#   persistent_workers: true
#   timeout: 0

# Optimization
scheduler:
  name: linear_decay_with_warmup
  t_warmup: 20ba
  alpha_f: 0

optimizer:
  name: decoupled_adamw
  lr: # TODO
  betas:
  - 0.9
  - 0.999
  eps: 1.0e-8
  weight_decay: 0.0

# we can't use gradient clipping for sparse training runs because we don't have
# a way to mask gradients of pruned weights, and thus the global gradient norm
# will be incorrect
# algorithms:
#   gradient_clipping:
#     clipping_type: norm
#     clipping_threshold: 1.0

# FSDP
fsdp_config:
  sharding_strategy: FULL_SHARD
  mixed_precision: FULL
  activation_checkpointing: true
  activation_checkpointing_reentrant: true
  forward_prefetch: true
  activation_cpu_offload: false
  limit_all_gathers: true
  verbose: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
  speed_monitor:
    window_size: 10
  lr_monitor: {}
  memory_monitor: {}
  runtime_estimator: {}
  had_hf_checkpointer:
    overwrite: true
    precision: # TODO
    save_folder: checkpoints/${run_name}
    save_interval: 1dur

loggers:
  wandb: {}

# # Checkpoint to local filesystem or remote object store
# save_interval: 1ep
# save_num_checkpoints_to_keep: 1  # Important, this cleans up checkpoints saved to DISK
# save_folder: output_dir/{run_name}/checkpoints
# save_overwrite: true