
[job]
dump_folder = "./outputs"
description = "Flux-schnell model"
print_args = false

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 100
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "flux"
flavor = "flux-schnell"

[optimizer]
name = "AdamW"
lr = 1e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 3_000  # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.0  # no decay

[training]
batch_size = 4
seq_len = 512
max_norm = 1.0  # grad norm clipping
steps = 30_000
compile = false
dataset = "cc12m-wds"
classifer_free_guidance_prob = 0.1
img_size = 256

[encoder]
t5_encoder = "google/t5-v1_1-xxl"
clip_encoder = "openai/clip-vit-large-patch14"
max_t5_encoding_len = 256
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors"  # Autoencoder to use for image

[eval]
enable_classifer_free_guidance = true
classifer_free_guidance_scale = 5.0
denoising_steps = 50
save_img_folder = "img"
eval_freq = 1000

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1

[experimental]
custom_args_module = "torchtitan.experiments.flux.flux_argparser"

[activation_checkpoint]
mode = "full"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 1_000
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled"  # ["disabled", "async", "async_with_pinned_mem"]
