# PM-KVQ: Progressive Mixed-precision KV Cache Quantization for Long-CoT LLMs

## Installation

1. Create a new conda environment.

   ```bash
   conda create -n pm_kvq python==3.10
   conda activate pm_kvq
   ```

2. Use pip to install packages from requirements.

   ```bash
   pip install -r requirements.txt
   ```

3. Install `pm_kvq` from source.

   ```bash
   pip install -e .
   ```

4. For RotateKV baseline, install `fast-hadamard-transform` from [Dao-AILab/fast-hadamard-transform](https://github.com/Dao-AILab/fast-hadamard-transform).

## Apply PM-KVQ

### Block-wise Memory Allocation

1. Profile the sensitivity to quantization of KV Cache in different transformer blocks.

   ```bash
   python scripts/get_sensitivity.py \
   --model_path /PATH/TO/MODEL \
   --dataset_path /PATH/TO/CALIBRATION/DATASET \
   --n_samples 512 \
   --seq_len 2048 \
   --effective_len 8192 \
   --save_path /PATH/TO/SAVE/SENSITIVITY
   ```

2. Assign memory budget to each transformer block. The value of `--memory_budget` is specified in megabytes (MB).

   ```bash
   python scripts/allocate_memory.py \
   --sensitivity_path /PATH/TO/SENSITIVITY \
   --memory_budget 1024 \
   --fbit_choices 4,2 \
   --hidden_size ${HIDDEN_DIMENSION_OF_MODEL} \
   --max_len 32768 \
   --save_path /PATH/TO/SAVE/MEMORY/BUDGET
   ```

### Calibration with Positional Interpolation

1. Calculate maximum magnitude of the Key cache.

   ```bash
   python scripts/get_max_keys.py \
   --model_path /PATH/TO/MODEL \
   --dataset_path /PATH/TO/CALIBRATION/DATASET \
   --n_samples 512 \
   --seq_len 2048 \
   --effective_len 8192 \
   --save_path /PATH/TO/SAVE/MAX/KEYS
   ```

2. Search for the optimal reparameterization factor.

   ```bash
   python scripts/search_rep_scales.py \
   --model_path /PATH/TO/MODEL \
   --dataset_path /PATH/TO/CALIBRATION/DATASET \
   --n_samples 512 \
   --seq_len 2048 \
   --effective_len 8192 \
   --max_keys_path /PATH/TO/MAX/KEYS \
   --k_bits 4 \
   --v_bits 4 \
   --save_path /PATH/TO/SAVE/REP/SCALES
   ```

### Quantization and Evaluation

1. Evaluate the quantized model and save its responses to a `.jsonl` file. Use the `--start` and `--end` options to specify the range of problem indices to evaluate. To facilitate joint judgement, save the response files for different problems in the same directory.

   ```bash
   python scripts/evaluation.py \
   --model_path /PATH/TO/MODEL \
   --output_path /PATH/TO/SAVE/MODEL/RESPONSES \
   --benchmark aime \
   --version 2024 \
   --start 0 \
   --end 30 \
   --n_responses 16 \
   --method pm-kvq \
   --backend fake \
   --rep_scales /PATH/TO/REP/SCALES \
   --kv_budgets /PATH/TO/MEMORY/BUDGET \
   --n_sink_token 1 \
   --n_sink_token_bits 16 \
   --n_window_token 128 \
   --n_window_token_bits 16 \
   --n_init_kv_bits 16
   ```

2. Judge the responses and calculate the evaluation metrics.

   ```bash
   python scripts/judge.py \
   --benchmark aime \
   --version 2024 \
   --responses_dir /PATH/TO/MODEL/RESPONSES
   ```