Keywords: LLM inference, KV cache compression
TL;DR: Inference-aware attention variants utilize modern hardware efficiently for fast decoding
Abstract: LLM decoding is bottlenecked for large batches and long contexts by loading the KV cache from high-bandwidth memory, which raises per-token latency, while its sequential nature limits parallelism. We redesign attention to perform more computation per byte of memory transfer, maximizing hardware efficiency without sacrificing parallel scalability. We first present \textit{Grouped-Tied Attention} (GTA), which merges and reuses key and value states to reduce memory traffic without affecting quality. Next, we introduce \textit{Grouped Latent Attention} (GLA), a parallel friendly latent attention enhanced with low-level optimizations for fast decoding at high quality. Experiments show that GTA matches Grouped Query Attention (GQA) quality while using roughly half the KV cache, and GLA matches Multi-head Latent Attention (MLA) yet shards more easily. Our optimized GLA kernel is up to $2\times$ faster than FlashMLA in speculative decoding once the query length exceeds one.
Submission Number: 160
Loading