# Antidistillation Sampling

This repository implements the techniques described in our paper "Antidistillation Sampling" for protecting language models from distillation attacks while preserving their utility.

## Overview

Frontier language models that return extended reasoning traces inadvertently yield rich token sequences that can facilitate model distillation (model stealing). Consequently, model owners may prefer token sampling strategies that limit the effectiveness of distillation. 

**Antidistillation sampling** provides this capability: By modifying the model's next-token probability distribution, it poisons reasoning traces so they are less useful for distillation without significantly degrading model utility.

## How It Works

The core idea is to adjust the teacher model's sampling distribution using gradients from a student model on downstream tasks. By sampling in directions that would make the student perform worse if it learns from these traces, we can generate data that maintains high quality but resists distillation.

The method involves:
1. Computing gradients from the student model on a holdout set
2. Using these gradients in a finite difference approximation to modify the teacher's sampling distribution
3. Controlling the tradeoff between teacher performance and distillation resistance via a λ parameter

## Installation

1. Install uv: https://docs.astral.sh/uv/getting-started/installation/
2. Run `uv sync` to install dependencies from the environment file
3. Install flash-attention: `uv add flash-attn --no-build-isolation`

## Project Structure

- `save_grad.py` - Computes and saves gradients from a student model
- `gentraces.py` - Generates reasoning traces with optional antidistillation sampling
- `distill.py` - Implements model distillation from the generated traces
- `utils.py` - Utility functions for dataset loading, initialization, etc.
- `Makefile` - Orchestrates the entire pipeline
- Configuration files:
  - `gen_config.yaml` - Configuration for trace generation
  - `train_config.yaml` - Configuration for model training/distillation
  - `acc_config.yaml` - Accelerate configuration for distributed training

## Pipeline

The full pipeline consists of 5 main steps:

1. **Generate holdout traces**: Create a dataset of reasoning traces from the teacher model on a holdout set.
2. **Compute student gradients**: Calculate gradients from a student model on the holdout traces.
3. **Generate training traces**: Create training traces using different tau/lambda parameters, with optional antidistillation sampling.
4. **Distill models**: Train student models on the generated traces.
5. **Evaluate models**: Evaluate the performance of the distilled models.

## Running the Pipeline

You can run the full pipeline with:

```bash
bash pipeline.sh
```

### Key Configuration Parameters:

- **EXP_DIR**: Output directory for all experiment data
- **SEED**: Random seed for reproducibility
- **DATASET**: Dataset to use (e.g., gsm8k, hendrycks_math)
- **TAU_LAMS**: Tau/Lambda parameter pairs (temperature and antidistillation strength)

## How to Modify the Pipeline

### Changing Configuration Files

1. **gen_config.yaml**: Modify to change trace generation parameters like:
   - Teacher model
   - Maximum token length
   - Temperature (tau)
   - Antidistillation strength (lambda)
   - Dataset settings

2. **train_config.yaml**: Modify to change training parameters like:
   - Student model
   - Learning rate
   - Batch size
   - Number of epochs
   - LoRA parameters

### Adding New Models or Datasets

1. To add new models, modify the MODEL_NAME parameter when running the pipeline.
2. To add new datasets, you'll need to implement a loading function in utils.py similar to `load_gsm8k` or `load_hendrycks_math_dataset`.

## Experimental Results

Our paper demonstrates that antidistillation sampling provides an effective defense against model distillation. With appropriate hyperparameters, we can maintain high teacher model performance while significantly reducing the effectiveness of distillation.

The paper shows that by tuning the λ parameter, we can control the trade-off between maintaining teacher accuracy and preventing successful distillation.
