S2-ATTENTION: HARDWARE-AWARE CONTEXT SHARDING AMONG ATTENTION HEADS

Published: 05 Mar 2025, Last Modified: 05 Mar 2025SLLMEveryoneRevisionsBibTeXCC BY 4.0
Track: long paper (up to 9 pages)
Keywords: Hardware Innovation for Sparsity; kernel; efficient training; efficient inference
TL;DR: The paper pinpoint why existing sparse attention are rarely used in production and presents an optimized sparse attention kernel as well insights on how to train and serve such models.
Abstract: Sparse attention, which selectively attends to a subset of tokens in the context, has been an established approach to enhance the efficiency of Transformers. However, its theoretical reduction in FLOPs has rarely translated into wall-clock speed-up over its dense attention counterparts, mainly due to the lack of hardware-level optimizations like FlashAttention (Dao, 2023). Meanwhile, it remains unclear whether sparse attention can maintain the model’s quality at the scale of today’s large language models (LLMs), and how this can be achieved. This paper presents Sparsely-Sharded Attention (S2-ATTENTION), an optimized Triton kernel library providing a variety of customizable sparse attention implementations for both training and inference. S2-ATTENTION allows customizing the attention patterns at per head per context range level. The fresh insights from S2-ATTENTION inspire a novel sparse attention architecture that meets several desiderata that we find crucial for achieving both practical efficiency gains and strong accuracy on downstream tasks, called as Head-Heterogenous Strided Transformer (HHST). For higher sparsity, HHST shards the context heterogeneously across attention heads, where each head attends to a different subset of tokens while collectively covering the whole. We evaluate HHST by pretraining 1.3B and 7B sized models. For attention computation, HHST with S2-ATTENTION achieves 8.8× and 15.9× wall-clock attention speedup, as well as 2.8× and 2.5× training time reduction compared to a dense attention baseline implemented with FlashAttention-2. Moreover, HHST’s downstream task performance is on-par with dense attention, and achieves a perfect retrieval accuracy at a 128K context length at 7B scale. At inference, our 7B HHST, achieves a 4.5× speed-up compared to the dense counterparts in vLLM. S2- ATTENTION is released with easy-to-customize APIs for direct usage in Megatron and vLLM.
Anonymization: This submission has been anonymized for double-blind review via the removal of identifying information such as names, affiliations, and identifying URLs.
Presenter: ~Xihui_Lin1
Format: Yes, the presenting author will attend in person if this work is accepted to the workshop.
Funding: Yes, the presenting author of this submission falls under ICLR’s funding aims, and funding would significantly impact their ability to attend the workshop in person.
Submission Number: 18
Loading

OpenReview is a long-term project to advance science through improved peer review with legal nonprofit status. We gratefully acknowledge the support of the OpenReview Sponsors. © 2025 OpenReview