# test yaml

model_name: "wan_t2v"
seed: 1024

output_dir: "output/t2v"

training_iteration: 1000000
# ddp_size: 1
fsdp_size: 8
cp_size: 2
use_context_parallel: True
reshard_after_forward: False
gradient_checkpointing: True
gradient_accumulation_steps: 1
init_max_grad_norm: 1.0
log_interval: 1
save_interval: 6
weight_dtype: "bf16"
ema_decay: 0.9999
ema_update_interval: 1
save_with_dcp_api: True

wandb_config:
  project_name: "t2v"
  exp_name: "t2v"

model_config:
  pretrained_model_dir_or_checkpoint: "Wan2.1-T2V-1.3B/"

scheduler_config:
  scheduler_name: "flow_matching"
  use_dynamic_shifting: True
  use_logitnorm_time_sampling: True

vae_config:
  vae_path: "Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
  dtype: "fp32"

text_encoder_config:
  text_len: 512
  checkpoint_path: "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth"
  use_fsdp: True

data_config:
  batch_size: 1
  num_workers: 16
  pin_memory: False
  drop_last: True
  shuffle: True
  dataset_name: "t2v_random"
  dataset_config:
    text_tokenizer_path: "Wan2.1-T2V-1.3B/google/umt5-xxl"
    sample_height: 480
    sample_width: 832
    sample_num_frames: 49
    tokenizer_max_length: 512
    return_prompt_mask: True
  sampler_name: "stateful_distributed"
  collator_name: "wan_t2v"
  
optimizer_config:
  lr: 0.00002
  weight_decay: 0.01
