
[job]
dump_folder = "./outputs"
description = "Flux debug model"
print_args = false

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "flux"
flavor = "flux-debug"

[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 1  # 10% warmup steps
decay_ratio = 0.0  # no decay, stay stable during training

[training]
local_batch_size = 4
max_norm = 2.0  # grad norm clipping
steps = 10
dataset = "cc12m-test"
classifier_free_guidance_prob = 0.447
img_size = 256

[encoder]
t5_encoder = "google/t5-v1_1-xxl"
clip_encoder = "openai/clip-vit-large-patch14"
max_t5_encoding_len = 256
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors"  # Autoencoder to use for image

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1

[experimental]
custom_args_module = "torchtitan.experiments.flux.job_config"

[activation_checkpoint]
mode = "full"

[checkpoint]
enable = false
folder = "checkpoint"
interval = 10
last_save_model_only = false
export_dtype = "float32"
async_mode = "disabled"  # ["disabled", "async", "async_with_pinned_mem"]

[validation]
enable = false
dataset = "coco-validation"
freq = 5
local_batch_size = 8
steps = 48  # Recommended value with the current settings and world_size=8
# args for sampling images
enable_classifier_free_guidance = true
classifier_free_guidance_scale = 5.0
denoising_steps = 4
save_img_count = 1
save_img_folder = "img"
all_timesteps = false

[inference]
save_img_folder = "inference_results"
prompts_path = "./torchtitan/experiments/flux/inference/prompts.txt"
local_batch_size = 2
