## Requirements

This codebase contains the code for the ICLR2025 submission "Grounding Continuous Representations in Geometry: Equivariant Neural Fields."
We include code for all experiments on CIFAR10 and OMBRIA datasets. Code for the other experiments will be released during the review process.

To install requirements:

Create conda env
```
conda create -n enf-iclr python=3.11
conda activate enf-iclr
```

Install dependencies
```
conda install pytorch torchvision torchaudio cpuonly -c pytorch
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install optax flax orbax wandb ml_collections matplotlib tqdm
```

## Data

All datasets are downloaded automatically. Code for datasets can be found under `/utils/datasets/__init__.py` and `/utils/datasets/ombria.py`.

## Training

We specify commands per experiment, these hyperparameters can also be found in the appendix.

**CIFAR10**
The bi-invariant can be specified with `--config.nef.bi_invariant=R2`, `--config.nef.bi_invariant=SE2` or `--config.nef.bi_invariant=abs`.
```bash
export PYTHONPATH=. && python fit_enf_cifar.py --config.nef.bi_invariant=R2
export PYTHONPATH=. && python fit_class_cifar.py
```

**OMBRIA**
First, we fit a backbone model to the OMBRIA dataset. Then, we fit the segmentation model.
```bash
export PYTHONPATH=. && python fit_enf_ombria.py
export PYTHONPATH=. && python fit_seg_ombria.py
```

Subsampling experiments can be run with the `--config.dataset.mask_points` flag. For example, to run with 1/4 subsampling:
```bash
export PYTHONPATH=. && python fit_seg_ombria.py --config.dataset.mask_points=True --config.dataset.num_mask_points=0.25
```

This will automatically subsample only the test set.

Resolution transfer experiments can be run with the `--config.dataset.tr_resolution`, `--config.dataset.te_resolution` flags. For example to train the model on 128x128 and test on 256x256:
```bash
export PYTHONPATH=. && python fit_seg_ombria.py --config.dataset.tr_resolution=128 --config.dataset.te_resolution=256
``` 