# Teacher-Student Learning Framework

A framework for experimenting with teacher-student learning paradigms in neural networks. This project allows you to train teacher models on separate tasks and then train a student model to learn from these teachers.

## Project Overview

This framework implements a teacher-student learning approach where:
1. A full model is trained on a complete dataset
2. Two teacher models are trained on separate subsets of the data (different tasks)
3. A student model with dual heads is trained to learn from both teachers

The project is designed to study knowledge transfer, catastrophic forgetting, and other phenomena in neural network learning.

## Project Structure

```
teacher_student/
├── config/              # Configuration templates
├── data/                # Dataset storage
├── documents/           # Documentation
├── experiments/         # Experiment configurations and results
│   ├── examples/        # Example experiment configurations
│   └── evaluations/     # Evaluation results
├── notebooks/           # Jupyter notebooks for analysis
├── scripts/             # Utility scripts
└── teacher_student/     # Core source code
    ├── cli.py           # Command-line interface
    ├── dataloader.py    # Dataset loading utilities
    ├── eval_sweeps.py   # Evaluation utilities
    ├── exp_control.py   # Experiment control and parameter sweeps
    ├── models.py        # Neural network model definitions
    ├── train_teacher_student.py  # Main training logic
    ├── training_routines.py      # Training procedures
    └── utils.py         # Utility functions
```

## Installation and Requirements

### Dependencies

- PyTorch
- PyTorch Lightning
- OmegaConf
- Pandas
- Papermill (for notebook execution)

### Setup

1. Clone the repository
2. Install the required dependencies
3. Ensure you have GPU support for PyTorch (recommended for faster training)

## Running Experiments

The framework provides a command-line interface for running experiments. Experiments are defined using configuration files in YAML format.

### Basic Usage

To run an experiment, use the `train` command:

```bash
python -m teacher_student.cli train --exp_id <experiment_id>
```

Where `<experiment_id>` is the name of the experiment folder in `experiments/examples/`.

### Example: Linear Model on MNIST

The repository includes an example experiment with linear models on the MNIST dataset:

```bash
python -m teacher_student.cli train --exp_id examples/linear
```

This will:
1. Train a full linear model on the complete MNIST dataset
2. Train two teacher models on separate subsets of MNIST (digits 0-4 and 5-9)
3. Train a student model with dual heads to learn from both teachers

### Parameter Sweeps

The framework supports two types of parameter sweeps:

1. **Independent Sweep**: Vary one parameter at a time
2. **Combination Sweep**: Explore all combinations of parameter values

To run a sweep, use the `--sweep_mode` option:

```bash
python -m teacher_student.cli train --exp_id examples/linear --sweep_mode independent
```

or

```bash
python -m teacher_student.cli train --exp_id examples/linear --sweep_mode combination
```

The sweep parameters are defined in a `sweep.yaml` file in the experiment directory. For example, the linear example sweeps over:

```yaml
control:
  hidden_size: [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000]
  set_seed: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
```

### Additional Options

- `--debug`: Enable debug mode
- `--cluster`: Enable cluster mode (reduces number of workers)
- `--scratch_path`: Specify a scratch path for saving results

## Experiment Configuration

Experiments are configured using YAML files. The configuration includes:

- Dataset and task definitions
- Model architectures and parameters
- Training parameters (epochs, batch size, learning rate, etc.)
- Control parameters for sweeps

Example configuration (`examples/linear/config.yaml`):

```yaml
dataset         : "MNIST" 
task0_labels    : [0, 1, 2, 3, 4]
task1_labels    : [5, 6, 7, 8, 9]

full_model  : "Linear"
full_kwargs :
  input_size      : 784
  hidden_size     : ${control.hidden_size}
  num_classes     : 10

teacher0_model  : "Linear"
teacher0_kwargs :
  input_size      : 784
  hidden_size     : ${control.hidden_size}
  num_classes     : 5

teacher1_model  : "Linear"
teacher1_kwargs :
  input_size      : 784
  hidden_size     : ${control.hidden_size}
  num_classes     : 5

student_model  : "DoubleHeadLinear"
student_kwargs :
  input_size      : 784
  hidden_size     : ${control.hidden_size}
  num_classes_per_head : 5

student_equal_heads: False

epochs_teacher  : 50
epochs_t0       : 50
epochs_t1       : 50
BATCH_SIZE      : 64
seed            : ${control.set_seed}
learning_rate   : 0.01
momentum        : 0.0
detect_anomaly_flag : False

control:
  hidden_size     : 10
  hidden_size_2   : 10
  set_seed        : 1
```

## Evaluation

The framework includes utilities for evaluating experiment results, particularly for parameter sweeps. The evaluation calculates metrics such as catastrophic forgetting (CF).

To evaluate a sweep experiment:

```python
from teacher_student.eval_sweeps import eval_sweeps

eval_sweeps(
    filepath="path/to/experiment/combinations",
    outpath="path/to/save/evaluation"
)
```

## Available Models

The framework includes several model architectures:

- **Linear**: Simple linear model
- **SimpleMLP**: Single hidden layer MLP with ReLU activation
- **TwoLayerMLP**: Two hidden layers MLP with ReLU activations
- **ThreeLayerMLP**: Three hidden layers MLP with ReLU activations
- **SimpleCNN**: Simple convolutional neural network
- **DoubleHead** variants: Models with two output heads for the student model

## Supported Datasets

Currently, the framework supports:

- **MNIST**: Handwritten digit recognition dataset

Additional datasets can be added by extending the `get_dataset` function in `dataloader.py`.

## License

[Specify license information here]

## Contributors

[List contributors here]
