jax>=0.4.0
flax>=0.7.0
optax>=0.1.0
diffrax>=0.4.0
ott-jax>=0.4.0
numpy>=1.20.0
scipy>=1.7.0
scikit-learn>=1.0.0
POT>=0.8.0
pandas>=1.3.0
