# LycheeDecode: Accelerating Long-Context LLM Inference via Hybrid-Head Sparse Decoding

![Framework of LycheeDecode](assets/framework.png)
## Abstract
The proliferation of long-context large language models (LLMs) exposes a key bottleneck: the rapidly expanding key-value cache during decoding, which imposes heavy memory and latency costs. While recent approaches attempt to alleviate this by sharing a single set of crucial tokens across layers, such coarse-grained sharing undermines model performance by neglecting the functional diversity of attention heads. To address this, we propose LycheeDecode, an efficient decoding method centered on a fine-grained hybrid-head attention mechanism that employs a hardware-efficient top-$k$ selection strategy. Specifically, the novel hardkuma-based mechanism partitions attention heads into a small subset of retrieval heads that dynamically identify crucial tokens and a majority of sparse heads that reuse them for efficient computation. Through extensive experiments on leading models like Llama3 and Qwen3 across diverse benchmarks for long-context understanding (e.g., LongBench, RULER) and complex reasoning (e.g., AIME24, OlympiadBench), we demonstrate that LycheeDecode achieves generative quality comparable to, and at times surpassing, the full-attention baseline. Crucially, this is accomplished with up to a 2.7x speedup at a 128K context length. By preserving the functional diversity of attention heads, our fine-grained strategy overcomes the performance bottlenecks of existing methods, providing a powerful and validated pathway to both efficient and high-quality long-context LLM inference. The implementation code, kernels, and models will be publicly available.

## Installation
```bash
conda create -yn lychee python=3.10
conda activate lychee

conda install -y nvidia/label/cuda-12.4.0::cuda-toolkit
conda install -y nvidia::cuda-cudart-dev
conda install -y pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia

pip install mkl==2024.0 transformers==4.56.1 datasets wandb matplotlib tilelang einops zstandard

pip install flash-attn==2.6.3 --no-build-isolation
```

## Train
```bash
cd ./src/train
# Passkey Retrieval Dataset
bash train_kuma_multi_passkey.sh 
# HotpotQA Dataset
bash train_kuma_hotpotqa.sh 
```

## Open Source Plan
We currently provide only the source code of LycheeDecode as supplementary material, and the remaining code for experiments will be released once it has been organized.
- [x] Source Code for LycheeDecode
- [ ] Code for Experiments
