jax>=0.3.15
jaxlib>=0.3.15
torch>=1.10.1
torchvision>=0.11.2
numpy>=1.21.6
matplotlib>=3.5.3
scipy>=1.7.3
flax>=0.6.0
optax>=0.1.3
hydra-core>=1.1.0
