<div align="center">

<h1><img src="figures/logo.png" height="34px" align="center"/>Locret</h1>

<p align="center">
<a href="TODO" target="_blank">Blog</a> |
<a href="TODO" target="_blank">Paper (ArXiV)</a> 
</a>
 
</p>

**A Light-weight Training-based KV Cache Compression Algorithm**
</div>

Locret is a light-weight training-based KV cache eviction method, which achieves **20x** and **8x** KV cache compression ratio for Phi-3-mini-128K and Llama-3.1-8B-instruct, enabling **128K+** long-context inference on **a single Nvidia 4090 GPU**.



## Design

### Overall Framework

![](figures/design.png)

### Locret Inference

![](figures/inference.png)

## Usage

Note that we only support Phi-3-mini-128K and Llama-3.1-8B-instruct currently. More models will be supported in the future.

### Environment Setup

Run the following command, and the environment will be set up automatically.

```bash
cd locret
conda env create -f locret_env.yaml
conda activate locret_env
pip install -e .
```

Now you are all set!

### Training

First, enter the working directory by
```bash
cd locret/train
```

Then, generate the training dataset. Run the following command.
```bash
python data_gen.py --model_dir <model_dir>
```

Finally, start training the model.
```bash
python train.py --model_dir <model_dir>
```

All the hyperparameters are set automatically according to our paper. Please indicate the model you use in `model_dir`, i.e. if there is a `phi` in the path, the training script can detect that.

After the training process, you will be getting the trained checkpoint at `locret/train/checkpoints/<model_name>`. You can directly load the complete model after `save_pretrained` or only load the retaining heads. We provide a script, `locret/train/convert.py`, to convert checkpoints saved in safetensors format to pytorch format.

### Inference

Here, we provide an example of one entry in R.PassKey of $\infty$Bench in `example.py`. To run the example, you can execute 
```bash
python example.py --model_dir <model_dir> # for saved full checkpoint, or
python example.py --model_dir <model_dir> --retaining_head_path <*.bin> # original model + saved retaining heads
```

For other experiments in our paper, please run the codes in `benchmark/infinite_bench` and `benchmark/LEval-main`. Each script correspond to an experimental setting, which can be recognized through the script name. 

## Develop Roadmap 

- [ ] Add support to MiniCPM-2.4B and MiniCPM-1.2B
- [ ] Add support to Qwen-2.5-1.5B and Qwen-2.5-1.5B


## Citation

Please cite our [paper](TODO) if you find our work valuable.

```
@article{anonymous2024locret,
  title={Locret: Accelerating Long-Context LLM Inference with Retaining Heads},
  author={Anonymous Authors},
  journal={arXiv preprint arXiv:TODO},
  year={2024}
}
```