
[job]
dump_folder = "./outputs"
description = "Flux-schnell model"
print_args = false

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 100
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "flux"
flavor = "flux-schnell"

[optimizer]
name = "AdamW"
lr = 1e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 3_000  # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.0  # no decay

[training]
local_batch_size = 64
max_norm = 1.0  # grad norm clipping
steps = 30_000
dataset = "cc12m-wds"
classifier_free_guidance_prob = 0.447
img_size = 256

[encoder]
t5_encoder = "google/t5-v1_1-xxl"
clip_encoder = "openai/clip-vit-large-patch14"
max_t5_encoding_len = 256
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors"  # Autoencoder to use for image


[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1

[experimental]
custom_args_module = "torchtitan.experiments.flux.job_config"

[activation_checkpoint]
mode = "full"

[checkpoint]
enable = false
folder = "checkpoint"
interval = 1_000
last_save_model_only = true
export_dtype = "float32"
async_mode = "disabled"  # ["disabled", "async", "async_with_pinned_mem"]

[validation]
enable = false
dataset = "coco-validation"
local_batch_size=64
steps = 6   # Recommended value with the current settings and world_size=8
freq = 1000
enable_classifier_free_guidance = true
classifier_free_guidance_scale = 5.0
denoising_steps = 50
save_img_count = 50
save_img_folder = "img"
all_timesteps = false
