# Don't Pay Attention (DPA)

This repository includes the code used to train models, run benchmarks, and create plots for the paper Don't Pay Attention.

While all the model checkpoints trained during this work will be made available on HuggingFace, they have been omitted from this submission to preserve anonymity.

## Requirements

This code was tested on Ubuntu 22.04, python 3.12, and A100, H100 and H200 GPUs. It is recommended to run the setup and code in a clean python environment. Please make sure CUDA toolkit is installed correctly.

To install the dependencies, run `source setup.sh`.

Either login to weights & biases if you want to log training metrics:
```bash
wandb login
```
or disable it:
```bash
wandb disabled
```

To save a Huggingface-compatible Avey tokenizer, run `python make_hf_tokenizer.py`.

### Avey

Note: the model checkpoint is not made available as a part of this submission, as noted above.

```bash
export MODEL_NAME=avey
export MODEL_PATH=avey1-dpa-1.5B-100BT
```

### Mamba

Install dependencies:
```bash
sh mamba/setup.sh
```

Note: the model checkpoint is not made available as a part of this submission, as noted above.

set the model name and path:
```bash
export MODEL_NAME=mamba
export MODEL_PATH=mamba-dpa-1.5B-100BT
```

### RWKV-7

Note: the model checkpoint is not made available as a part of this submission, as noted above.

```bash
source rwkv7/env.sh
export MODEL_NAME=rwkv7
export MODEL_PATH=rwkv7-dpa-1.5B-100BT
```

### Transformers

Note: the model checkpoint is not made available as a part of this submission, as noted above.

```bash
export MODEL_NAME=tpp
export MODEL_PATH=tpp-dpa-1.5B-100BT
```

## Training

Adjust `NUMBER_OF_GPUS` (on a single node) and `BATCH_SIZE` (start at 1, increase until your GPU runs out of memory) in `train.sh`, and then run:

```bash
sh train.sh
```

## Benchmarks

For standard benchmarks reported in the paper, run:

```bash
sh eval.sh
```

For RULER S-NIAH, run:
```bash
sh eval-long.sh
```

## Plots

To plot the NIAH heatmap (figure 1 from the paper) run:
```bash
sh plot-niah.sh
```

To plot TTFT vs context length (figure 4 from the paper) run:

> [!IMPORTANT]
> Make sure you've already run the setup steps for mamba (run mamba/setup.sh) and rwkv7 (source rwkv7/env.sh)

```bash
python3 plot_ttft.py
```
