[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "tabpfn"
version = "2.0.0"
dependencies = [
  "torch>=2.1",
  "scikit-learn>=1",
  "typing_extensions",
  "scipy",
  "pandas",
  "einops",
  "huggingface-hub",
]
requires-python = ">=3.9"
authors = [
  { name = "Noah Hollmann", email = "noah.hollmann@charite.de" },
  { name = "Samuel Müller", email = "muellesa@cs.uni-freiburg.de" },
  { name = "Lennart Purucker" },
  { name = "Arjun Krishnakumar" },
  { name = "Max Körfer" },
  { name = "Shi Bin Hoo" },
  { name = "Robin Tibor Schirrmeister" },
  { name = "Frank Hutter", email = "fh@cs.uni-freiburg.de" },

  { name = "Eddie Bergman"}, # Huge thanks to code refactoring contributor Eddie
]
readme = "README.md"
description = "TabPFN: Foundation model for tabular data"
classifiers = [
  'Intended Audience :: Science/Research',
  'Intended Audience :: Developers',
  #'License', TODO: Add license - currently can't be licensed!
  'Programming Language :: Python',
  'Topic :: Software Development',
  'Topic :: Scientific/Engineering',
  'Operating System :: POSIX',
  'Operating System :: Unix',
  'Operating System :: MacOS',
  'Programming Language :: Python :: 3',
  'Programming Language :: Python :: 3.9',
  'Programming Language :: Python :: 3.10',
  'Programming Language :: Python :: 3.11',
]
license = { file = "LICENSE" }

[project.urls]
documentation = "https://priorlabs.ai/docs"
source = "https://github.com/priorlabs/tabpfn"

[project.optional-dependencies]
dev = [
  # Lint/format
  "pre-commit",
  "ruff",
  "mypy",
  # Test
  "pytest",
  # Docs
  "mkdocs",
  "mkdocs-material",
  "mkdocs-autorefs",
  "mkdocs-gen-files",
  "mkdocs-literate-nav",
  "mkdocs-glightbox",
  "mkdocstrings[python]",
  "markdown-exec[ansi]",
  "mike",
  "black",  # This allows mkdocstrings to format signatures in the docs
]

[tool.pytest.ini_options]
testpaths = ["tests"]  # Where the tests are located
minversion = "8.0"
empty_parameter_set_mark = "xfail"  # Prevents user error of an empty `parametrize` of a test
log_cli = false
log_level = "DEBUG"
xfail_strict = true
addopts = "--durations=10 -vv"

# https://github.com/charliermarsh/ruff
[tool.ruff]
target-version = "py39"
line-length = 88
output-format = "full"
src = ["src", "tests", "examples"]

[tool.ruff.lint]
# Extend what ruff is allowed to fix, even it it may break
# This is okay given we use it all the time and it ensures
# better practices. Would be dangerous if using for first
# time on established project.
extend-safe-fixes = ["ALL"]

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

select = [
  "A",
  # "ANN", # Handled by mypy
  "ARG",
  "B",
  "BLE",
  "COM",
  "C4",
  "D",
  # "DTZ",  # One day I should know how to utilize timezones and dates...
  "E",
  # "EXE", Meh
  "ERA",
  "F",
  "FBT",
  "I",
  # "ISC",  # Favours implicit string concatenation
  "INP",
  # "INT", # I don't understand this one
  "N",
  "NPY",
  "PD",
  "PLC",
  "PLE",
  "PLR",
  "PLW",
  "PIE",
  "PT",
  "PTH",
  # "PYI", # Specific to .pyi files for type stubs
  "Q",
  "PGH004",
  "RET",
  "RUF",
  "C90",
  "S",
  # "SLF",    # Private member accessed (sure, it's python)
  "SIM",
  # "TRY", # Good in principle, would take a lot of work to statisfy
  "T10",
  "T20",
  "TID",
  "TCH",
  "UP",
  "N",
  "W",
  "YTT",
]

ignore = [
  "D104",    # Missing docstring in public package
  "D105",    # Missing docstring in magic mthod
  "D203",    # 1 blank line required before class docstring
  "D205",    # 1 blank line between summary and description
  "D401",    # First line of docstring should be in imperative mood
  "N806",    # Variable X in function should be lowercase
  "E731",    # Do not assign a lambda expression, use a def
  "A002",    # Shadowing a builtin
  "A003",    # Shadowing a builtin
  "S101",    # Use of assert detected.
  "W292",    # No newline at end of file
  "PLC1901", # "" can be simplified to be falsey
  "TC003",   # Move stdlib import into TYPE_CHECKING
  "PLR2004", # Magic numbers, gets in the way a lot
  "PLR0915", # Too many statements
  "N803",    # Argument name `X` should be lowercase
  "N802",    # Function name should be lowercase
  # These tend to be lighweight and confuse pyright
]

exclude = [
  ".bzr",
  ".direnv",
  ".eggs",
  ".git",
  ".hg",
  ".mypy_cache",
  ".nox",
  ".pants.d",
  ".ruff_cache",
  ".svn",
  ".tox",
  ".venv",
  "__pypackages__",
  "_build",
  "buck-out",
  "build",
  "dist",
  "node_modules",
  "venv",
  "docs",
]

# Exclude a variety of commonly ignored directories.
[tool.ruff.lint.per-file-ignores]
"tests/*.py" = [
  "S101",
  "D101",
  "D102",
  "D103",
  "ANN001",
  "ANN201",
  "FBT001",
  "D100",
  "PLR2004",
  "PD901",   #  X is a bad variable name. (pandas)
  "TCH",
  "N803",
  "C901",  # Too complex
]
"__init__.py" = ["I002"]
"examples/*" = ["INP001", "I002", "E741", "D101", "D103", "T20", "D415", "ERA001", "E402", "E501"]
"docs/*" = ["INP001"]
# TODO(eddiebergman):
"src/tabpfn/model/*.py" = [
  # Documentation
  "D100",
  "D101",
  "D102",
  "D103",
  "D107",
]
# TODO(eddiebergman): There's a lot of typing and ruff problems detected here
"src/tabpfn/model/multi_head_attention.py" = [
]
"src/tabpfn/model/encoders.py" = [
  "PT018",
  "ARG002",
  "E501",
  "ERA001",
  "F821",
  "FBT001",
  "FBT002",
  "A001",
  "A001",
]
"src/tabpfn/*.py" = [
  "D107",
]


[tool.ruff.lint.isort]
known-first-party = ["tabpfn"]
known-third-party = ["sklearn"]
no-lines-before = ["future"]
required-imports = ["from __future__ import annotations"]
combine-as-imports = true
extra-standard-library = ["typing_extensions"]
force-wrap-aliases = true

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.pylint]
max-args = 10 # Changed from default of 5

[tool.mypy]
python_version = "3.9"
packages = ["src/tabpfn", "tests"]

show_error_codes = true

warn_unused_configs = true # warn about unused [tool.mypy] lines

follow_imports = "normal"      # Type check top level api code we use from imports
ignore_missing_imports = false # prefer explicit ignores

disallow_untyped_defs = true       # All functions must have types
disallow_untyped_decorators = true # ... even decorators
disallow_incomplete_defs = true    # ...all types
allow_redefinition = true          # Allow redefining types within a scope

no_implicit_optional = true
check_untyped_defs = true

warn_return_any = true


[[tool.mypy.overrides]]
module = ["tests.*"]
disallow_untyped_defs = false       # Sometimes we just want to ignore verbose types
disallow_untyped_decorators = false # Test decorators are not properly typed
disallow_incomplete_defs = false    # Sometimes we just want to ignore verbose types
disable_error_code = ["var-annotated"]

# TODO(eddiebergman): Too much to deal with right now
[[tool.mypy.overrides]]
module = [
  "tabpfn.model.multi_head_attention",
  "tabpfn.model.encoders"
]
ignore_errors = true

[[tool.mypy.overrides]]
module = [
  "sklearn.*",
  "matplotlib.*",
  "einops.*",
  "networkx.*",
  "scipy.*",
  "pandas.*",
  "huggingface_hub.*",
  "joblib.*",
  "torch.*",
  "kditransform.*",
]
ignore_missing_imports = true

# TODO: We don't necessarily need this
[tool.pyright]
include = ["src", "tests"]

pythonVersion = "3.9"
typeCheckingMode = "strict"

strictListInference = true
strictSetInference = true
strictDictionaryInference = false
reportImportCycles = false
reportMissingSuperCall = true
reportMissingTypeArgument = false
reportOverlappingOverload = true
reportIncompatibleVariableOverride = true
reportIncompatibleMethodOverride = true
reportInvalidTypeVarUse = true
reportCallInDefaultInitializer = true
reportImplicitOverride = true
reportUnknownMemberType = false
reportUnknownParameterType = false
reportUnknownVariableType = false
reportUnknownArgumentType = false
reportUnknownLambdaType = false
reportPrivateUsage = false
reportUnnecessaryCast = false
reportUnusedFunction = false
reportMissingTypeStubs = false
reportPrivateImportUsage = false
reportUnnecessaryComparison = false
reportConstantRedefinition = false
reportUntypedFunctionDecorator = false
