# requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image
jax>=0.4.16.0,<=0.4.25
jaxlib>=0.4.16.0,<=0.4.25
flax==0.7.4
chex==0.1.84
optax==0.1.7
dotmap==1.3.30
safetensors==0.4.2
blinker==1.4.0
# atari specific
gym==0.23.1
envpool==0.8.4
gymnax==0.0.5
# less sensitive libs
wandb
pytest
pygame
seaborn
numpy>=1.26.1
hydra-core>=1.3.2
omegaconf>=2.3.0
matplotlib>=3.8.3
pillow>=10.2.0
pettingzoo>=1.24.3
tqdm>=4.66.0
scipy<=1.12