ml_collections>=0.1.1
optax>=0.1.1
dm-haiku>=0.0.6
jaxline>=0.0.5
tensorflow>=2.8.0
tensorflow_datasets>=4.5.2
