# torchtitan Config.toml

[job]
dump_folder = "./outputs"
description = "Llama2 1B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 100
enable_tensorboard = true
save_tb_folder = "tb"
enable_color_printing = false
enable_wandb = true
wandb_comment =  "vanilla standard LR"

[model]
name = "llama2"
flavor = "1B"
norm_type = "fused_rmsnorm"  # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/sentencepiece/tokenizer.model"

[optimizer]
name = "AdamW"
# name = "Adam-mini"
# standard 1e-4
# smaller  1e-5
# alt      3e-6
# smallest 1e-6
lr = 1e-5

#   batch | seq_len | GPUs | tokens / step | steps(~300B tok) | steps(warmup)
# 1.    8 |    2048 |    4 |        65,536 |        4,577,636 |        15,258
# 2.    4 |    2048 |    8 |        65,536 |        4,577,636 |        15,258
# 3.    8 |    2048 |    8 |       131,072 |        2,288,818 |         7,629
# stick to config. 1 and 2 for portability (A100 4GPU and others 8GPU)
# config 3 for A100 8GPU 300B scaling
# (2.1M steps total, 7k steps warmup --> 275B tokens total)

[training]
batch_size = 8  # local batch size
seq_len = 2048 
max_norm = 1.0  # grad norm clipping
steps = 2100000
warmup_steps = 7000  # steps // 300
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
dataset_path = "/mnt/cephfs/dataset/c4"
seed = 42

[experimental]
context_parallel_degree = 1   # disabled
pipeline_parallel_degree = 1  # disabled
enable_gws = true
is_diffq = false
gws_lambda = 1e-4
gws_init = 8.0
gws_target = 6.0
enable_cod = false
quant_no_noise = false

[checkpoint]
enable_checkpoint = false
folder = "checkpoint/llama2_1B"
interval_type = "steps"
interval = 50000000
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'none'  # ['none', 'selective', 'full'] # for Adam-mini we can use none, will be faster
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_float8_linear = false
