# requirements of vmoe
absl-py>=0.12.0
cachetools>=5.0.0
chex>=0.0.7
clu>=0.0.6
jax>=0.2.25
flax>=0.3.6
ml-collections==0.1.0
numpy>=1.19.5
optax>=0.1.0
scipy>=1.4.0
# Using tensorflow-cpu to have all GPU memory for JAX.
# this requirement is superceded by the one below
# tensorflow-cpu>=2.4.0
tfds-nightly
# additional jax packages
jax[cuda12]==0.4.28
flax==0.8.4
optax==0.2.2

# for asyncio/cloud_tpu
keras<3.0.0
tensorflow-cpu<2.16

# port the tf.contrib.image code
tensorflow-addons==0.23.0

# for vision transformer
tensorflow-text>=2.9.0
git+https://github.com/google-research/vision_transformer.git
