Keywords: Diffusion Model;Video Generation;Efficient Attention;Sparse Attention
TL;DR: Trainable sparse attention for video diffusion model.
Abstract: Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D attention, even though most of the attention mass concentrates on a small subset of positions. We turn this observation into VSA, a trainable, hardware-efficient sparse attention that replaces full attention at both training and inference. In VSA, a lightweight coarse stage pools tokens into tiles and identifies high-weight critical tokens; a fine stage computes token-level attention only inside those tiles subjecting to block computing layout to ensure hard efficiency. This leads to a single differentiable kernel that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of FlashAttention3 MFU. We perform a large sweep of ablation studies and scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA reaches a Pareto point that cuts training FLOPS by 2.53$\times$ with no drop in diffusion loss. Retrofitting the open-source Wan2.1-1.3B model speeds up attention time by 6$\times$ and lowers end-to-end generation time from 31s to 18s with comparable quality, while for the 14B model, end-to-end generation time is reduced from 1274s to 576s.
Furthermore, we introduce a preliminary study of Sparse-Distill, the first method to enable sparse attention and distillation concurrently, achieving 50.9x speed up for Wan-1.3B while maintaining quality.
These results establish trainable sparse attention as a practical alternative to full attention and a key enabler for further scaling of video diffusion models. Code is available at https://github.com/hao-ai-lab/FastVideo.
Primary Area: Deep learning (e.g., architectures, generative models, optimization for deep networks, foundation models, LLMs)
Submission Number: 4396
Loading