jax==0.4.28
matplotlib==3.5.2
numpy==1.21.5
scikit_learn==1.0.2
