jax[cuda12]==0.4.28
flax==0.8.4
optax==0.2.4
distrax==0.1.5
numpy==1.24.1
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.12.0+cu113
torchvision==0.13.0+cu113
wandb
wandb[media]
moviepy<2
tqdm
gym==0.23.1
gym-minigrid
mjrl @ git+https://github.com/aravindr93/mjrl@master
d4rl==1.1
Cython==0.29.33
ogbench
shapely
matplotlib==3.7.5
hydra-core
einops
scikit-learn
pyrallis
ipympl
faiss-gpu
ml_collections





