[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "example"
version = "0.1.0"
description = "Factored World Hypothesis experiments"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "hydra-ray-launcher>=1.2.1",
    "ipykernel>=7.1.0",
    "nbformat>=5.10.4",
    "pydantic >= 2.12.0",
    "transformer-lens==2.15.4",
    "transformers==4.57.3",
    # fwh_core dependencies
    "jax>=0.4.0",
    "chex>=0.1.0",
    "equinox>=0.11.0",
    "mlflow>=2.0.0",
    "altair>=5.0.0",
    "matplotlib>=3.0.0",
    "plotly>=5.0.0",
    "Pillow>=9.0.0",
    "scipy>=1.10.0",
    "torch>=2.0.0",
    "omegaconf>=2.3.0",
    "hydra-core>=1.3.0",
    "tqdm>=4.0.0",
]

[project.optional-dependencies]
dev = ["jaxtyping", "nbqa", "pyright", "pytest", "pytest-cov", "ruff"]
gpu = ["jax[cuda12]>=0.6.0"]
cuda = ["jax[cuda12]>=0.6.0"]
pytorch = ["torch>=2.0.0"]
mac = ["jax[cpu]>=0.6.0"]
analysis = ["ipykernel>=7.1.0", "nbformat>=5.10.0"]

[tool.setuptools.packages.find]
include = ["experiments*", "fwh_core*"]

[tool.ruff]
line-length = 120
target-version = "py312"

[tool.ruff.lint]
ignore = [
    "D100",   # undocumented-public-module
    "D105",   # Missing docstring in magic method
    "D107",   # Missing docstring in __init__
    "SIM108", # Use the ternary operator
]
select = [
    "A",   # flake8-builtins
    "B",   # flake8-bugbear
    "D",   # pydocstyle https://www.pydocstyle.org/en/stable/error_codes.html
    "E",   # pycodestyle
    "F",   # Pyflakes
    "I",   # isort
    "PT",  # flake8-pytest-style
    "SIM", # flake8-simplify
    "UP",  # pyupgrade
]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.per-file-ignores]
"test_*.py" = ["D"]
"*.ipynb" = ["D"]

[tool.pyright]
typeCheckingMode = "standard"
