ogbench
jax[cuda12] >= 0.4.26
flax >= 0.8.4
distrax >= 0.1.5
ml_collections
matplotlib
moviepy
wandb
opencv-python
