# SinkQ: Accurate KV Cache Quantization with Dynamic Sink Tracking

### Setup

To install the required packages:

```bash
conda create -n sinkq python=3.10
conda activate sinkq
pip install --upgrade pip  # enable PEP 660 support
pip install -e .
```

Then install our CUDA implementation:

```bash
cd quant && pip install -e .
```

### Example

Load model with KIVI: (e.g., Llama-2-7b)

```python
# LLaMA model with SinkQ
import torch
import os
from models.llama_sinkq import LlamaForCausalLM_SinkQ
from transformers import LlamaConfig, AutoTokenizer
config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf")

config.k_bits = K_BITS # current support 2/4 bit for KV Cache
config.v_bits = V_BITS # current support 2/4 bit for KV Cache
config.sink_num=SINK_NUM
config.sink_max_size=SINK_MAX_SIZE
config.group_size = GROUP_SIZE
config.residual_length = RESIDUAL_LENGTH # the number of recent fp16 tokens
CACHE_DIR = PATH_TO_YOUR_SAVE_DIR

model = LlamaForCausalLM_SinkQ.from_pretrained(
    pretrained_model_name_or_path='meta-llama/Llama-2-7b-hf',
    config=config,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(
    'meta-llama/Llama-2-7b-hf', 
    use_fast=False, 
    trust_remote_code=True, 
    tokenizer_type='llama')

# Inference
# e.g., model.generate(...)
```

#### GSM8K example
We use GSM8K as an example to show how to use KIVI. You can check [example.py](./example.py):

```bash
python example.py
```
