jax[cuda11_cudnn805]==0.3.23
neural-tangents==0.6.1
hydra-core==1.2
flax==0.6.1
chex==0.1.5
torch
torchvision
flatdict
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
