scipy
jax == 0.2.24
neural-tangents == 0.3.8