# torchtitan Config.toml
# NOTE: this toml config is a preset for 128 H100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 405B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "405B"
norm_type = "rmsnorm"  # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tiktoken/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 8e-5

[training]
batch_size = 2
seq_len = 8192
warmup_steps = 600  # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0  # grad norm clipping
steps = 3000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8  # 8-way TP
compile = true
dataset = "c4"
dataset_path = "/mnt/cephfs/dataset/c4"

[experimental]
context_parallel_degree = 1
pipeline_parallel_degree = 1
enable_async_tensor_parallel = true

[checkpoint]
enable_checkpoint = false
folder = "checkpoint/llama3_405B"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'full' # ['none', 'selective', 'full']

[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
