Equivariant Strong Lottery Tickets
====

This is the code repository for the reproducibility of the experiments
that accompany the ICLR 2023 submission with title:  
**A General Framework For Proving The Equivariant Strong Lottery Ticket Hypothesis**

The code is responsible for training neural networks with equivariant components
in supervised tasks defined by various datasets (`main_train.py`),
and for pruning overparameterized versions of pretrained networks to approximate the target networks (`main_prune.py`).
The pruning method in particular implements the procedure that **Lemma 1** describes
in the paper.

The code is compatible with `Python/3.8` and the requirements can be found in `requirements.txt`.

Third-party software
---

The pruning procedure relies on third-party software to solve the desired MIPs of
the Subset-Sum problems. The user can optionally install the commercial software Gurobi,
and its Python interface, if they have access to a valid use license. Otherwise,
we have listed the open-source OR-Tools software suite for optimization in the `requirements.txt`
file. The user can select which third-party software they would like to use by specifying
the command line argument `--solver` of `main_prune.py`.

Furthermore, we make use of external third-party repositories, available on Github,
which we enlist in the directory `submodules`.
Python modules from these open-source repositories have been modified to our needs
and linked symbolically to appropriate places in our codebase.
Specifically, we have replicated and adapted:
 - The file [QUVA-Lab/e2cnn](https://github.com/QUVA-Lab/e2cnn)`/e2cnn/nn/modules/r2_conv/r2convolution.py` at `equislt/r2convolution.py`, in order to adapt the code to support the equivariant basis we describe in **section 4.2**.
 - The repository [QUVA-Lab/e2cnn_experiments](https://github.com/QUVA-Lab/e2cnn_experiments) in `submodules/e2cnn_experiments` at commit `0c8f275be0361367c52d2d268471ac32f39fe3f3`, to help us define and reuse the code for RotMNIST and FlipRotMNIST datasets.
 - The repository [Haggaim/InvariantGraphNetworks](https://github.com/Haggaim/InvariantGraphNetworks) in `submodules/equivariant_graphs` at commit `f4be7cdecf8e054dbcdf8e5a78881bd3f2c16486`, to reuse code for k-order permutation equivariant graph neural networks.

Repository structure
---

- `main_train.py`: Main file for training a model with equivariant layers, various architectures on various datasets.  
- `main_prune.py`: Main file that accepts a path to an experiment executed by `main_train.py` via `--target_net_dir`, creates an overparameterized version of the target pretrained network and prunes it to approximate the target.  
- `requirements.txt`: To install Python packages that are necessary for this code's execution.  
- `README.md`: The current document.  
- `submodules/`: Adaptation of third-party open-source codebases for our needs.  
- `equislt/`: Python module which implements modularly functionality for our experiments.  
  * `equislt/data/`: Definition of data loaders for various datasets.
  * `equislt/args.py`: Command line argument handling.
  * `equislt/find_subset_gurobi.py`: Solving the Subset-Sum MIP problem with Gurobi.
  * `equislt/find_subset_ortools.py`: Solving the Subset-Sum MIP problem with OR-Tools.
  * `equislt/graph_equivariant.py`: Architecture for the k-order GNN experiments.
  * `equislt/r2convolution.py`: Architecture for the E(2)-CNN experiments.
  * `equislt/methods/`: Implement the training and pruning methods for the three architectures considered: Graph Convolutional NNs (`.../gcn.py`), k-order equivariant Graph NNs (`.../ign.py`), E(2)-CNNs (`.../e2cnn.py`)

Example use
---

For training:
```bash
python main_train.py --dataset=PROTEINS --data_dir=$DATASET_DIR --method=ign \
  --save_checkpoint --checkpoint_dir=$TARGET_NET_DIR --seed=$RANDOM \
  --devices=1 --accelerator=gpu --gpus=0 \
  --lr=1e-3 --weight_decay=5e-4 --max_epochs=200
```

For pruning:
```bash
python main_prune.py --dataset=PROTEINS --data_dir=$DATASET_DIR --method=ign \
  --num_workers=$NUM_WORKERS --num_threads=1 --overparam_factor=5.0 --seed=$RANDOM \
  --target_net_dir=$TARGET_NET_DIR --solver=ortools --eps=0.01
```
