[project]
name = "diffusion-co-design"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
    "ipykernel>=6.29.5",
    "matplotlib>=3.10.0",
    "numpy>=2.2.2",
    "guided-diffusion",
    "tensordict>0.9,<0.10",
    "torchrl>0.9,<0.10",
    "torch>=2.6.0",
    "torchvision>=0.21.0",
    "moviepy==1.0.3", # Temp pin as until wandb updates
    "wandb[media]>=0.19.5",
    "hydra-core>=1.3.2",
    "pydantic>=2.10.6",
    "torch-geometric>=2.6.1",
    # Manually install torch_scatter and torch_cluster
    # uv pip install torch-scatter -f https://data.pyg.org/whl/torch-2.7.0+cu128.html --no-build-isolation
    # uv pip install torch-cluster -f https://data.pyg.org/whl/torch-2.7.0+cu128.html --no-build-isolation
    "wfcrl",
    "seaborn>=0.13.2",
    "segnn",
]

# VMAS and RWARE are incompatible with each other
[project.optional-dependencies]
rware = [ # Pyglet 2.1
  "rware[pettingzoo, dev]",
]
vmas = [
  "vmas[render]"
]

[dependency-groups]
dev = [
    "mypy>=1.14.1",
    "pytest>=8.3.4",
    "ruff>=0.9.3",
    "expecttest>=0.3.0",
]


[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.uv.workspace]
members = [
  ".",
  "packages/guided_diffusion",
  "packages/wfcrl-env",
  "experiments/*",
  "packages/segnn",
]

[tool.uv]
conflicts = [
  [
    { extra = "rware" },
    { extra = "vmas" },
  ]
]


[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.uv.sources]
torch = [
  { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchvision = [
  { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
diffusion_co_design = { workspace = true }
guided_diffusion = { workspace = true }
wfcrl = { workspace = true }
vmas = { git = "ANON!", branch = "dicode" }
rware = { git = "ANON!" }
segnn = { workspace = true }