# Activation-Based Preference Optimization (APO)

This repository contains the code for **Low-Resource Preference Adaptation of LLMs via Activation-Based Label Propagation**, a method for improving preference optimization in language models by learning a probe on model activations to predict human preferences, then using the probe's predictions to automatically relabel training data.

## Overview

APO works by:
1. Training a linear or MLP probe on model activations to predict which response humans prefer using a small dataset
2. Using the trained probe to annotate large unlabelled preference datasets
3. Training a language model using preference optimization (DPO, KTO, CPO, or IPO) with the probe-labeled data

This approach can improve preference optimization in low resource settings.

## Installation

```bash
pip install -r requirements.txt
```

### Requirements

- Python 3.10+ (tested on 3.12.3)
- PyTorch 2.0+ (tested on 2.9.1)
- CUDA-capable GPU (recommended)

## Quick Start

### Basic Usage

Train a model with APO using DPO on the HH-RLHF dataset:

```bash
python main.py \
  --model-name meta-llama/Llama-3.2-1B \
  --po-dataset Anthropic/hh-rlhf \
  --po-method dpo \
  --probe-type logistic \
  --no-probe-filter-length-outliers \ # all experiments were conducted with this flag on
  --output-dir ./apo_output
```

### With SFT Pre-training

```bash
python main.py \
  --do-sft \
  --sft-dataset tatsu-lab/alpaca \
  --sft-max-samples 1000 \
  --po-dataset Anthropic/hh-rlhf \
  --output-dir ./apo_output
```

### Comparing Against Baselines

```bash
# Compare probe labels vs original human labels (default)
python main.py --baseline original

# Compare probe labels vs random labels
python main.py --baseline random --flip-probability 0.5

# Compare probe-trained model vs SFT-only model
python main.py --baseline sft --do-sft
```

## Configuration Options

### Model Settings

| Argument | Default | Description |
|----------|---------|-------------|
| `--model-name` | `meta-llama/Llama-3.2-1B` | HuggingFace model identifier |
| `--use-4bit` | `False` | Enable 4-bit quantization |

### Probe Settings

| Argument | Default | Description |
|----------|---------|-------------|
| `--probe-layers` | `[8, 12, 16]` | Model layers to extract activations from |
| `--probe-subset-size` | `1000` | Number of samples for probe training |
| `--probe-type` | `logistic` | Probe type: `logistic` or `mlp` |
| `--probe-confidence-threshold` | `0.0` | Minimum confidence for relabeling (0.0 = relabel all) |
| `--probe-filter-length-outliers` | `True` | Filter samples with extreme length ratios |

### Preference Optimization Settings

| Argument | Default | Description |
|----------|---------|-------------|
| `--po-method` | `dpo` | Optimization method: `dpo`, `kto`, `cpo`, `ipo` |
| `--po-dataset` | `Anthropic/hh-rlhf` | Preference dataset |
| `--po-max-samples` | `5000` | Maximum training samples |
| `--beta` | `0.1` | KL penalty coefficient |
| `--learning-rate` | `2e-5` | Learning rate |

### Training Options

| Argument | Description |
|----------|-------------|
| `--train-probe-only` | Only train the probe model (skip baseline) |
| `--train-baseline-only` | Only train the baseline model (skip probe) |

### Evaluation Settings

| Argument | Default | Description |
|----------|---------|-------------|
| `--baseline` | `original` | Baseline for comparison: `original`, `random`, `sft` |
| `--eval-samples` | `100` | Number of evaluation samples |
| `--judge-model` | `Qwen/Qwen3-4B` | LLM judge model |
| `--enable-checkpoint-eval` | `False` | Enable checkpoint-based evaluation |

## Supported Datasets

- **HH-RLHF** (`Anthropic/hh-rlhf`): Human preference pairs
- **UltraFeedback** (`HuggingFaceH4/ultrafeedback_binarized`): Model feedback data
- **Nectar** (`berkeley-nest/Nectar`): Multi-ranked preferences
- **PRISM** (`HannahRoseKirk/prism-alignment`): Demographic-aware preferences
- **AfriSenti** (`afrisenti`): African language sentiment classification

### Multi-language Support (AfriSenti)

```bash
python main.py \
  --po-dataset afrisenti \
  --po-dataset-language amh \
  --probe-dataset afrisenti \
  --probe-dataset-language amh
```

Supported languages: `amh` (Amharic), `dz` (Dinka), `ha` (Hausa), `ig` (Igbo), and others.

## Project Structure

```
APO/
├── main.py                 # Main pipeline orchestration
├── config.py               # Configuration dataclass
├── dataset_utils.py        # Dataset loading and formatting
├── probe_training.py       # Probe training and dataset relabeling
├── probes.py               # Probe model definitions
├── activation_extractor.py # Activation extraction utilities
├── model_utils.py          # Model layer access utilities
├── training.py             # SFT and preference optimization training
├── evaluation.py           # LLM-as-a-judge and ground truth evaluation
├── eval.py                 # Checkpoint-based evaluation
├── callbacks.py            # Training callbacks
├── wandb_utils.py          # Weights & Biases integration
├── probe_analysis.py       # Probe analysis and visualization
└── requirements.txt        # Dependencies
```

## Output Structure

```
apo_output/
├── relabeled_data.json         # Probe-relabeled dataset
├── results.json                # Final evaluation results
├── checkpoint_results.json     # Checkpoint evaluation (if enabled)
├── dpo_probe/                  # Probe-trained model checkpoints
└── dpo_original/               # Baseline model checkpoints
```

## Weights & Biases Integration

Enable experiment tracking with:

```bash
python main.py --use-wandb --wandb-project my-project
```

Disable with `--no-wandb`.

## Advanced Usage

### Checkpoint-based Evaluation

Evaluate model performance at training intervals:

```bash
python main.py \
  --enable-checkpoint-eval \
  --checkpoint-intervals 0.25 0.5 0.75 1.0 \
  --checkpoint-eval-samples 30
```

### Using a Separate Probe Dataset

Train the probe on a different dataset than the PO training:

```bash
python main.py \
  --probe-dataset Anthropic/hh-rlhf \
  --po-dataset HuggingFaceH4/ultrafeedback_binarized
```

### Confidence-based Relabeling

Only relabel samples where the probe is confident:

```bash
python main.py --probe-confidence-threshold 0.2
```

## Citation

If you use this code in your research, please cite our paper:

```bibtex

```

## License

This project is released for research purposes.
