# Test time training on nearest neighbor (TTT-NN)

The repository contains code to:

1. Train the embedding model
2. Build the Pile nearest neighbor index
3. Run distributed servers on top of the index
4. Query the servers
5. Evaluate test time training with nearest neighbors on the Pile
6. Run baselines

## Necessary files

To evaluate TTT-NN you upltimately need the following directory structure.

```
indexes/
  roberta-large/
    00.jsonl.index
    01.jsonl.index
    ...
    29.jsonl.index

models/
  roberta-large-pile-lr2e-5-bs16-8gpu/
    checkpoint-1700000/

pile/
  train/
    00.jsonl
    01.jsonl
    ...
    29.jsonl
  val.jsonl
  test.jsonl

servers/
  addresses.txt
```

Download the dataset [here](https://the-eye.eu/public/AI/pile/) and place the files in the `pile/` subdirectory.

## Train the embedding model

See `code/trainer_lm.py`. This is a standard HuggingFace training setup.
This code was used to produce the model checkpoint `checkpoint-1700000` in the `models` directory.
The model trained for approximately one month on 8 A100 GPUs, making one pass over the data.

Make sure to have the checkpoint `models/roberta-large-pile-lr2e-5-bs16-8gpu/checkpoint-1700000` before you proceed.

## Building the index

Use the code in `code/build_database.py` to build an index on top of the Pile dataset. This is a time consuming operation.
Specify `--data_file` to build index for given data file.

```
/usr/bin/python3 code/build_database.py \
                 --data_file pile/train/00.jsonl \
                 --output_dir indexes/roberta-large
```

Run for each data file.

## Running the Pile server

The following command will launch a server with 6 replicas each serving one split of the data. This will append 6 ip addresses and ports to the file specified as `address_path`. 

```
python3 code/pile_server.py \
        --address_path servers/addresses.txt \
        --data_file pile/train/00.jsonl \
        -num_servers 6 \
```

To serve from all Pile data files, start one server for each data file. 
We recommend starting 30 servers with 6 replicas each, resulting in 180 instances running.

Make sure servers are up and running before launching evaluation.

## Using the Pile client

Use `code/pile_client.py` to query the server. Specify `--address_path` to indicate which servers to query. The client will query all servers it finds under the address path and query each. The client then builds a local nearest neighbors structure to find the nearest neighbors among all the retrieved results.

The client code can be used as a standalone client, but will also be called from the evaluation code.

## Running test time training with nearest neighbors

To evaluate on GPTNeo with default parameters:

```
/usr/bin/python3 code/eval_tttlm.py \
                 --address_path servers/addresses.txt \
                 --results_dir results/
```

To evaluate on GPT2:

```
/usr/bin/python3 code/eval_tttlm.py \
                 --model gpt2-large \
                 --tokenizer gpt2-large \
                 --embedding_model_checkpoint models/roberta-large-pile-lr2e-5-bs16-8gpu/checkpoint-1700000
                 --max_length 1024 \
                 --stride 1024 \
                 --learning_rate 2e-5 \
                 --address_path servers/addresses.txt \
                 --results_dir results/
```

Replace `gpt2` with `gpt2-large` to evaluate on GPT2Large.

The evaluation code requires the Eleuther-AI package in `lm_eval`.

## Run baselines

* See `code/baseline_context.py` for in-context baseline.
* See `code/baseline_interpolation.py` for interpolation baseline.
* Run `code/eval_tttlm.py` with option `--dynamic_eval` for dynamic evaluation baseline.
