# Training probabilistic models with R2-G2 gradients.

This repository contains code to train BNNs and VAEs with the R2-G2 estimator (i.e. Rao-Blackwellised Reparameterisation gradients).

## Environment Setup
Create a python3.10.6 virtualenv and start a terminal. Uncomment lines in the requirements.txt file if installing PyTorch with CUDA support.

To install with pip:

```
pip install -r requirements.txt
```

## Directory Setup for Omniglot dataset
Setup the data folder at the root of this repository with the structure
```
/data/omniglot_npy/
```
Load the train and test splits for the Omniglot from Burda et al. 2026, found in:
https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat

Save the splits as npy files in the above subfolder with names:

* omniglot_xs_train.npy
* omniglot_ys_train.npy
* omniglot_xs_test.npy
* omniglot_ys_test.npy

## Experiment run setup
cd to the root of this repository before running any of the below experiments.
```
cd ~/<path_to_directory_containing_this_repo>/r2g2
```

## BNN examples
Run Bayesian MLP on MNIST with R2-G2:
```
python -m experiments.bnn_mnist --seed 123 --reparam "r2g2" --fwd "local" --batch_size_train 80 --batch_size_test 10000
```

Run Bayesian MLP on MNIST with LRT:
```
python -m experiments.bnn_mnist --seed 123 --reparam "lrt" --batch_size_train 80 --batch_size_test 10000
```

Run Bayesian MLP on MNIST with RT:
```
python -m experiments.bnn_mnist --seed 123 --reparam "rt" --batch_size_train 80 --batch_size_test 10000
```

Run Bayesian CNN on CIFAR-10 with R2-G2:
```
python -m experiments.bnn_cifar10 --seed 123 --reparam "r2g2" --batch_size_train 80 --batch_size_test 10000
```

Run Bayesian CNN on CIFAR-10 with RT:
```
python -m experiments.bnn_cifar10 --seed 123 --reparam "rt" --batch_size_train 80 --batch_size_test 10000
```

## VAE examples
Run one-layer VAE experiment with R2-G2:
```
python -m experiments.1vae --dataset "mnist" --seed 123 --reparam "r2g2" --batch_size_train 80 --batch_size_test 20
```

Run one-layer VAE experiment with RT:
```
python -m experiments.1vae --dataset "mnist" --seed 123 --reparam "rt" --batch_size_train 80 --batch_size_test 20
```

Run two-layer VAE experiment with R2-G2:
```
python -m experiments.2vae --dataset "mnist" --seed 123 --reparam "r2g2" --batch_size_train 80 --batch_size_test 20
```

Run two-layer VAE experiment with RT:
```
python -m experiments.2vae --dataset "mnist" --seed 123 --reparam "rt" --batch_size_train 80 --batch_size_test 20
```

Run three-layer VAE experiment with R2-G2:
```
python -m experiments.3vae --dataset "mnist" --seed 123 --reparam "r2g2" --batch_size_train 80 --batch_size_test 20
```

Run three-layer VAE experiment with RT:
```
python -m experiments.3vae --dataset "mnist" --seed 123 --reparam "rt" --batch_size_train 80 --batch_size_test 20
```

## Acknowledgements
This repository contains code from the following sources:

* https://github.com/CW-Huang/CP-Flow
