Keywords: token sparsity, token pruning, sparsity, pruning, language models
TL;DR: TokenButler uses a lightweight predictor to identify which tokens in LLM memory matter for generation at each step, beating existing methods in accuracy and speed.
Abstract: Large Language Models (LLMs) rely on the Key-Value (KV) Cache to store token history, enabling efficient decoding of tokens.
As the KV-Cache grows, it becomes a major memory and computation bottleneck, however, there is an opportunity to alleviate this bottleneck, especially because prior research has shown that only a small subset of tokens contribute meaningfully to each decoding step.
A key challenge in finding these _critical tokens_ is that they are dynamic, and heavily input query-dependent.
Existing methods either risk quality by evicting tokens permanently, or retain the full KV-Cache but rely on retrieving chunks (pages) of tokens at generation, failing at dense, context-rich tasks.
Additionally, many existing KV-Cache sparsity methods rely on inaccurate proxies for token importance.
To address these limitations, we introduce **TokenButler**, a high-granularity, query-aware predictor that learns to identify these critical tokens.
By training a light-weight predictor with less than $1.2\\%$ parameter overhead, TokenButler prioritizes tokens based on their contextual, predicted importance.
This improves perplexity \& downstream accuracy by upto $8\\%$ relative to SoTA methods for estimating token importance. We evaluate TokenButler on a novel synthetic small-context co-referential retrieval task, demonstrating near-oracle accuracy. Furthermore, we show that TokenButler minimizes the gap to the oracle throughput and outperforms prior methods by up to $3\times$. Code, models, dataset and benchmarks [are available](https://anonymous.4open.science/r/TokenButler-EAEF).
Primary Area: generative models
Submission Number: 14135
Loading