jax>=0.4.6
jaxlib>=0.4.6
numpy>=1.24.0
matplotlib>=3.7.0
e3nn-jax>=0.17.0
flax>=0.7.0
optax>=0.1.7
simple-parsing>=0.1.0
scikit-learn>=1.2.0 