# NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 70B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "70B"
hf_assets_path = "./assets/hf/Llama-3.1-70B"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 1.5e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 200  # lr scheduler warm up, normally 20% of the train steps

[training]
local_batch_size = 8
seq_len = 8192
max_norm = 1.0  # grad norm clipping
steps = 1000
dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8  # 8-way TP
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable = false
folder = "checkpoint"
interval = 500
last_save_model_only = true
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "full"

[compile]
enable=false
components = ["model", "loss"]

[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

[validation]
enable = false
dataset = "c4_validation"
freq = 500
steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192
