# torchtitan Config.toml
# NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 1B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"
disable_color_printing = false
enable_wandb = true

[model]
name = "llama3"
flavor = "1B_compressed"
tokenizer_path = "./assets/tokenizer/gpt2"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 4e-4
eps = 1e-8
beta2 = 0.8

[lr_scheduler]
warmup_steps = 4000 # lr scheduler warm up
decay_type = "cosine"
min_lr_factor = 0.1

[training]
local_batch_size = 4
global_batch_size = 256
seq_len = 2048
steps = 2384186
compile = true
dataset = "fineweb-edu"
seed = 123456

[gradient_clipping]
method = 'vanilla'
scope = 'global'
max_norm = 1.0
max_norm_last_layer = 1.0
alpha = 0.97
z_thresh = 2.5
eps = 1e-6
warmup_steps = 25
mode = 'zscore'
clip_option = 'adaptive_scaling'
clip_factor = 1.0
skip_update_on_spike = false

[parallelism]
enable_powersgd = true
powersgd_matrix_approximation_rank = 128
powersgd_start_iter = 2
powersgd_min_compression_rate = 0.5
powersgd_use_error_feedback = true
powersgd_warm_start = true
powersgd_random_seed = 123456
powersgd_error_feedback_reset_frequency = 100
data_parallel_replicate_degree = 8
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1
pipeline_parallel_microbatch_size = 1
pipeline_parallel_schedule = "GPipe"

[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval = 100
keep_latest_k = 10
last_save_model_only = true
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "none"  # ["none", "selective", "full"]
selective_ac_option = "op"  # "int" = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]
