matplotlib==3.5.1
autograd==1.3
termcolor==1.1.0
wandb==0.12.9
# jax[cpu]==0.2.27
flax==0.4.0
optax==0.1.0
seaborn==0.11.2
pytest==7.1.2
tqdm==4.64.0
