flake8==3.9.0
pytest==6.2.3
mypy==0.812
torch==1.8.1
einops==0.3.0
jax>=0.2.16
jaxlib>=0.1.68
flax>=0.2.2