# OBCache: Optimal Brain KV Cache Pruning for Efficient Long-Context LLM Inference

A unified kv-cache implementation for LLMs based on [HF's cache utilities](https://huggingface.co/docs/transformers/en/kv_cache#generate-with-cache)

## Install Dependencies
```bash
pip install transformers==4.47.0
```

## Usage

Using `OBCache` is simple and straightforward, which involves three main steps:

1. Load a model and the tokenizer.
2. Apply a one-line monkey patch to enable `OBCache`.
3. Configure the desired cache eviction policy and pass it to the `generate` function.

Below is an inference example demonstrating how to run the Llama-3.1 LLM with `OBCache` (a minimal example can also be found in `obc/example_gen.py`)

```python
import torch

from monkey_patch.utils import enable_optimal_brain_kv, enable_optimal_brain_kv_flashattn2
from cache_utils import *
from utils import load_kv_cache, load_model_and_tokenizer, seed_everything


# model_name = "Qwen2/Qwen2-0.5B-Instruct"
model_name = "meta-llama/Llama-3.1-8B-Instruct"
model, tokenizer = load_model_and_tokenizer(
    model_name_or_path=model_name,
    precision="bf16",
    flash_attn=False
)

# switch model to eval mode for inference
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("model and tokenizer loaded")

# enable OBCache for the model
enable_optimal_brain_kv(model, 'llama' if 'llama' in model_name.lower() else 'qwen2')
print("OBCache enabled")

prompt = "A special magic number is hidden within the following text. Make sure to memorize it. I will quiz you about the number afterwards.\\nJuly 2010What hard liquor, cigarettes, heroin, and crack have in common is that they're all more concentrated forms of less addictive predecessors. Most if not all the things we describe as addictive are. And the scary thing is, the process that created them is accelerating.We wouldn't want to stop it. It's the same process that cures diseases: technological progress. One of the special magic numbers for worried-colt is: 8930103. Technological progress means making things do more of what we want. When the thing we want is something we want to want, we consider technological progress good. If some new technique makes solar cells x% more efficient, that seems strictly better. \\nWhat is the special magic number for worried-colt mentioned in the provided text?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Required template for Qwen2
past_key_values = load_kv_cache(
    method = "h2o", # Alternative methods: "obc_value_p2", "obc_key_p2", "obc_value_key_p2"
    num_recent = 16, num_heavy = 240,
    # recent_ratio=0.01, heavy_ratio=0.19,
    decode_evict=True
)
print("Cache Eviction Method:\n", past_key_values)

seed_everything(42)

model_inputs = tokenizer([prompt], return_tensors="pt")
generated_ids = model.generate(
    model_inputs.input_ids.to(model.device),
    attention_mask=model_inputs.attention_mask,
    max_new_tokens=30,
    past_key_values=past_key_values,
    do_sample=False,
    temperature=1.0,
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
```

To switch the eviction strategy, simply change the input `method` (str) in `load_kv_cache` function:

- `"obc_value_p2"`: isolated value pruning. (based on L2 norm).
- `"obc_key_p2"`: isolated key pruning.
- `"obc_value_key_p2"`: perform joint pruning on Key and Value simultaneously.