[tool.poetry]
name = "retrievit"
version = "0.1.0"
description = ""
authors = ["The author"]
readme = "README.md"
repository = ""

packages = [{include = "retrievit", from = "src"}]

[tool.poetry.group.dev.dependencies]
ruff = "^0.11.12"
basedpyright = "^1.11.0"
wemake-python-styleguide = "^0.18.0"
flake8 = "^7.0"
pytest = "^7.1.2"
ipykernel = "^6.13.0"
pre-commit = "^3.0"
poethepoet = "^0.13.1"
pytest-cov = "^3.0.0"
jupyterlab = "^3.3.1"
pudb = "^2022.1"
isort = "^5.10.1"
pytest-cases = "^3.6.10"
black = {version = "^22.1.0", extras = ["jupyter"]}
mypy = "^1.8.0"
types-requests = "^2.27.16"
flake8-pytest-style = "^1.6.0"

[tool.poe]
envfile = ".env"

[tool.poe.tasks]

[tool.poe.tasks.format]
help = "Format using the pre-commit hooks"
cmd = "pre-commit run --all-files"

[tool.poe.tasks.typecheck]
help = "Check types with mypy"
cmd = "mypy ."

[tool.poe.tasks.lint]
help = "Lint with flake8"
cmd = "flake8 ."

[tool.poe.tasks.test]
help = "Run the fast Python tests"
cmd = "pytest --cov=src -m 'not slow'"

[tool.poe.tasks.test-everything]
help = "Run all the tests and get the coverage"
cmd = "pytest -v --junitxml=pytest.xml --cov=src -m 'not slow and not multiprocessing'"

[tool.poe.tasks.autoinstall-torch-cuda]
## See https://github.com/python-poetry/poetry/issues/2543
help = "Update torch to use the best CUDA version for your system"
shell = """
	python -m pip install light-the-torch && python -m light_the_torch install --upgrade torch torchvision
"""

[tool.poe.tasks.install-flash-attn]
shell = """
	pip install flash-attn --no-build-isolation
	python -c "import importlib.util; print(f'Flash attention is available: \\033[92m{importlib.util.find_spec('flash_attn') is not None}\\033[00m')"
"""
help = "Install flash attention"

[tool.poe.tasks.install-mamba-conv1d]
# pip install mamba-ssm[causal-conv1d]
# This is a custom installation to ensure compatibility with torch 2.7 and CUDA versions 12
shell = """
	pip install https://github.com/state-spaces/mamba/releases/download/v2.2.6.post3/mamba_ssm-2.2.6.post3+cu12torch2.7cxx11abiFALSE-cp313-cp313-linux_x86_64.whl
	pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.4/causal_conv1d-1.5.4+cu12torch2.7cxx11abiFALSE-cp313-cp313-linux_x86_64.whl
	python -c "import importlib.util; print(f'Mamba-ssm is available: \\033[92m{importlib.util.find_spec('mamba_ssm') is not None}\\033[00m')"
"""
help = "Install Mamba-ssm with causal conv1d support"


[tool.poe.tasks.outdated]
help = 'Show all outdated top-level dependencies'
shell = """
	poetry show --outdated | grep --file=<(poetry show --tree | grep '^\\w' | sed 's/^\\([^ ]*\\).*/^\\1\\\\s/')
"""
interpreter = "bash"


[tool.poetry.dependencies]
torch = "^2.7.1"
python = "^3.13"
wandb = "^0.23.1"
transformers = "^4.57.3"
pydantic = {extras = ["dotenv"], version = "^2.10.6"}
overrides = "^6.1.0"
numpy = "^2.3.3"
loguru = "^0.7.3"
datasets = "^4.4.1"
accelerate = "^1.12.0"
huggingface-hub = "^0.34.0"
scikit-learn = "^1.8.0"
matplotlib = "^3.10.8"
imageio = "^2.37.2"
seaborn = "^0.13.2"


[tool.basedpyright]
exclude = [
	"storage",
	"configs",
	"wandb",
	"**/.*",
	"**/*_cache*",
	"**/python*/test/**",
]
ignore = ["src/retrievet/trainer/**"]
typeCheckingMode = "standard"
reportMissingTypeStubs = false
reportUnknownMemberType = false
reportFunctionMemberAccess = "warning"
reportUnknownVariableType = false
reportUntypedFunctionDecorator = false
reportUnknownLambdaType = false
reportUnknownArgumentType = false
reportAny = false
reportImplicitOverride = false
reportMissingSuperCall = false
reportUnusedCallResult = false
reportCallIssue = false
reportArgumentType = false
reportIncompatibleMethodOverride = false
reportInvalidCast = false
reportPrivateLocalImportUsage = false
# Covered by ruff
reportPrivateUsage = false
reportUnusedImport = false
reportPrivateImportUsage = false
reportImplicitStringConcatenation = false
reportDeprecated = false
reportIncompatibleVariableOverride = false

[tool.pytest.ini_options]
testpaths = ["tests"]
filterwarnings = [
	"ignore::UserWarning",
	'ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning',
]

[tool.ruff]
line-length = 99
target-version = "py311"
unsafe-fixes = true
src = ["src"]

[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint]
ignore-init-module-imports = true
# Enable every possible rule
select = ["ALL"]
ignore = [
	# Allow function call as argument default
	"B008",
	# Don't ask for docstring at top of module --- put it in the functions/classes
	"D100",
	# Do not check for docstring within __init__ method
	"D107",
	# Don't ask about line length, Black recommends using bugbear B950 instead
	"E501",
	# Disable because this project uses jaxtyping (https://github.com/google/jaxtyping/blob/main/FAQ.md#flake8-is-throwing-an-error)
	"F722",
	# Allow import to be uppercase, because torch.nn.functional as F
	"N812",
	# Allow asserts to be used because they're just convenient for type-narrowing. Type-narrowing
	# is more important than the possibility that someone is running python with -O (in optimized
	# mode).
	# https://stackoverflow.com/a/68429294
	"S101",
	# Do not enforce annotations for self, cls, or similar
	"ANN1",
	# Do not block using 'Any' type since it happens
	"ANN401",
	# Let Black handle commas
	"COM",
	# Let logging use f-strings
	"G004",
	# Disable 'flake8-errmsg' because we assume users of this project can read tracebacks
	"EM",
	# Allow TODO comments
	"FIX002",
	# We don't need to care about creating separate exception classes for every single type of
	# error
	"TRY003",
	# Allow assigning variables before returning them
	"RET504",
	# Don't care about requiring an author name or issue link for a todo
	"TD002",
	"TD003",
	# Boolean expressions are fine
	"FBT001",
	"FBT002",
	# Disable flagging commented-out code because it's false-positives on shape comments,
	"ERA001",
	# Things to ignore because ruff's formatter says so
	# https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules
	"D206",
	"D300",
	"E111",
	"E114",
	"E117",
	"ISC001",
	"ISC002",
	"Q000",
	"Q001",
	"Q002",
	"Q003",
	"W191",
	# open files
	"PTH123",
	# Mutable strings
	"RUF012",
	# Allow print statements, I use them for debugging
	"T201",
]
unfixable = [
	# Do not remove unused variables
	"F841",
	# Do not auto-remove commented out code
	"ERA001",
]

[tool.ruff.lint.flake8-quotes]
inline-quotes = "double"

[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.ruff.lint.flake8-type-checking]
exempt-modules = ["typing", "typing_extensions", "pydantic_numpy"]
runtime-evaluated-base-classes = [
	"pydantic.BaseModel",
	"pydantic.generics.GenericModel",
]

[tool.ruff.lint.isort]
combine-as-imports = true
known-first-party = ["vima", "cogelot", "vima_bench"]

[tool.ruff.lint.mccabe]
max-complexity = 18

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D", "S101", "INP001", "PLR2004", "FBT001", "SLF001"]
"scripts/*" = ["INP001"]
"src/**/__init__.py" = ["D", "F401", "I002"]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.pylint]
max-args = 20


[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
