# Model Arithmetic

This repo contains the code for model arithmetic, a comprehensive framework where arithmetic formulas express combinations of LMs and classifiers, thereby biasing the generated text towards or away from desired attributes.

In order to install model arithmetic, run

```sh
python -m pip install -e .
```

## LM Evaluation Harness

Model arithmetic is compatible with the [LM Evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness). In order to run benchmarks from the harness, you need to install the package as described [on their GitHub page](https://github.com/EleutherAI/lm-evaluation-harness). An example of how to use our tool with the lm evaluation harness is shown in `scripts/evaluate_lm_eval.py`.



## Reproducing results

For the reproduction of the results presented in our paper, *Controlled Text Generation via Language Model Arithmetic*, we advice to run the code with the exact environment we used (Nvidia H100 80GB GPU on a Linux machine). To do so install [Conda](https://docs.conda.io/projects/miniconda/en/latest/) and run

```sh
conda create -n model_arithmetic python=3.10
conda activate model_arithmetic
python -m pip install -r requirements.txt
python -m pip install -e .
```

Additionally, install the spacy `en_core_web_sm` package.
```sh
python -m spacy download en_core_web_sm
```
You also need to download all datasets and put them in the `data/datasets` folder:
- [Alpaca Data](https://github.com/tloen/alpaca-lora/blob/main/alpaca_data.json)
- [Jigsaw Toxicity Dataset](https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data) (the `all_data.csv` file should be downloaded and extracted in the `data/datasets` folder)
- [Politically Incorrect 4chan Messages](https://zenodo.org/record/3606810) (file should be unzipped and placed in the top level of the `data/datasets` folder)

Then, an API key for [PERSPECTIVE API](https://perspectiveapi.com/) needs to be placed in the file `src/.env`. Specifically, 

```sh
PERSPECTIVE_API_KEY="[YOUR API KEY]"
```

Finally, you can reproduce the results using

```sh
bash scripts/main.sh
```
this will preprocess the data, finetune a classifier for toxicity, and reproduce the results from all sections. Results in CSV-format can then be found in the `processed` and our figures in `plots`


We note that part of our preprocessing code got lost, specifically for finetuning the classifier. Running the code in this setting might therefore result in slightly different numbers when they involve the finetuned classifier. However, you can also adjust the `scripts/main.sh` such that the first line is

```sh
python scripts/preprocess.py --reproduction
```

This provides a manual fix that selects the exact samples in the exact order we use. 