# Self-Proving Models

Self-Proving Models prove the correctness of their outputs to a verifier using an Interactive Proof System.
This repository includes tools to train these models, specifically for the Greatest Common Divisor (GCD) problem,
based on the theory described in our paper.

This repository provides:
- A straightforward framework for training a Self-Proving GPT.
- Data generation scripts for the GCD problem.
- Scripts for reproducing experiments from the paper.


## Setup

1. Create a new conda environment (recommended):
```bash
conda create -n spm python=3.12
conda activate spm
```
2. Install the package:
```bash
cd self-proving-models
pip install -e .
```

## Data
You can download pre-generated data ([anonymized link](https://zenodo.org/records/13855544)), and extract it to the `data/` directory.
To generate this data yourself, run
```bash
python spm/data/generate_data.py
```

This should take about 95 minutes, and will populate the `data/` directory with Transcripts and Annotated Transcripts
of interactions between an honest prover and the verifier.

Transcript datasets are named according to the following convention:
```
TL_{UPPER_BOUND}_m{NUM_TRAIN_SAMPLES}_b{BASE_OF_REPRESENTATION}
```
For Annotated Transcripts, `TL` is replaced with `ATL{ANNOTATION_LENGTH}`.

## Training
Once `data/` is populated, you can train a Self-Proving GPT via Transcript Learning:
```bash
python spm/train.py --data DATASET_NAME
```
where `DATASET_NAME` is the name of the dataset you want to use.
### Example
To train on about 10 million samples with an upper bound of 10,000 encoded in base 210:
```bash
python spm/train.py --data TL_1e4_m1e7_b210
```
### Useful arguments
- `--help`: Show all arguments.
- `--device DEVICE`: Specify the device to train on (`cpu` or `cuda`).
- `--epochs EPOCHS`: Number of epochs to train. Each epoch looks at a number of samples equal to the dataset size.
- `--data DATA, -d DATA`: Name of the dataset to use.
- `--seed SEED`: Set random seed
- `--save_iters SAVE_ITERS [SAVE_ITERS ...]`:
    Save model at these iterations. -1 For the last iteration. None to disable.
    Checkpoints are saved to `models/` as:
`{N_LAYERS}x{N_HEAD}x{DIM_EMBED}x{N_ITERATIONS}_{DATASET_NAME}_iter{ITERATION}.pt`
- `--load_ckpt LOAD_CKPT`: Load model from checkpoint (name).
 When you load a model, specify the checkpont name as described above (not the full path).
- `--wandb`: Enable tracking with [Weights & Biases](https://wandb.ai/).
Use `--wandb_proj WANDB_PROJ` to specify the project name.

## Reproducing results from the paper
Once you [obtain data](#data), you can train models on the datasets to reproduce the experimental
section of the paper.

### Annotation length
To reproduce the ablation on the annotation length, run
```bash
./runs/annot_len.sh
```
Logs of these runs will be saved at `logs/`. The Figure 2 can be generated with
```bash
./figs/annotation.py        # Fig 2
```

### Reinforcement Learning from Verifier Feedback
Notice that in the [above ablation](#Annotation length), T=0 corresponds to GPT+TL and T=7 corresponds to GPT+Annotated TL in Table 2 of the paper.

To obtain the third row, first train a mode with TL for 10k iterations:
 ```bash
 train.py --device=cuda --dropout=0 --eval_batch_size=512 --warmup_iters=0 --epochs=1 --beta1=0.733 --learning_rate=0.0007 --batch_size=1024 --decay_lr=10 --grad_clip=2 --n_embd=256 --n_head=8 --n_layer=8 --seed=0 --data=TL_1e4_m1e7_b210 --eval_interval=10000 --log_interval=500 --save_iters 10000
 ```

You now have a low-verifiability base model. Tune it with RLVF with the following hyperparameters:
```bash
python spm/train.py --rlvf=annotated --load_ckpt=8x8x256-TL_1e4_m1e7_b210_iter10000 --batch_size=2048 --beta1=0.7080340973835836 --decay_lr=10 --device=cuda --epochs=1000 --eval_batch_size=1024 --eval_interval=10 --grad_clip=0 --learning_rate=0.00012878882950276307  --log_interval=1 --n_embd=256 --n_head=8 --n_layer=8 --seed=0 --temperature=1.6097135118487995 --warmup_iters=0
```

Note that in the above, `--rlvf=annotated` indicates that a transcript should be considered "accepted" only when the entire annotated transcript was generated correctly. However, since the base model was not trained with annotations, this is the same as `--rlvf=transcript`.

### Base of representation
The paper shows that the number of unique primes in the base of representation
determines Verifiability of the model. This ablation requires generating many different datasets (one for each base).
For convenience, there is a script that first samples a random base with a given number of unique primes in its
factorization, then trains a model and deletes the dataset.
```bash
python spm/train_diff_bases.py --num_unique_primes NUM_UNIQUE_PRIMES --seed SEED
```

Run this script for twenty seeds. *Tip: You can use a [WandB sweep](https://docs.wandb.ai/guides/sweeps)
to easily  schedule runs from different machines, and then aggregate the results onto your local machine
for plotting.
Use `wandb sweep runs/diff_bases.yaml` to create the sweep.*

Each run will be logged to `logs/diff_bases/`.
You can then generate the  figure with
```bash
./figs/diff_bases.py  # Fig 3
```

## Acknowledgements
This codebase adapts Andrej Karpathy's [nanoGPT](https://www.github.com/karpathy/nanoGPT) as its GPT implementation.
The model can be found in `self-proving-models/gpt/`.

## Package structure

**Root `spm/`**
- `train.py`: Main entry point for training models.
- `train_diff_bases.py`: Alternative entry point for training models on datasets with different bases of representation.
                         Generates and cleans up datasets automatically.              
- `utils.py`: Common utilities (e.g. implementation of the extended Euclidean algorithm).
- `__init__.py`: Common paths and constants.
- `systematic_models.py`: Systematic models for the GCD problem, useful for testing.

**Data `spm/data/`**
- `generate_data.py`: Generates datasets for training.
- `samples.py`: The `Samples` represents a dataset of input-output sequences to the GCD problem.
                `Transcripts` add a proof (Bézout coefficients) to the samples.
                `AnnotatedTranscripts` add a proof and its annotation.
- `samplers.py`: Samplers for generating `Samples`, `Transcripts`, and `AnnotatedTranscripts`.
- `str_repr.py`: A string representation of data samples (encoded in a given base).
- `tensor_repr.py`: A tensor representation of data samples. Uses `str_repr` to encode samples as strings, and
                    handles delimiters and padding. Contains utility methods for encoding, decoding, and saving.

**Model `spm/gpt/`**
- `model.py`: Defines the GPT model architecture. An object-oriented adaptation of [nanoGPT](https://www.github.com/karpathy/nanoGPT).
- `trainer.py`: Trainer for GPT models. Handles training, evaluation, and logging.
- `config.py`: Config for a trainer.
- `rlvf_trainer.py`: Specialized trainer for RLVF models. (Work in progress)
