jax[cuda12]>=0.5.3
flax>=0.10.6
torch>=2.7.0
torchvision>=0.22.0
numpy>=2.2.5
scipy>=1.15.3
scikit-learn>=1.6.1
tqdm>=4.67.1
matplotlib>=3.10.3
pyyaml>=6.0.2
importlib-metadata>=8.7.0