# kvlinc
Repository for submitted paper **KVLinC : KV Cache Quantization with Hadamard Rotation and Linear Correction**. This repository is built over [LolCats](https://github.com/HazyResearch/lolcats)
## Getting started
### Setup dependencies

Please see `environment.yaml` for dependencies and adjust PyTorch CUDA version if needed. Conda is used to setup dependencies:
```
conda env create -f environment.yaml
conda activate kvlinc-env
```
Then install Flash Attention using the command below:
```
pip install flash-attn --no-build-isolation
```
### Experiment and model configs
Every model and experiment is organized under experiment and model config files (`.yaml`) in `./configs`.
- Files under `./configs/experiments/` determine dataset and training hyperparameters (for training linear correction adapters).
- Files under `./configs/models/` determine model setup (pretrained LLM, attention configuration)
- We also provide configs to run baseline KV Cache quantization techniques : [KIVI](https://arxiv.org/abs/2402.02750), [QuaRot](https://arxiv.org/abs/2404.00456), [ResQ](https://arxiv.org/pdf/2412.14363) and [Gear](https://arxiv.org/abs/2403.05527). These can be evaluated by using appropriate config files.

For models, our scripts should automatically download the models from Hugging Face, but you should change the `cache_dir` to reflect where you want to save the weights.
For example:

```yaml
name: qwen3
model:
  pretrained_model_name_or_path: "qwen/Qwen3-1.7B-Base"
  cache_dir: "/home/HF_cache/" # Set this to where you want to save checkpoint weights 
  return_dict: true
  load_in_8bit: false
  load_in_4bit: false
  device_map: cpu
  low_cpu_mem_usage: true
  torch_dtype: bfloat16
  attn_implementation: flash_attention_2
  rope_theta: 1000000
```
### Sample commands
Script to train linear correction adapters and evaluate the model can be found under run.sh and run_eval.sh.
To train adapters, run:
### Qwen-3-1.7B-Base ###
```
torchrun --nproc_per_node=4 --master_port=24552 train_linc.py
--model_config qwen_3_1_7b_base/kvlinc
--distill_config distill_alpaca
--lk_zero_init
--verbose
--seed 0
--huggingface_token hf_<insert your token here>
```
To evaluate on GSM8K, run:
```
torchrun --nproc_per_node=1 --master_port=24539 eval.py
--model_config qwen_3_1_7b_base/kvlinc
--load_checkpoint ./checkpoints/qwen_3_1_7b_base/kvlinc/seq_len_3072/feature_dim_128/dl-d=distill_alpaca-s=0-se=0-lzi=1_distill_500.pt
--verbose
--seed 1
--huggingface_token hf_<insert your token here>
--tasks "gsm8k"
```



