# ScaleKV: Memory-Efficient Visual Autoregressive Modeling with Scale-Aware KV Cache Compression

<div align="center">
<img src="assets/teaser.png" width="100%">
</div>

## 💡 Introduction
We propose Scale-Aware KV Cache (ScaleKV), a novel KV Cache compression framework tailored for VAR’s next-scale prediction paradigm. ScaleKV leverages on two critical observations: varying cache demands across transformer layers and distinct attention patterns at different scales. Based on these insights, we categorizes transformer layers into two functional groups termed drafters and refiners, implementing adaptive cache management strategies based on these roles and optimize multi-scale inference by identifying each layer's function at every scale, enabling adaptive cache allocation that aligns with specific computational demands of each layer. On Infinity-8B, it achieves 10x memory reduction from 85 GB to 8.5 GB with negligible quality degradation (GenEval score remains at 0.79 and DPG score marginally decreases from 86.61 to 86.49).

<div align="center">
<img src="assets/method.png" width="100%">
</div>
<div align="center">
<img src="assets/overview.png" width="100%">
</div>


## 🔧 Installation

### Requirements
```bash
pip install -r requirements.txt
```

### Model Checkpoints
Ensure huggingface_hub is installed:
```bash
pip install -U huggingface_hub
cd Infinity
```

Download Infinity-8B model checkpoint:
```bash
huggingface-cli download FoundationVision/Infinity --include "infinity_8b_weights/**" --local-dir ./weights/infinity_8b_weights
```

Download flan-t5-xl:
```bash
huggingface-cli download google/flan-t5-xl --local-dir ./weights/flan-t5-xl
```

Click the links below to download the checkpoints of VAE.
- `VAE(8B)`: [infinity_vae_d56_f8_14_patchify.pth](https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d56_f8_14_patchify.pth?download=true)


## ⚡ Sample & Evaluations
### Sampling 5000 images with COCO captions.

Sample images with full model as ground truth:
```python
torchrun --nproc_per_node=$N_GPUS scripts/sample_ori.py
```

Sample images with ScaleKV compressed model (10% KV Cache):
```python
torchrun --nproc_per_node=$N_GPUS scripts/sample_kv.py
```

After you generate all the images, you can calculate PSNR, LPIPS and FID with:
```python
python compute_metrics.py --input_root0 samples/gt --input_root1 samples/scalekv
```

## 📚 Key Results
<div align="center">
<img src="assets/coco.png" width="100%">
</div>

<div align="center">
<img src="assets/bench.png" width="100%">
</div>

<div align="center">
<img src="assets/compare.png" width="100%">
</div>

<div align="center">
<img src="assets/speed.png" width="100%">
</div>

<div align="center">
<img src="assets/mem.png" width="100%">
</div>