clu==0.0.1-alpha.2
flax==0.2.2
ml_collections==0.1.0
numpy==1.18.5
tensorflow-cpu==2.3.1  # Using tensorflow-cpu to have all GPU memory for JAX.
tensorflow-datasets==4.0.1
tensorflow-probability==0.11.1
