# torchtitan Config.toml

[job]
dump_folder = "./outputs"
description = "Llama2 360M training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 100
enable_tensorboard = true
save_tb_folder = "tb"
enable_color_printing = false
enable_wandb = true
wandb_comment =  "vanilla standard LR"

[model]
name = "llama2"
flavor = "360M"
norm_type = "fused_rmsnorm"  # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/sentencepiece/tokenizer.model"

[optimizer]
name = "AdamW"
# name = "Adam-mini"
# standard 1e-4
# smaller  1e-5
# alt      3e-6
# smallest 1e-6
lr = 1e-5

#   batch | seq_len | GPUs | tokens / step | steps(~300B tok) | steps(warmup)
# 1.   24 |    2048 |    8 |       393,216 |          750,000 |         2,500
# 2.   24 |    2048 |    4 |       196,608 |        1,500,000 |         5,000
# 3.   12 |    2048 |    8 |       196,608 |        1,500,000 |         5,000
# 4.   12 |    2048 |    4 |        98,304 |        3,000,000 |        10,000
# stick to configuration 3 for portability

[training]
batch_size = 12  # local batch size
seq_len = 2048 
warmup_steps = 5000
max_norm = 1.0  # grad norm clipping
steps = 1500000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
dataset_path = "/mnt/cephfs/dataset/c4"
seed = 42

[experimental]
context_parallel_degree = 1   # disabled
pipeline_parallel_degree = 1  # disabled
enable_gws = true
gws_lambda = 1e-4
gws_init = 6.0
gws_target = 4.0
enable_cod = false

[checkpoint]
enable_checkpoint = false
folder = "checkpoint/llama2_360M"
interval_type = "steps"
interval = 50000000
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'none'  # ['none', 'selective', 'full'] # for Adam-mini we can use none, will be faster
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_float8_linear = false
