[project]
authors = [
]
name = "jaxltl"
requires-python = ">= 3.12"
version = "0.1.0"
dependencies = [
    "optax>=0.2.6,<0.3",
    "equinox>=0.13.2,<0.14",
    "distrax>=0.1.7,<0.2",
    "jax>=0.7.2,<0.8",
    "graphviz>=0.21,<0.22",
    "jraph<=0.0.6dev0",
]

[build-system]
build-backend = "hatchling.build"
requires = ["hatchling"]

[tool.hatch.metadata]
allow-direct-references = true

[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64"]

[tool.pixi.pypi-dependencies]
jaxltl = { path = ".", editable = true }

[tool.pixi.tasks]

[tool.pixi.dependencies]
ruff = ">=0.11.13,<0.12"
python-dotenv = ">=1.1.0,<2"
jupyter = ">=1.1.1,<2"
pygame = ">=2.6.1,<3"
matplotlib = ">=3.10.6,<4"
hydra-core = ">=1.3.2,<2"
pandas = ">=2.3.3,<3"
seaborn = ">=0.13.2,<0.14"
pytest = ">=8.4.2,<9"
joblib = ">=1.5.2,<2"
tqdm = ">=4.67.1,<5"

[tool.pixi.feature.gpu.system-requirements]
cuda = "12"

[tool.pixi.feature.gpu.pypi-dependencies]
jax = { version = "*", extras = ["cuda12"] }

[tool.pixi.environments]
gpu = ["gpu"]

[tool.ruff.lint]
fixable = ["ALL"]
ignore = [
    "B018", # Useless statement (used in notebooks)
    "D100", # Missing docstring in public module
    "E402", # Module level import not at top of file (used in notebooks)
    "E501", # Line too long
    "E731", # Do not assign a lambda expression, use a def
    "B008", # Do not perform function calls in argument defaults
]
select = [
    # pyflakes
    "F",
    # pycodestyle
    "E",
    "W",
    # flake8-builtins
    "A",
    # flake8-bugbear
    "B",
    # flake8-comprehensions
    "C4",
    # flake8-simplify
    "SIM",
    # flake8-unused-arguments
    "ARG",
    # pylint
    "PL",
    # tidy
    "TID",
    # isort,
    "I",
    # pep8-naming
    "N",
    # pyupgrade
    "D100",
    "UP",
]
[tool.ruff.lint.pylint]
max-args = 8
