# AttentionPredictor: Temporal Patterns Matter for KV Cache Compression
![Alt text](figs/image.png)
Implementation of **AttentionPredictor** - the **first learning-based KV cache compression and critical token identification method with direct attention pattern prediction.**
Specifically, AttentionPredictor learns a lightweight, unified convolution model to dynamically capture spatiotemporal patterns and predict the next-token attention scores.
An appealing feature of AttentionPredictor is that it accurately predicts the attention score and shares the unified prediction model, which consumes negligible memory, among all transformer layers.
By retaining most of the attention information, AttentionPredictor achieves **13x** KV cache compression and **5.6x** speedup with comparable LLM performance, significantly outperforming the state-of-the-arts. 

## Quick Start

### Requirements
- Torch
- FlashAttention-2
- Transformers >= 4.44.0

### Supported LLMs
- LongChat: [lmsys/longchat-7b-v1.5-32k](https://huggingface.co/lmsys/longchat-7b-v1.5-32k)
- LLaMA-3.1: [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)

## AttentionPredictor

The implementation of AttentionPredictor is in `src/llama_attention/attnpred_llama_attention.py`.

The pretrained predictor models are in `model/`.

## Experiments
To evaluate LongBench dataset:
```bash
cd src/evaluation/LongBench
bash eval.sh # generate answers
bash metrics.sh # calculate scores
```


To evaluate GSM8K dataset:
```bash
cd src/evaluation/gsm8k
bash run.sh 
```



