# torchtitan Config.toml
# NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 1B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"
disable_color_printing = false
enable_wandb = true

[model]
name = "llama3"
flavor = "1B_compressed_vanilla"
tokenizer_path = "./assets/tokenizer/gpt2"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 8e-3
eps = 1e-8
# beta2 = 0.8

[lr_scheduler]
warmup_steps = 2000 # lr scheduler warm up
decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
global_batch_size = 2048
seq_len = 1024
steps = 10000
compile = false
dataset = "fineweb"

[gradient_clipping]
method = 'vanilla'
scope = 'global'
max_norm = 1.0
max_norm_last_layer = 1.0
alpha = 0.97
z_thresh = 2.5
max_grad_norm = 1.0
eps = 1e-6
warmup_steps = 25
mode = 'zscore'
clip_option = 'adaptive_scaling'
clip_factor = 1.0
skip_update_on_spike = false

[parallelism]
enable_powersgd = false
powersgd_matrix_approximation_rank = 128
powersgd_start_iter = 2
powersgd_min_compression_rate = 0.5
powersgd_use_error_feedback = false
powersgd_warm_start = true
powersgd_random_seed = 123456
data_parallel_replicate_degree = 8
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1
pipeline_parallel_microbatch_size = 1
pipeline_parallel_schedule = "GPipe"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 500
last_save_model_only = true
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "full"  # ["none", "selective", "full"]
selective_ac_option = "op"  # "int" = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]
