pip
nvidia-cuda-nvcc-cu12
nvidia-cudnn-cu12
jax[cuda12]
flax
orbax
tqdm
matplotlib
pandas
wandb