
## Installation

The package can be installed with pip. If using a GPU for computations, install jax first

```bash
pip install jax[cuda12]
```

then install the package

```bash
python -m pip install .
```

## Dependencies

This package is built on JAX. In addition, we use the NMI clustering metric and the PCA decomposition implemented in Scikit-Learn.


## Acknowledgements

(suppressed for Anonymous submission)
