# Bidirectional Predictive coding

## Overview
This repository houses the implementation and experiments presented in our paper, where we propose a novel computational framework called bidirectional predictive coding. 

## Key Contributions
- We propose bidirectional predictive coding, a biologically plausible model of visual perception
employing both generative and discriminative inference that naturally arises from minimising a single energy function.

- We show that bPC performs as well as its purely discriminative or generative counterparts on both supervised classification tasks and unsupervised representation learning tasks, and outperforms precedent hybrid predictive coding models.

- We provide an explanation for the superior performance of bPC in both tasks, by comparing it to unidirectional models and showing that it learns an energy landscape which benefits detecting out-of-distribution data points.

- We further show that bPC outperforms other PC models in two biologically relevant tasks, including multimodal learning and inference with occluded visual scenes, indicating its potential as a more faithful model for visual inference in the brain.


## Structure of the repository
The repository includes:
- `train_mlp_*` and `train_cnn_*` contain the code to train all the models considered in our experiments.
- `figure3/`, `figure4/`, `figure5/`, `figure7_top`, `SM_C/`, `SM_D/`, `SM_E/`, `SM_G/`, and `SM_I/` contain the code to rerun the hyperparameter tuning and retrain the models with optimal parameters reported in our experiments.
- `figure6_*.py` and `figure7_*.py` contain the code to recreate figures 6 and 7.
- `requirements.txt` contains the python dependencies of this repository.
- `utils_pcax/` contains utility functions
- `pcax/` contains the PCX library to simulate PC models in JAX.
- `inception_score/` contains code to measure the inception score of generated images.
- `pytorch_fid/` contains the code to measure the FID of generated images.

## Different models names from the paper
The code follows a slightly different naming convention than the paper:
- bPC = bPC
- genPC = dPC (down-PC)
- discPC = uPC (up-PC)
- genBP = dBP
- discBP = uBP
- hybridPC = hybrid-dPC
- hybridBP = hybridBP
- AE = AE 
- agPC = agPC
- shared bPC = s-bPC
- decay discPC = d-uPC
- bimodal genPC = mm (multi-modal)
- bimodal bPC = 3-bPC (3-layered bPC)

The configuration required to run each model can be found in the configuration files used to train the models.


## Usage
Follow the following steps to clone the code and setup the necessary python libraries:

```bash
git clone tbd
cd bPC
```
Browse to https://github.com/liukidar/pcx and https://docs.jax.dev/en/latest/quickstart.html and follow JAX installation steps.
The PCX library is not compatible with JAX version above JAX v0.4.33. We ran our experiments in JAX v0.4.28. To install this version of JAX for cuda 12:
```bash
conda create -n bpcx python=3.10
conda activate bpcx
pip install jax==0.4.28 jaxlib==0.4.28+cuda12.cudnn89 --no-deps -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt
```





To retrain the models with optimal parameters over 5 seeds.
```bash
wandb sweep path_to_config.yaml # returns sweep_id and wandb agent command
wandb agent user_name/project_name/sweep_id
```

To generate the figures from figure 6 or 7 please run, e.g.:
```bash
python figure6_middle.py
```

## Citation
For those who find our work useful, here is how you can cite it:

```bibtex
@article {TBD,
	author = {TBD},
	title = {Bidirectional Predictive coding},
	year = {2025},
	doi = {},
	publisher = {},
	URL = {},
}

```

## Contact
For any inquiries or questions regarding the project, please feel free to contact name at @email.

## Code Aknowledgements
This repository builds upon the following repositories/codes:
- https://github.com/liukidar/pcx
- https://github.com/sundyCoder/IS_MS_SS
- https://github.com/mseitzer/pytorch-fid

and it heavily depends on Weights and Biases (https://github.com/wandb/wandb).
