-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# torch-cpu is sufficient since we only use it for data loading
--extra-index-url https://download.pytorch.org/whl/cpu

chex==0.1.7
ConfigArgParse==1.7
einops==0.6.1
flax==0.6.10
flaxmodels==0.1.3
jax==0.4.11
jaxlib==0.4.11+cuda11.cudnn86
jaxtyping==0.2.20
ml-dtypes==0.1.0
numpy==1.24.1
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvcc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==8.9.2.26
nvidia-cufft-cu11==10.9.0.58
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
opt-einsum==3.3.0
optax==0.2.0
orbax-checkpoint==0.2.4
scipy==1.10.1
timm==0.9.6
torch==2.0.1+cpu
torchaudio==2.0.2+cpu
torchmetrics==1.2.0
torchvision==0.15.2+cpu
tqdm==4.65.0
transformers==4.46.3
Pillow==10.0.0
wandb
git+https://github.com/alebeck/chunkax.git@0.1.0
