# torchtitan Config.toml
# Test configuration for 2D context parallelism with reduced memory usage

[job]
dump_folder = "/data" #full_exps
description = "Llama 3 8B 2D context parallel test"

[profiling]
enable_profiling = false
save_traces_folder = "cp_yunchang_256k_4u4r"
enable_memory_snapshot = false
save_memory_snapshot_folder = "cp_yunchang_256k_4u4r"
profile_freq = 25

[metrics]
log_freq = 10
enable_tensorboard = false
enable_wandb = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
attn_impl = "yunchang_zigzag"
ring_comm_heads = "mha_kv"

[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 10  # lr scheduler warm up

[training]
enable_cpu_offload = false # FSDP CPU Offloading - corresponds to the Zero3 offloading
batch_size = 1
seq_len = 256000  # Reduced from 20000 to save memory
max_norm = 1.0  # grad norm clipping
steps = 100  # Reduced for testing
compile = false
dataset = "c4_test"
seed = 0  # Fixed seed for reproducibility
# deterministic = false  # Enable deterministic algorithms for exact reproducibility
gc_freq = 50
chunked_loss = true
backend = "torch"

[evaluation]
enable_eval = false           # Enable periodic evaluation - needs fixing yet
eval_freq = 1              # Run evaluation every 10 steps  
max_new_tokens = 80         # Generate up to 100 new tokens
num_eval_samples = 5         # Use 5 fixed samples for evaluation


[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_ulysses_degree = 4  # Ulysses dimension
context_parallel_degree = 4  # Ring dimension
context_parallel_rotate_method = "allgather" #"alltoall" # #refers to the method used for ring passing

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled"

[activation_checkpoint]
mode = "none"  # ["none", "selective", "full"] # Note that selective op only retains flash attn output, recomputes everything else
offloading = "no" # ["no" , "UAO"] Unsloths Async Activation Offloading
selective_ac_option = "op" # Note that selective op only retains flash attn output, recomputes everything else

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"] 