[build-system]
requires = ["setuptools>=68"]
build-backend = "setuptools.build_meta"

[project]
name = "ulee-repo"
version = "0.1.0"
description = "ULEE research code"
readme = "README.md"
requires-python = ">=3.10,<3.13"
dependencies = [
    "jax[cuda12] >=0.4.38,<0.5", 
    "flax >=0.10.3,<0.11",
    "distrax >=0.1.5,<0.2",
    "optax >=0.2.4,<0.3",
    "chex >=0.1.88,<0.2",
    "orbax-checkpoint >=0.11.5,<0.12",
    "xminigrid >=0.9.1,<1.0",
    "tqdm >=4.67.1,<5.0",
    "numpy >=2.2.1,<3.0",

    #"nvidia-cuda-runtime-cu12==12.6.77",
    #"nvidia-cuda-cupti-cu12==12.6.80",
    #"nvidia-cuda-nvcc-cu12==12.6.85",
    #"nvidia-nvjitlink-cu12==12.6.85",
    #"nvidia-cublas-cu12==12.6.4.1",
    #"nvidia-cusolver-cu12==11.7.1.2",
    #"nvidia-cusparse-cu12==12.5.4.2",
    #"nvidia-cufft-cu12==11.3.0.4",
    #"nvidia-cudnn-cu12==9.6.0.74",
    #"nvidia-nccl-cu12==2.24.3"
]


[project.optional-dependencies]
experiments = [
    "jupyter",
    "ipykernel",
    "matplotlib",
    "wandb",
]

[tool.setuptools.packages.find]
where = ["src"]