# torchtitan Config.toml
# Test configuration for 2D context parallelism with reduced memory usage

[job]
dump_folder = "/data"
description = "Llama 3 8B 2D context parallel finetune fineweb"

[profiling]
enable_profiling = true
save_traces_folder = "cp_yunchang_256k_2u8r"
profile_freq = 1000

[metrics]
log_freq = 100
enable_tensorboard = false
enable_wandb = true
# save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
tokenizer_path = "/data"
attn_impl = "yunchang_zigzag_skipkv"
ring_comm_heads = "gqa_kv"

[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 250  # lr scheduler warm up

[training]
enable_cpu_offload = true # FSDP CPU Offloading - corresponds to the Zero3 offloading
batch_size = 1
seq_len = 524288  # Reduced from 20000 to save memory
max_norm = 1.0  # grad norm clipping
steps = 10000  # Reduced for testing
compile = false
dataset = "openchat/openchat_sharegpt4_dataset"
seed = 0  # Fixed seed for reproducibility
# deterministic = false  # Enable deterministic algorithms for exact reproducibility
gc_freq = 100

[evaluation]
enable_eval = true           # Enable periodic evaluation - needs fixing yet
eval_freq = 500              # Run evaluation every 10 steps  
num_eval_samples = 10        # Use 10 fixed samples for evaluation (coz the c4-test used for validation anyways doesn't have lot of 512K sequences)


[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_ulysses_degree = 2  # Ulysses dimension
context_parallel_degree = 8  # Ring dimension
context_parallel_rotate_method = "allgather" #"alltoall" # #refers to the method used for ring passing

[checkpoint]
enable_checkpoint = true
folder = "/data"
interval = 1000
model_weights_only = true
keep_latest_k = 5
export_dtype = "float32"
async_mode = "async"

[activation_checkpoint]
mode = "selective"  # ["none", "selective", "full"] # Note that selective op only retains flash attn output, recomputes everything else
offloading = "UAO" # ["no" , "UAO"] Unsloths Async Activation Offloading
selective_ac_option = "op" # Note that selective op only retains flash attn output, recomputes everything else

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"] 