# Distillation Robustifies Unlearning

(prepared as NeurIPS 2025 supplementary material)

This repository contains the implementation for research on improving the robustness of machine unlearning through distillation techniques. The codebase supports experiments across multiple domains (language, arithmetic, and WMDP) using various unlearning methods.

## Overview

The implementation includes:

- **Multiple unlearning methods**: GradDiff, MaxEnt, and RMU
- **Distillation techniques**: Standard distillation and partial distillation (UNDO) with configurable mixing parameters
- **Multi-domain evaluation**: Language tasks (English/Korean), arithmetic operations, and WMDP 
- **Relearning attack**: Including relearning attacks to test robustness

## Project Structure

```
├── run_*.py                        # Main experiment runners
├── code/
│   ├── tools/                      # Core implementations
│   │   ├── pretrain.py             # Model pretraining
│   │   ├── distill.py              # Standard distillation
│   │   ├── partial_distill_*.py    # Partial distillation implementations
│   │   ├── relearn_*.py            # Relearning attack implementations
│   │   └── unlearn_*/              # Unlearning method implementations
│   │       ├── graddiff.py         # Gradient Ascent unlearning
│   │       ├── maxent.py           # Maximum Entropy unlearning
│   │       └── rmu.py              # Random Mislabeled Unlearning
│   ├── utils/                      # Utility functions
│   │   ├── loss_functions.py       # Loss function implementations
│   │   ├── validation_functions.py # Evaluation metrics
│   │   ├── process_datasets.py     # Data processing utilities
│   │   └── paths.py                # Path configurations
│   └── prepare_*/                  # Data and model preparation
├── collector.py                    # Code collection utility
└── wmdp_question_extraction.py     # WMDP data processing
```

## Setup and Installation

### Installation

1. **Clone the repository**:
```bash
git clone <repository-url>
cd distillation-robustifies-unlearning
```

2. **Install dependencies**:
```bash
pip install uv
uv venv
source .venv/bin/activate
uv sync
```

3. **Set up authentication tokens**:
Create a `tokens/` directory and add:
- `hf_token.txt` - Hugging Face token
- `wandb_token.txt` - Weights & Biases token


4. **Configure paths**:
Edit `code/utils/paths.py` to set your data and model directories:
```python
CACHE_DIR = '/your/cache/directory'
DATASET_DIR = '/your/dataset/directory'
MODEL_DIR = '/your/model/directory'
WMDP_MODEL_DIR = '/your/wmdp/model/directory'
```

## Quick Start

### 1. Prepare Models
```bash
# Create smaller Gemma models for experiments
python code/prepare_models/reduce_gemma.py

# Download base Gemma model for WMDP experiments
python code/prepare_models/download_gemma.py
```

### 2. Prepare Data
```bash
# Download and prepare language datasets
python code/prepare_data/download_english_and_korean.py
python code/prepare_data/prepare.py

# Generate arithmetic datasets
python code/prepare_data/download_arithmetic.py
```

### 3. Run Basic Experiments

**Language Domain**:
```bash
# Pretrain a bilingual model
python run_pretrain_language.py

# Unlearn Korean while retaining English
python run_unlearn_language.py

# Distill the unlearned model
python run_distill_language.py

# Test robustness with relearning
python run_relearn_language.py
```

**Arithmetic Domain**:
```bash
# Pretrain on all arithmetic operations
python run_pretrain_arithmetic.py

# Unlearn multiplication/division, retain addition/subtraction
python run_unlearn_arithmetic.py

# Apply distillation
python run_distill_arithmetic.py

# Test with relearning attacks
python run_relearn_arithmetic.py
```

## Key Concepts

### Unlearning Methods

1. **GradDiff**: Performs gradient ascent on forget data and gradient descent on retain data
2. **MaxEnt**: Pushes model outputs toward uniform distribution on forget data
3. **RMU**: Uses random targets to disrupt learned representations

### Distillation Techniques

1. **Standard Distillation**: Student learns from teacher using KL divergence
2. **Partial Distillation (UNDO)**: Mixes original and unlearned teacher models with configurable α and β parameters

## Advanced Usage

### Partial Distillation Sweeps

Run systematic sweeps over α and β parameters:

```bash
# Language domain partial distillation
python run_partial_distill_language.py --run_all

# Arithmetic domain with specific parameters
python run_partial_distill_arithmetic.py --setup gemma-2-0.3B_MaxEnt --alpha 0.5 --beta 0.1
```

### Custom Evaluation

Implement custom evaluation functions in `code/utils/validation_functions.py`:

```python
def custom_eval_fn(model, eval_data, accelerator):
    # Your evaluation logic here
    return {"metric1": value1, "metric2": value2}
```

### WMDP Domain

For weapons of mass destruction prevention experiments:

```bash
# Unlearn dangerous knowledge
python run_unlearn_wmdp.py

# Apply partial distillation
python run_partial_distill_wmdp.py

# Test relearning resistance
python run_relearn_wmdp.py
```

## Monitoring and Logging

The project integrates with Weights & Biases for experiment tracking:

- Set `use_wandb=True` in experiment configurations
- Metrics are automatically logged during training
- Local records are also saved for offline analysis

## Data Formats

### Language Data
```json
{"text": "Your text content here"}
```

### Arithmetic Data
```json
{"text": "15 + 23 = 38"}
```

### QA Data
```json
{"qa": {"question": "What is 2+2?", "answer": "4"}}
```