# Knowledge Distillation through Geometry-Aware Representational Alignment

This repository contains the code and data used in the paper "Knowledge Distillation through Geometry-Aware Representational Alignment".

## Overview

This project provides a comprehensive framework for knowledge distillation that goes beyond traditional KL divergence by incorporating geometric similarity measures between teacher and student representations. 

## Repository Structure

```
feature-distillation/
├── train/                          # Training scripts
│   ├── distill_glue.py            # GLUE task distillation
│   ├── distillation.py             # Core distillation framework
│   ├── similarity_measures.py      # Similarity measure implementations
│   ├── inst_tuning.py             # Instruction tuning with distillation
│   ├── run_glue.py                # Standard GLUE training (baseline)
│   └── download_glue.py           # GLUE dataset downloader
├── eval/                          # Evaluation scripts
│   ├── inst_tuning_nat.py         # Natural Instructions evaluation
│   ├── inst_tuning_self_inst.py   # Self-Instruct evaluation
│   └── inst_tuning_unnat.py       # Unnatural Instructions evaluation
├── synthetic/                      # Synthetic experiments
│   └── Synthetic experiments.ipynb # Jupyter notebook for synthetic experiment
├── requirements.txt                # Python dependencies
└── README.md                      # This file
```

## Installation

1. Clone the repository:
```bash
git clone git@github.com:x-labs-xyz/feature-distillation.git
cd feature-distillation
```

2. Install dependencies:
```bash
pip install -r requirements.txt
```

## Quick Start

### GLUE Task Distillation

Train a distilled model on GLUE tasks:

```bash
python train/distill_glue.py \
    --task_name sst2 \
    --student_model bert-base-uncased \
    --teacher_model bert-large-uncased \
    --similarity linear \
    --gamma 0.6 \
    --output_dir ./output/sst2_distilled \
    --do_train \
    --do_eval \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 16 \
    --num_train_epochs 3 \
    --learning_rate 2e-5
```

### Instruction Tuning with Distillation

Train a model for instruction following:

```bash
python train/inst_tuning.py shape
```

Available loss functions: `shape`, `k_frob`, `cka`, `k_frob_only`

### Evaluation

Evaluate a trained model on instruction datasets:

```bash
# Natural Instructions
python eval/inst_tuning_nat.py /path/to/model

# Self-Instruct
python eval/inst_tuning_self_inst.py /path/to/model

# Unnatural Instructions
python eval/inst_tuning_unnat.py /path/to/model
```

## Core Components

### Distillation Framework (`train/distillation.py`)

The core distillation components:

- **DistilModel**: Wrapper that combines teacher and student models
- **DistillationLoss**: Loss function combining KL divergence and similarity measures
- **DistilTrainer**: Custom trainer for distillation training

### Similarity Measures (`train/similarity_measures.py`)

Available similarity measures:

- **Procrustes**: SVD-based Procrustes alignment with/without whitening
- **CKA**: Centered Kernel Alignment with a linear or Gaussian kernel
- **Euclidean**: Mean Squared Error with dimension matching using a linear projection
- **Energy**: Energy distance metric with iterative optimization

### Key Parameters

- **gamma**: Weight balancing similarity loss vs KL divergence (0-1)
- **similarity**: Similarity measure type (`linear`, `cosine`, `cka`, `euclidean`)
- **align_match**: Layer alignment configuration `[[student_layers], [teacher_layers]]`
- **include_targets**: Whether to include target labels in loss computation

## Datasets

### GLUE Tasks
Download GLUE datasets:
```bash
python train/download_glue.py --data_dir ./glue_data --tasks all
```

### Instruction Datasets
The framework supports:
- **Dolly-15k**: For instruction tuning
- **Natural Instructions**: For evaluation
- **Self-Instruct**: For evaluation  
- **Unnatural Instructions**: For evaluation

## Reproducibility

### Environment Setup
- Python 3.8+
- PyTorch 2.2.0
- Transformers 4.44.0
- CUDA support recommended

### Random Seeds
All scripts use fixed random seeds for reproducibility:
- Training: Seeds 10, 19, 42, 69, 99
- Evaluation: Multiple seeds for robust assessment

### Model Checkpoints
Models are saved periodically during training:
- Every 500 steps during instruction tuning
- End of each epoch
- Best model based on validation metrics

## Advanced Usage

### Custom Similarity Measures

Add new similarity measures by extending the base classes in `similarity_measures.py`:

```python
class CustomSimilarity(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        # Initialize parameters
        
    def forward(self, X, Y):
        # Implement similarity computation
        return similarity_value
```

### Layer Alignment

Configure custom layer alignments:

```python
# Align student layers [3,6,9] with teacher layers [6,12,18]
align_match = [[3, 6, 9], [6, 12, 18]]
```

### Multi-GPU Training

The framework supports multi-GPU training:
- Teacher model on GPU 1
- Student model on GPU 0
- Automatic gradient synchronization

## Evaluation Metrics

### GLUE Tasks
- Accuracy for classification tasks
- Pearson/Spearman correlation for regression
- Combined score for multi-metric tasks

### Instruction Tasks
- ROUGE-1, ROUGE-2, ROUGE-L
- Average across multiple random seeds
- Stemming enabled for robust evaluation

## Troubleshooting

### Common Issues

1. **Out of Memory**: Reduce batch size or use gradient accumulation
2. **CUDA Errors**: Ensure compatible PyTorch/CUDA versions
3. **Dataset Download**: Check internet connection for GLUE downloads

### Performance Tips

- Use mixed precision training (`--fp16`)
- Enable gradient checkpointing for large models
- Use data parallelism for multi-GPU setups

## License

This project is licensed under the Apache 2.0 License - see the LICENSE file for details.

