Keywords: Attention Mask Efficient Representation, Efficient Attention Computation, Long context, IO complexity, GPUs, LLMs
Abstract: The computational and memory demands of vanilla attention scale quadratically with the sequence length $N$, posing significant challenges for processing long sequences in Transformer models. FlashAttention alleviates these challenges by eliminating the $\mathcal{O}(N^2)$ memory dependency and reducing attention latency through IO-aware memory optimizations. However, its native support for certain attention mask types is limited, and it does not inherently accommodate more complex masking requirements. Previous approaches resort to using dense masks with $\mathcal{O}(N^2)$ memory complexity, leading to inefficiencies. In this paper, we propose \ours{}, an extension of FlashAttention that introduces a column-wise sparse representation of attention masks. This approach efficiently represents a wide range of mask types and facilitates the development of optimized kernel implementations. By adopting this novel representation, \ours{} achieves linear memory complexity $\mathcal{O}(N)$, making it suitable for modeling long-context sequences. Moreover, this representation enables kernel optimizations that eliminate unnecessary computations by leveraging sparsity in the attention mask, without sacrificing computational accuracy, resulting in higher computational efficiency. We evaluate \ours{}'s performance in fine-tuning and alignment training of LLMs such as SFT, LoRA, DPO, and RM. \ours{} achieves significant throughput improvements, with end-to-end speedups ranging from 1.65x to 3.22x compared to existing FlashAttention dense method. Additionally, our kernel-level comparisons demonstrate that \ours{} surpasses the latest counterpart, FlexAttention, by 12.1\% to 60.7\% in terms of kernel TFLOPs/s, achieving 37.8\% to 62.3\% of the theoretical maximum FLOPs/s on the A100 GPU. The code is open-sourced on PaddlePaddle\footnote{\url{https://github.com/PaddlePaddle/Paddle}} and integrated into PaddleNLP\footnote{\url{https://github.com/PaddlePaddle/PaddleNLP}}, supporting models with over 100 billion parameters for contexts extending up to 128K tokens.
Supplementary Material: zip
Primary Area: infrastructure, software libraries, hardware, systems, etc.
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2025/AuthorGuide.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors’ identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Submission Number: 9666
Loading