[project]
name = "e2e_sae"
version = "1.0.0"
description = "Repo for training sparse autoencoders end-to-end"
requires-python = ">=3.11"
readme = "README.md"
dependencies = [
    "torch~=2.2.0",
    "torchvision~=0.17.0",
    "einops~=0.7.0",
    "pydantic~=2.0",
    "wandb~=0.16.2",
    "fire~=0.5.0",
    "tqdm~=4.66.1",
    "pytest~=8.1.2",
    "ipykernel~=6.29.0",
    "transformer-lens~=1.14.0",
    "jaxtyping~=0.2.25",
    "python-dotenv~=1.0.1",
    "zstandard~=0.22.0",
    "matplotlib~=3.5.3",
    "seaborn~=0.13.2",
    "umap-learn~=0.5.6",
    "tenacity~=8.2.3",
    "statsmodels~=0.14.2",
    "automated-interpretability~=0.0.3"
]

[project.urls]
repository = "https://github.com/ANONYMIZED/"

[project.optional-dependencies]
dev = [
    "ruff~=0.1.14",
    "pyright~=1.1.362",
    "pre-commit~=3.6.0",
]

[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = ["e2e_sae", "e2e_sae.models", "e2e_sae.scripts"]

[tool.ruff]
line-length = 100
fix = true
ignore = [
    "F722" # Incompatible with jaxtyping
]

[tool.ruff.lint]
select = [
    # pycodestyle
    "E",
    # Pyflakes
    "F",
    # pyupgrade
    "UP",
    # flake8-bugbear
    "B",
    # flake8-simplify
    "SIM",
    # isort
    "I",
]

[tool.ruff.format]
# Enable reformatting of code snippets in docstrings.
docstring-code-format = true

[tool.ruff.isort]
known-third-party = ["wandb"]

[tool.pyright]
include = ["e2e_sae", "tests"]

strictListInference = true
strictDictionaryInference = true
strictSetInference = true
reportFunctionMemberAccess = true
reportUnknownParameterType = true
reportIncompatibleMethodOverride = true
reportIncompatibleVariableOverride = true
reportInconsistentConstructorType = true
reportOverlappingOverload = true
reportConstantRedefinition = true
reportImportCycles = false
reportPropertyTypeMismatch = true
reportMissingTypeArgument = true
reportUnnecessaryCast = true
reportUnnecessaryComparison = true
reportUnnecessaryContains = true
reportUnusedExpression = true
reportMatchNotExhaustive = true
reportShadowedImports = true
reportPrivateImportUsage = false

[tool.pytest.ini_options]
filterwarnings = [
    # https://github.com/google/python-fire/pull/447
    "ignore::DeprecationWarning:fire:59",
]