flax==0.5.2
gymnax==0.0.4
matplotlib==3.5.0
numpy==1.21.4
optax==0.1.2
tqdm==4.62.3
transformers==4.20.1
# No Jax specified
