# Symmetries in PAC-Bayes Learning

**Anonymous authors**
*(Paper under review at AISTATS 2026)*
[Link to paper — *to be added after review*]

---

## Overview

This repository accompanies the paper **“Symmetries in PAC-Bayes Learning”**, which investigates how incorporating group symmetries into neural architectures can *tighten PAC-Bayes generalization bounds*. Specifically, we show that equivariant models — by respecting underlying symmetries in the data — achieve smaller Kullback–Leibler (KL) divergences between posterior and prior distributions, thereby yielding *sharper McAllester-style PAC-Bayes bounds*.

Our experiments focus on the **Rotated MNIST** and **Rotated+Translated MNIST** benchmarks, where we compare:

* A **Baseline CNN** (standard convolutional neural network)
* An **Equivariant CNN (G-CNN)** implemented using **e2cnn**, enforcing rotation and translation symmetries

Across both datasets, the equivariant models consistently demonstrate smaller KL divergences, validating the theoretical claim that respecting symmetries leads to tighter PAC-Bayes bounds.

---

## Repository Structure

The repository is organized as follows:

```
Symmetries_in_PAC_Bayes/
├── rotated_mnist/            # pre-generated dataset splits and trained models
│   ├── baseline_cnn.pt 
│   ├── equivariant_cnn.pt 
│   ├── histogram_overlay.png
│   ├── prior_mu_baseline.pt 
│   ├── prior_mu_equivariant.pt
│   └── rotated_examples.png 
├── rotated_translated_mnist/ # pre-generated dataset splits and trained models
│   ├── baseline_cnn.pt 
│   ├── equivariant_cnn.pt 
│   ├── histogram_overlay.png
│   ├── prior_mu_baseline.pt 
│   ├── prior_mu_equivariant.pt
│   └── rotated_examples.png 
├── dataset.py                # dataset utilities
├── generate_rotated_mnist.py # creates rotated MNIST with train/val/prior/test splits
├── generate_rotated_translated_mnist.py # same for rotation+translation
├── models.py                 # CNN and G-CNN (e2cnn) architectures
├── pacbayes_utils.py         # KL computation, Gibbs risk, PAC-Bayes bound
├── train_baseline.py         # train baseline CNN
├── train_equivariant_cnn.py  # train equivariant CNN
├── prior_mu.py               # training the data-dependent Gaussian prior
├── plot.py                   # visualization utilities (PAC-Bayes histograms)
```

Pretrained weights and datasets are provided for reproducibility.

---

## Experimental Design

### Datasets and Splits

Each dataset (rotated and rotated+translated MNIST) is generated with four independent splits:

* **train**: model training
* **val**: early stopping and hyperparameter tuning
* **prior**: used to fit a data-dependent Gaussian prior
* **test**: final evaluation and PAC-Bayes bound estimation

All transformations are uniformly sampled from the range:

* Rotation angles: ( [-\pi/2, \pi/2] ) (i.e., no upside-down digits)
* Optional translations for the second dataset: small random shifts in both axes

The script `generate_rotated_mnist.py` creates these splits and stores them as `.pt` tensors.

---

### Priors and Posteriors

To ensure the prior and posterior remain statistically independent (as required for PAC-Bayes analysis):

* **Prior**: A *data-dependent Gaussian prior* is learned using the held-out *prior set*.
  The mean ( \mu_P ) is obtained from a model trained solely on the prior set.
  The variance ( \sigma_P ) is estimated from the empirical variability of the learned weights.

* **Posterior**: Approximated by a *variational Gaussian* with mean ( \mu_Q ) and variance ( \sigma_Q ), trained via variational inference to minimize the PAC-Bayes (McAllester) bound.

The KL divergence ( D_{\mathrm{KL}}(Q | P) ) is computed analytically for diagonal Gaussians.

---

### PAC-Bayes Bound (McAllester)

We employ the **McAllester bound** for the Gibbs classifier:

[
R(Q) \le \hat{R}*S(Q) + \sqrt{\frac{D*{\mathrm{KL}}(Q|P) + \ln\frac{2\sqrt{n}}{\delta}}{2(n - 1)}}
]

where

* ( \hat{R}_S(Q) ) is the empirical Gibbs risk estimated by Monte Carlo sampling,
* ( n ) is the number of test examples,
* ( \delta ) is the confidence parameter (typically (10^{-5})).

The experiments demonstrate that for the same dataset and architecture depth, the equivariant model yields a smaller KL term and a tighter bound.

---

## Reproducing the Main Figure

The main figure (`histogram_overlay.png`) visualizes:

* The empirical distribution of errors under the **posterior** for both models
* The corresponding **PAC-Bayes bound** derived from the McAllester inequality

To reproduce this visualization (without retraining), follow these steps:

1. **Generate the datasets** by running:

   ```bash
   python generate_rotated_mnist.py
   python generate_rotated_translated_mnist.py
   ```

   These scripts load the MNIST dataset and apply the rotations and translations as described earlier.

2. **Plot the results** by running:

```bash
python plot.py
```

This script:

1. Loads the precomputed results (posterior samples and bound components)
2. Plots the error distribution histogram for the baseline and equivariant models
3. Overlays the analytical PAC-Bayes bound as a vertical line
4. Saves the figure to `rotated_mnist/histogram_overlay.png` or `rotated_translated_mnist/histogram_overlay.png` respectively.

---

## Training from Scratch

If readers wish to regenerate results rather than use the stored weights, follow these steps:

### 1. Generate Datasets

```bash
python generate_rotated_mnist.py
python generate_rotated_translated_mnist.py
```

### 2. Train the Prior

Train a small CNN on the prior split:

```bash
python prior_mu.py --dataset rotated_mnist
python prior_mu.py --dataset rotated_translated_mnist
```

This creates `prior_mu_baseline.pt` and `prior_mu_equivariant.pt` in each dataset directory.

### 3. Train Models

Train the baseline and equivariant models:

```bash
python train_baseline.py --dataset rotated_mnist
python train_equivariant_cnn.py --dataset rotated_mnist
```

### 4. Plot Results

Finally, reproduce the histogram:

```bash
python plot.py
```

---

## Dependencies

This project uses only standard scientific Python libraries:

* Python ≥ 3.10
* PyTorch ≥ 2.0
* torchvision
* e2cnn
* numpy
* matplotlib


---

## Key Results

| Dataset                  | Model Type        | KL Term ↓ | Bound Tightness ↑ |
| ------------------------ | ----------------- |----------:|------------------:|
| Rotated MNIST            | Baseline CNN      |      High |      Looser bound |
| Rotated MNIST            | Equivariant G-CNN |     Lower |     Tighter bound |
| Rotated+Translated MNIST | Baseline CNN      |      High |      Looser bound |
| Rotated+Translated MNIST | Equivariant G-CNN |     Lower |     Tighter bound |

These quantitative trends support the theoretical claim: *incorporating symmetry reduces the complexity penalty and improves the PAC-Bayes bound.*

---

## Citation

```
Anonymous Authors (2026). Symmetries in PAC-Bayes Learning. In Proceedings of the 29th International Conference on Artificial Intelligence and Statistics (AISTATS 2026).
```


---

## License and Reproducibility

All code and datasets are released for **academic research** purposes only.
Seed values are fixed for reproducibility. The pretrained weights included in `rotated_mnist/` and `rotated_translated_mnist/` ensure that all results, figures, and bounds can be reproduced deterministically using the provided scripts.
