# test yaml

model_name: "flashi2v"
seed: 1024

output_dir: "output/flashi2v_1_3b"

training_iteration: 1000000
# ddp_size: 1
fsdp_size: 8
cp_size: 2
use_context_parallel: True
reshard_after_forward: False
gradient_checkpointing: True
gradient_accumulation_steps: 1
init_max_grad_norm: 1.0
log_interval: 1
save_interval: 1000
weight_dtype: "bf16"
ema_decay: 0.9999
ema_update_interval: 1
save_with_dcp_api: True

wandb_config:
  project_name: "flashi2v"
  exp_name: "flashi2v"

model_config:
  dim: 1536
  ffn_dim: 8960
  freq_dim: 256
  in_dim: 16
  num_heads: 12
  num_layers: 30
  out_dim: 16
  text_len: 512
  low_freq_energy_ratio: [0.05, 0.95]
  fft_return_abs: True
  pretrained_model_dir_or_checkpoint: "Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors"

scheduler_config:
  scheduler_name: "flashi2v_flow_matching"
  use_dynamic_shifting: True
  use_logitnorm_time_sampling: True

vae_config:
  vae_path: "Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
  dtype: "fp32"

text_encoder_config:
  text_len: 512
  checkpoint_path: "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth"
  use_fsdp: True

data_config:
  batch_size: 1
  num_workers: 16
  pin_memory: False
  drop_last: True
  shuffle: True
  dataset_name: "i2v_random"
  dataset_config:
    text_tokenizer_path: "Wan2.1-T2V-1.3B/google/umt5-xxl"
    sample_height: 480
    sample_width: 832
    sample_num_frames: 49
    tokenizer_max_length: 512
    return_prompt_mask: True
  sampler_name: "stateful_distributed"
  collator_name: "flashi2v"
  
optimizer_config:
  lr: 0.00002
  weight_decay: 0.01
