#jax>=0.4.0
#flax>=0.7.0
#torch>=1.9.0
#numpy>=1.19.0
#scipy>=1.7.0
#scikit-learn>=1.0.0
#tqdm>=4.62.0
#matplotlib>=3.4.0
#pyyaml>=5.4.0
#importlib-metadata>=6.0.0