[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "skyrl-train"
version = "0.1.0"
description = "skyrl-train"
authors = [
    {name = "NovaSkyAI", email = "novasky.berkeley@gmail.com"}
]
license = {text = "MIT"}
readme = "README.md"
requires-python = "==3.12.*"
classifiers = [
    "Programming Language :: Python :: 3",
    "License :: OSI Approved :: MIT License",
    "Operating System :: OS Independent",
]

dependencies = [
    "loguru",
    "tqdm",
    "tensorboard",
    "func_timeout",
    "transformers>=4.51.0",
    "hydra-core==1.3.2",
    "accelerate",
    "torchdata",
    "omegaconf",
    "ray==2.48.0",
    "peft",
    "debugpy==1.8.0",
    "hf_transfer",
    "wandb",
    "datasets",
    "tensordict",
    "jaxtyping",
    "skyrl-gym",
    "flash-attn",
]

[tool.uv]
conflicts = [
    [
        { extra = "vllm" },
        { extra = "sglang" },
    ],
]

[tool.uv.sources]
skyrl-gym = { path = "./skyrl-gym" , editable = true }
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }
flash-attn = { url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl" }
# NOTE (sumanthrh): We explictly use a flashinfer wheel from their index. 
# The wheels on PyPI don't come with pre-compiled kernels and the package will JIT compile them at runtime which is slow.
# additionally, different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang
flashinfer-python = [
    { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra =='vllm'" }, 
    { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'vllm'" }
]
 
[project.optional-dependencies]
deepspeed = [
    "deepspeed==0.16.5"
]
dev = [
    "ruff==0.11.9",
    "black==24.10.0",
    "pytest>=6.2.5",
    "pytest-asyncio",
    "pre-commit",
]
docs = [
    "sphinx>=7.0.0",
    "sphinx-rtd-theme>=2.0.0",
    "sphinx-autodoc-typehints>=1.25.0",
    "myst-parser>=2.0.0",
    "sphinx-copybutton>=0.5.0",
    "sphinx-autobuild>=2021.3.14"
]
vllm = [
    "vllm==0.9.2",
    "torch==2.7.0",
    "flashinfer-python",
    "torchvision"
]
sglang = [
    "sglang[srt,openai,torch_memory_saver]==0.4.8.post1",  # 0.4.9.post1 causes non-colocate weight broadcast to hang
    "flashinfer-python",
    "torch==2.7.1",
    "torchvision",
]


[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.setuptools.packages.find]
include = ["skyrl_train*"]

[tool.setuptools.dynamic]
version = {attr = "skyrl_train.__version__"}
readme = {file = ["README.md"]}

[tool.pytest.ini_options]
addopts = "-v -s"
testpaths = [
    "tests",
]



[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 120
known_third_party = "wandb"

[tool.black]
line-length = 120
include = '\.pyi?$'
extend-exclude = '''
/(
  # directories
  \.eggs
  | \.git
  | \.hg
  | \.mypy_cache
  | \.tox
  | \.venv
  | build
  | dist
)/
'''

[tool.flake8]
max-line-length = 120
max-doc-length = 120
extend-ignore = [
    # Default ignored errors by flake8
    "E121", "E123", "E126", "E226", "E24", "E704",
    # F401 module imported but unused
    "F401",
    # E203 whitespace before ':' (conflict with black)
    "E203",
    # E231 missing whitespace after ',' (conflict with black)
    "E231",
    # E501 line too long (conflict with black)
    "E501",
    # E741 do not use variables named 'l', 'O', or 'I'
    "E741",
    # W503 line break before binary operator (conflict with black)
    "W503",
    # W504 line break after binary operator (conflict with black)
    "W504",
    # W505 doc line too long (conflict with black)
    "W505",
    # W605 invalid escape sequence 'x' (conflict with latex within docs)
    "W605",
]

[tool.ruff.lint]
ignore = [
    "F722" # Syntax error in annotation - ignored because this doesn't play well with jaxtyping
]
