absl-py>=0.12.0
aqtp
chex>=0.0.7
clu>=0.0.3
datasets
einops>=0.3.0
flax>=0.6.4
ml-collections>=0.1.0
numpy>=1.19.5
pandas>=1.1.0
tensorflow-cpu>=2.4.0
tensorflow-datasets>=4.0.1
tensorflow-probability>=0.11.1
tensorflow-text>=2.9.0
torch
torchvision
scikit-learn
matplotlib
tqdm
augmax
optax
dataclasses
argparse
wandb
timm
tqdm
wandb
scikit-learn
transformers
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]
git+https://github.com/google/flaxformer.git