# MuSe: Multipole Semantic Attention -- Supplementary Code

This repository contains the code for reproducing the experiments in the paper
"MuSe: Multipole Semantic Attention for Long-Context Language Model Pretraining".

Note: The package is named `fma` after an earlier working title
("Fast Multipole Attention"). All `fma` references in the code correspond
to MuSe in the paper.

## Repository Structure

```
muse_supplementary/
├── fma/                        # Core implementation and training
│   ├── fma/                    # MuSe attention library
│   │   ├── pallas_retrieval.py # Main MuSe attention kernel (JAX Pallas)
│   │   ├── pallas_cluster.py   # Query clustering implementation
│   │   ├── fast_kmeans.py      # Fast k-means for cluster assignment
│   │   ├── moba_retrieval.py   # MoBA baseline implementation
│   │   ├── flash_attention.py  # Flash attention wrapper
│   │   └── cuda_kernels/       # Custom CUDA kernels for clustering
│   ├── examples/minigpt/       # Training scripts
│   │   ├── minigpt.py          # Base GPT model definition
│   │   ├── rope_attention.py   # RoPE attention module
│   │   ├── exp_cudnn_*.py      # CUDNN baseline training configs
│   │   ├── exp_fma_*.py        # MuSe training configs
│   │   └── exp_moba_*.py       # MoBA baseline training configs
│   └── pyproject.toml          # Package definition and dependencies
├── fma-flax-llama/             # Llama evaluation (loss-vs-position)
│   ├── fma_llama/              # Flax Llama 3.2 1B implementation
│   ├── scripts/
│   │   ├── compare_fma_configs.py  # Loss-vs-position evaluation
│   │   └── convert_weights.py      # HuggingFace -> Flax weight conversion
│   ├── configs/                # Model configuration files
│   └── pyproject.toml
└── long_stackv2/               # Dataset preparation scripts
    ├── *.py                    # Code dataset preparation (The Stack v2)
    ├── longmino_pdfs/          # Science/PDF dataset preparation
    └── gutenberg16384/         # PG-19 / Gutenberg dataset preparation
```

## Requirements

- Python >= 3.10
- JAX with CUDA support (`jax[cuda]>=0.4.20`)
- Flax (`flax>=0.8.0`)
- 8x NVIDIA A100 (80GB) GPUs for training
- NVIDIA CUTLASS (for CUDA kernel compilation; see below)

### Installation

Install the `fma` package in editable mode from the repository root:

```bash
cd fma
pip install -e .
```

This installs the MuSe attention library and all training dependencies
(JAX, Flax, Optax, Grain, etc.). See `fma/pyproject.toml` for the full
dependency list.

To exactly replicate our environment, use [uv](https://docs.astral.sh/uv/)
with the included lockfile:

```bash
cd fma
uv sync
```

For the Llama evaluation code, also install `fma-flax-llama`:

```bash
cd fma-flax-llama
pip install -e .
```

Note: `fma-flax-llama` imports from the `fma` package, so `fma` must be
installed first.

### CUDA Kernel Compilation

The custom clustering kernels require NVIDIA CUTLASS headers.
Download CUTLASS into the cuda_kernels directory before compiling:

```bash
cd fma/fma/cuda_kernels
git clone https://github.com/NVIDIA/cutlass.git

# Compile the clustering kernels
bash cluster.sh       # Head dimension 64
bash cluster_d128.sh  # Head dimension 128
bash assign_indices.sh
```

The compilation scripts use `nvcc` with compute capability `sm_80`. The
kernels use SM80 features (async copy, tensor cores) and require Ampere
or newer GPUs (A100, RTX 30xx/40xx, H100, etc.). Different Ampere+ GPUs
may work for smaller experiments but have not been tested.

## Training

Training scripts are in `fma/examples/minigpt/`. Each `exp_*.py` file is a
self-contained training script that defines the model architecture, optimizer,
data loading, and training loop. Scripts are run directly with Python.

All training scripts expect pretokenized data in `fma/examples/minigpt/data/`.
See "Dataset Preparation" below.

### Hardware

Exact replication requires a node with 8x NVIDIA A100 (80GB) GPUs. The
training scripts configure a `(2, 4)` JAX device mesh for batch and model
parallelism. For other GPU counts, modify the mesh at the top of the
script:

```python
mesh = Mesh(mesh_utils.create_device_mesh((batch_parallel, model_parallel)), ('batch', 'model'))
```

Note that larger experiments (SCALE >= 6) may OOM if model sharding is
reduced.

### Model Scales

The SCALE parameter controls model size. The mapping from SCALE to
parameter count is:

| SCALE | Parameters | embed_dim | num_heads | num_layers |
|-------|-----------|-----------|-----------|------------|
| 3     | 96M       | 768       | 12        | 10         |
| 4     | 185M      | 1024      | 16        | 12         |
| 5     | 350M      | 1280      | 20        | 16         |
| 6     | 560M      | 1536      | 24        | 18         |
| 8     | 1B        | 2048      | 32        | 20         |

All models use sequence length 65536 (2^16) and batch size 2.

### CUDNN Baseline

The CUDNN baseline uses exact attention. For example, to train a 1B model
(SCALE=8) on the code dataset:

```bash
cd fma/examples/minigpt
python exp_cudnn_stack5678_scale8_v2.py
```

Key configuration (set at the top of the script):
- `SCALE = 8` -- controls model size (see table above)
- `maxlen = 2**16` -- sequence length (65536 tokens)
- `batch_size = 2` -- per-step batch size
- `use_fma_attention = False` -- uses exact CUDNN attention
- `dataset = "longstackv2_buckets_5678"` -- training dataset

Smaller models use corresponding scripts
(e.g., `exp_cudnn_stack5678_scale3_v2.py` for 96M parameters).

### MuSe Training

MuSe training scripts differ from the baseline only in the attention
configuration. For example, to train a 1B MuSe model:

```bash
cd fma/examples/minigpt
python exp_fma_stack5678_scale8_v2.py
```

The key differences from the baseline:
- `use_fma_attention = True`
- `fma_Q = 128` -- number of query clusters
- `fma_K = 128` -- number of key clusters
- `fma_num_retrievals = 8` -- number of retrieved blocks per query block
- `fma_blk_size = 2**13` -- block size (8192 tokens)

The naming convention `exp_fma_*_R8S1_*.py` encodes the MuSe configuration:
R8 = 8 retrievals, S1 = 1 sparsity level. "Stopped" in the name indicates
stopped-gradient training (used in final experiments).

### MoBA Baselines

MoBA baseline training scripts are also included (`exp_moba_*.py`).
These use the MoBA block-sparse attention mechanism from
`fma/moba_retrieval.py` as a comparison point. The naming convention
encodes the MoBA configuration: e.g., `exp_moba_*_B512SP2_*.py` means
block size 512, sparsity 2.

### Experiment Naming

Scripts follow the pattern `exp_{method}_{dataset}_{scale}_{variant}.py`:
- **method**: `cudnn` (baseline), `fma` (MuSe), `moba` (MoBA baseline)
- **dataset**: `stack5678` (code), `science2e16` (science/PDF)
- **scale**: `scale3` (96M) through `scale8` (1B)

### Container Environment

The training scripts were developed in a container environment with two
mount points:

- `/aim` -- Aim experiment tracking repository
- `/checkpoints` -- checkpoint storage

If you are not running in a container, modify these paths in the training
scripts:

1. **Checkpoints**: near the top of each `exp_*.py` script, change
   `checkpoint_save_path`:
   ```python
   checkpoint_save_path = f"/checkpoints/..."  # change to a local path
   ```

2. **Aim**: the `Run(repo="/aim", ...)` calls should be updated to point
   to your Aim repository (see below). Alternatively, comment out the
   Aim `Run(...)` calls to run without tracking.

### Experiment Tracking with Aim

Training scripts log metrics (loss, learning rate, throughput) to
[Aim](https://github.com/aimhubio/aim), an open-source experiment tracker.

To set up Aim:

```bash
pip install aim  # already included in fma dependencies

# Initialize a new Aim repository
mkdir -p /path/to/aim-repo
cd /path/to/aim-repo
aim init

# Launch the Aim UI (optional, for browsing results)
aim up
```

Then update the `Run(repo=...)` path in the training scripts to point to
your Aim repository. Aim is not required for training -- you can comment
out the Aim-related code and the scripts will still run and print metrics
to stdout.

## Llama Evaluation (Loss vs Position)

This evaluates MuSe as a drop-in attention replacement on pretrained
Llama 3.2 1B, measuring per-position cross-entropy loss on PG-19.

### 1. Convert Llama Weights

Download Llama 3.2 1B from HuggingFace (requires access approval) and
convert to Flax format:

```bash
cd fma-flax-llama

python scripts/convert_weights.py \
    --model_name meta-llama/Llama-3.2-1B \
    --output_dir checkpoints/llama-3.2-1b-flax \
    --use_auth_token
```

### 2. Run Loss-vs-Position Comparison

The `compare_fma_configs.py` script evaluates standard attention and
various MuSe configurations on PG-19 test data:

```bash
python scripts/compare_fma_configs.py \
    --checkpoint_dir checkpoints/llama-3.2-1b-flax \
    --max_length 65536 \
    --num_samples 100 \
    --output loss_by_position.png
```

MuSe configurations are defined in the `FMA_CONFIGS` list at the top of the
script. Edit this list to test different block sizes, cluster counts, and
retrieval counts.

## Dataset Preparation

### Data Format

Training scripts expect pretokenized data in
`fma/examples/minigpt/data/{dataset_name}/` with the following files:

```
data/{dataset_name}/
├── train.bin        # Flat array of uint16 token IDs
├── test.bin         # Same format, for evaluation
└── tokenizer.json   # HuggingFace tokenizers JSON file
```

The `.bin` files are simply concatenated token IDs stored as little-endian
uint16 (via `np.array(..., dtype=np.uint16).tofile(...)`). The training
code memory-maps these files and slices them into fixed-length windows
of `maxlen` tokens. No explicit sequence boundary information is needed
at training time.

**Any long-context corpus will work.** The specific dataset does not
matter for verifying that MuSe matches the baseline -- the key requirement
is that documents are long enough to exercise the far-field attention
(i.e., significantly longer than the block size of 8192 tokens).

To use your own data:

1. **Choose a tokenizer.** Any tokenizer will work. Our experiments use
   custom 32k BPE tokenizers trained with `long_stackv2/train_tokenizer.py`,
   but an off-the-shelf tokenizer (e.g., GPT-2, Llama) is fine. The
   tokenizer must be saved as a HuggingFace `tokenizers` JSON file
   (`tokenizer.json`). The training scripts auto-detect `vocab_size`
   from this file.

2. **Pretokenize your data** into the binary format:
   ```python
   import numpy as np
   from tokenizers import Tokenizer

   tokenizer = Tokenizer.from_file("tokenizer.json")
   tokens = []
   for doc in your_documents:
       tokens.extend(tokenizer.encode(doc).ids)
   np.array(tokens, dtype=np.uint16).tofile("train.bin")
   ```
   The uint16 format supports vocabularies up to 65536 tokens. If your
   tokenizer has a larger vocabulary (e.g., Llama 3's 128k), change the
   dtype to `np.uint32` here and in the `MemMapDataset` class in the
   training script.

3. **Place the files** in `fma/examples/minigpt/data/{dataset_name}/`
   with `train.bin`, `test.bin`, and `tokenizer.json`.

### Reconstructing Our Datasets

The scripts in `long_stackv2/` reproduce the specific datasets used in
the paper. This is not necessary to verify the method -- any
long-context corpus in the format described above will work.

#### Code Dataset (The Stack v2)

The code training data is derived from
[The Stack v2](https://huggingface.co/datasets/bigcode/the-stack-v2-smol)
(BigCode), using long-context repository buckets (16k--256k+ tokens).
The pipeline is:

1. `make-stack-v2-smol-ids-bucketed.py` -- download and bucket
   repositories by total token count
2. `generate_blob_lists.py` -- generate lists of source files to download
3. `fetch_sample_files_s3.py` -- download source files from Software
   Heritage S3
4. `filter_english_files.py` -- filter to English-language source files
5. `construct_repo_sequences.py` -- assemble repository-level sequences
   with StarCoder formatting
6. `train_tokenizer.py` -- train a 32k BPE tokenizer
7. `prepare_pretokenized_dataset.py` -- tokenize and write to binary
   format (`bucket_XX_train.bin` as uint16)

The per-bucket binary files should then be concatenated into a single
`train.bin` and placed in `data/longstackv2_buckets_5678/`.

#### Science Dataset

The science/PDF training data preparation scripts are in
`long_stackv2/longmino_pdfs/`. The pipeline downloads and parses
scientific PDFs, then pretokenizes them with
`longmino_pdfs/pretokenize_dataset.py` into the same binary format.
