jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.8
-f XXXX
torch==1.9.0+cpu 
-f XXXX
torchvision==0.10.0+cpu