[tool.poetry.dependencies]
python = "^3.10"
sentencepiece = "^0.2.0"
penzai = "^0.1.0"
jax = {version = "=0.4.31", extras = ["tpu"]}
jaxlib = {version = "=0.4.31"}
tiktoken = "^0.6.0"
tqdm = "^4.66.2"
datasets = "^2.19.1"
orbax = "^0.1.9"
torch = {version = "^2.2.2+cpu", source = "torch-cpu"}
appdirs = "^1.4.4"
huggingface-hub = "^0.23.0"
numba = "^0.60.0"
scikit-learn = "^1.5.0"
jax-tqdm = "^0.2.1"
ipywidgets = "^8.1.3"
plotly-express = "^0.4.1"

[[tool.poetry.source]]
name = "jax-nightly"
url = "https://storage.googleapis.com/jax-releases/jax_nightly_releases.html"
priority = "supplemental"

[[tool.poetry.source]]
name = "google-libtpu"
url = "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
priority = "supplemental"

[[tool.poetry.source]]
name = "torch-cpu"
url = "https://download.pytorch.org/whl/cpu"
priority = "supplemental"


[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.4"
isort = "^5.13.2"
nbconvert = "^7.16.4"
transformers = "^4.40.2"
matplotlib = "^3.8.4"
kagglehub = "^0.2.5"
ipywidgets = "^8.1.2"
orbax = "^0.1.9"
optax = "^0.2.2"
more-itertools = "^10.2.0"
rich = "^13.7.1"
replicate = "^0.26.0"
distribute = "^0.7.3"
python-dotenv = "^1.0.1"
anthropic = "^0.26.1"
jax-smi = "^1.0.3"
pylops = "^2.2.0"
langchain = "^0.2.1"
langchain-anthropic = "^0.1.13"
langchain-community = "^0.2.1"
pytest = "^8.2.1"
black = "^24.4.2"
flagembedding = "^1.2.10"
graphviz = "^0.20.3"
circuitsvis = "^1.43.2"
maturin = "^1.7.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
