# Multipole Attention Code

## Installation

1. Create a conda environment
```
conda create --name multipoleattention python=3.9 -y
conda activate multipoleattention
```

2. Clone and install the dependencies (including the local transformers environment)
```
pip install torch torchvision torchaudio
cd transformers
pip install -e .
cd ..
pip install -e .
```

3. Install additional dependencies
Install flash attention from https://github.com/Dao-AILab/flash-attention

```
pip install flash-attn==2.5.4 --no-build-isolation
```

---

4. Install gsm-infinite dependencies

Follow the installation instructions in the gsm_infinite directory before running gsm-infinite evaluation.

---


## Evaluation

`LongBench/run_centroid.sh` provides an example of how to run LongBenchV2 evaluation.

`gsm_infinite/gsm-infinite/run_centroid_hybrid.sh` provides an example of how to run GSM-infinite evaluation. Note that the script will generate output predictions, but `vllm_serve.sh` must also be used to launch a model for parsing outputs in order to collect the final output results, and the `run_evaluation` argument in `config.sh` must be set to True to run evaluation.

---

## Kernel Implementation

`benchmark/benchmark_kernels.py` is a benchmarking script for the centroid lookup, centroid replacement, and sparse flash decoding kernels.

`benchmark/benchmark_kmeans.py` is a benchmarking script for the online clustering update.
