############################################################################################
# NOTE: This uses the MPT _architecture_, but not the HuggingFace pretrained model!        #
# Use this YAML to train directly from a Composer checkpoint.                              #
# This is not the right YAML if you are trying to finetune a HuggingFace pretrained model. #
############################################################################################

variables:
  global_seed: 17
  max_seq_len: 2048

max_seq_len: ${variables.max_seq_len}

# Run Name
run_name:  # If left blank, will be read from env var $RUN_NAME

# Model
# These must match pretraining
model:
  name: mpt_causal_lm
  init_device: meta
  d_model: 2048
  n_heads: 16  # Modified 24->16 so that d_head == 128 to satisfy FlashAttention
  n_layers: 24
  expansion_ratio: 4
  max_seq_len: ${variables.max_seq_len}
  vocab_size: 50368
  attn_config:
    attn_impl: flash
    # Set this to `true` if using `train_loader.dataset.packing_ratio` below
    attn_uses_sequence_id: false

# Tokenizer
tokenizer:
  name: EleutherAI/gpt-neox-20b
  kwargs:
    model_max_length: ${variables.max_seq_len}

# Local data to load into huggingface datasets
dataset: &hf_dataset
  hf_name: json
  hf_kwargs:
    data_dir: /path/to/data/dir/
  preprocessing_fn: my.import.path:my_preprocessing_fn

# Dataloaders
train_loader: &train_loader
  name: finetuning
  dataset:
    <<: *hf_dataset
    split: train
    max_seq_len: ${variables.max_seq_len}
    allow_pad_trimming: false
    decoder_only_format: true
    shuffle: true
    # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with
    # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion
    # # of the dataset.
    # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...`
    # # to profile this run's optimal packing_ratio as it depends on GPU count,
    # # batch size, sequence length
    # packing_ratio:
  drop_last: true
  num_workers: 8
  pin_memory: false
  prefetch_factor: 2
  persistent_workers: true
  timeout: 0

eval_loader:
  <<: *train_loader
  dataset:
    <<: *hf_dataset
    split: validation

# Optimization
scheduler:
  name: linear_decay_with_warmup  # linear no warmup is HF default which dolly used
  t_warmup: 0ba
  alpha_f: 0

optimizer:
  # mimic HF defaults to replicate dolly
  name: decoupled_adamw
  lr: 1.0e-5
  betas:
  - 0.9
  - 0.999
  eps: 1.0e-8
  weight_decay: 0

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1.0

max_duration: 1ep
eval_interval: 50ba
eval_first: false
eval_subset_num_batches: -1
global_train_batch_size: 128

# System
seed: ${variables.global_seed}
device_eval_batch_size: 4
device_train_microbatch_size: 4
# device_train_microbatch_size: auto
precision: amp_bf16

# FSDP
fsdp_config:
  sharding_strategy: FULL_SHARD
  mixed_precision: PURE
  activation_checkpointing: false
  activation_checkpointing_reentrant: false
  activation_cpu_offload: false
  limit_all_gathers: true

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
  speed_monitor:
    window_size: 10
  lr_monitor: {}
  memory_monitor: {}
  runtime_estimator: {}

# loggers:
#   wandb: {}

# Checkpoint to local filesystem or remote object store
# save_interval: 2000ba
# save_num_checkpoints_to_keep: 1  # Important, this cleans up checkpoints saved to DISK
# save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Load from remote object store
# REPLACE THE BELOW with you own checkpoint!
load_path: oci://my-bucket/my-folder/mpt-1b/checkpoints/some_checkpoint.pt
load_weights_only: true  # Only load the weights, not the optimizer state, LR schedule, etc
