wandb
clu
datasets
distrax
grain
matplotlib
seaborn
tensorflow
tensorflow-datasets
tf-keras
transformers
# for jax on GPU
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda]
