[project]
name = "fma-llama"
version = "0.1.0"
description = "Llama 3.2 1B in Flax with FMA attention approximation"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
    "jax[cuda]>=0.4.20",
    "jaxlib>=0.4.20",
    "flax>=0.8.0",
    "optax>=0.1.9",
    "transformers>=4.40.0", # For loading Llama weights and tokenizer
    "safetensors>=0.4.0", # For loading pretrained weights
    "numpy>=1.24.0",
    "tiktoken>=0.5.0", # Alternative tokenizer if needed
    "torch>=2.0.0", # For weight conversion from PyTorch
    "tqdm>=4.66.0",
    "datasets<4.0.0", # For loading PG-19
    "chex>=0.1.85", # For testing
    "matplotlib>=3.7.0", # For plotting loss curves
    "scipy>=1.11.0", # For curve fitting in loss analysis
    "nvtx>=0.2.14",
    "einshape>=1.0",
    "jax-triton>=0.3.0",
]

[project.optional-dependencies]
dev = [
    "pytest>=7.4.0",
    "pytest-xdist>=3.3.0",
    "black>=23.0.0",
    "ruff>=0.1.0",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.black]
line-length = 100
target-version = ['py310']

[tool.ruff]
line-length = 100
target-version = "py310"
