-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
jax[tpu]==0.4.13
jaxlib==0.4.13
flax==0.5.3
numpy
matplotlib
tensorflow
tensorflow_datasets
pillow
ml_collections
clu
jax-smi
tensorboardx