# Core PyTorch ecosystem
torch==2.2.1
torchvision
torchaudio
# PyTorch Geometric stack (requires wheel index URL for CUDA 12.1)
torch-scatter
torch-sparse
torch-cluster
torch-spline-conv
torch-geometric
# Use with: pip install -r requirements.txt -f https://data.pyg.org/whl/torch-2.2.1+cu121.html

# Utilities
torch-summary
Cython
numpy
scikit-learn
networkx
scipy
notebook
pytest
eagerpy
seaborn
icecream

# JAX (with CUDA 12 pip wheels)
jax[cuda12_pip]
# Use with: pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Optimal Transport Tools
ott-jax
