jax[cuda12]==0.5.3
flax==0.10.0
numpy==1.26.3
tensorflow-probability==0.24.0
d4rl