# Graph Neural Networks for Learning Equivariant Representations of Neural Networks

Official implementation for
<pre>
<b>Graph Neural Networks for Learning Equivariant Representations of Neural Networks</b>
Anonymous
<em>Submitted to ICLR 2024</em>
</pre>



Our codebase is based on [_Equivariant Architectures for Learning in Deep Weight Spaces_](https://arxiv.org/abs/2301.12780) by Aviv Navon, Aviv Shamsian, Idan Achituve, Ethan Fetaya, Gal Chechik, Haggai Maron.
<p align="center">
    <img src=assets/dws/sym.png  height="400">
</p>

## Setup environment

To run the experiments, first create clean virtual environment and install the requirements.

```bash
conda create -n nns-are-graphs python=3.9
conda activate nns-are-graphs
conda install pytorch==2.0.1 torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
conda install pyg -c pyg
pip install hydra-core einops
```

Install the repo:

```bash
pip install -e .
```

## Introduction Notebook
An introduction notebook for INR classification with DWSNets:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AvivNavon/DWSNets/blob/main/notebooks/mnist-inr-classification.ipynb)

## Run experiment

To run specific experiment, please follow the instructions in the README file within each experiment folder.
It provides full instructions and details for downloading the data and reproducing the results reported in the paper.

## Datasets
**The datasets are available [here](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0).**
- [MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=mnist-inrs.zip)

### MNIST

Download from link available at https://github.com/AvivNavon/DWSNets and do the setup following https://github.com/AvivNavon/DWSNets/tree/main/experiments/mnist.

### CIFAR10

Download from the link at https://github.com/AllanYangZhou/nfn/tree/main/experiments#stylizing-sirens


## running the script

first

```
cd experiments/mnist
```

then either

```
python main.py data=mnist data.path=<data path> data.statistics_path=<stats file>
```

or

```
python main.py data=cifar
```

## NOTES

- [ ] graph probe features standalone work if no normalization is applied
- [ ] With augmentation (& no norm) it works, but slower at the beginning (maybe also worse top)

- [ ] using the standard weight & bias graph without norm results in a flat training loss curve
- [ ] applying a ln over all nodes and edge features (globally for nodes and edges respectively)

- [ ] Implement `extra_layer` functionality for CNNs with non-adaptive max pooling before linear layers
- [ ] Edges for `extra_layer`
