[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "skyrl-train"
version = "0.3.1"
description = "skyrl-train"
authors = [
    {name = "NovaSkyAI", email = "novasky.berkeley@gmail.com"}
]
license = {text = "MIT"}
readme = "README.md"
requires-python = "==3.12.*"
classifiers = [
    "Programming Language :: Python :: 3",
    "License :: OSI Approved :: MIT License",
    "Operating System :: OS Independent",
]

dependencies = [
    "loguru",
    "tqdm",
    "ninja",
    "tensorboard",
    "func_timeout",
    "transformers>=4.51.0",
    "hydra-core==1.3.2",
    "accelerate",
    "torchdata",
    "omegaconf",
    "ray==2.51.1",
    "peft",
    "debugpy==1.8.0",
    "hf_transfer",
    "wandb",
    "datasets==4.0.0",
    "tensordict",
    "jaxtyping",
    "skyrl-gym",
    "flash-attn",
    "polars",
    "s3fs",
    "fastapi",
    "uvicorn",
    "pybind11",
    "setuptools",
    "litellm>=1.79.2",
    "google-cloud-storage>=3.7.0",
]

[tool.uv]
required-version = ">=0.8.10"
conflicts = [
    [
        { extra = "vllm" },
        { extra = "sglang" },
    ],
    [
        { extra = "vllm" },
        { extra = "flashrl" },
        { extra = "sglang" },
    ],
    [
        { extra = "flashrl" },
        { extra = "miniswe" },
    ],
    [
        { extra = "mcore" },
        { extra = "vllm" },
        { extra = "sglang" },
        { extra = "flashrl" },
    ]
]
no-build-isolation-package = [
    "transformer-engine-torch",
    "transformer-engine",
    "nv-grouped-gemm",
]
override-dependencies = [
    "nvidia-resiliency-ext; sys_platform == 'never'",
    "mamba-ssm; sys_platform == 'never'",
    "causal-conv1d; sys_platform == 'never'",
    "transformer-engine[pytorch]==2.9.0",
    "megatron-core==0.15.0"
]
[tool.uv.extra-build-dependencies]
flash-attn = [{requirement = "torch", match-runtime = true}]
transformer-engine = [{requirement = "torch", match-runtime = true}, "build_tools", "ninja"]
transformer-engine-torch = [{requirement = "torch", match-runtime = true}, "build_tools", "ninja"]

[tool.uv.extra-build-variables]
flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"}

[tool.uv.sources]
skyrl-gym = { path = "./skyrl-gym" , editable = true }
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }
# We use `flashinfer-jit-cache` to avoid slow JIT compilation on first run.
# Different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang
flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm' or extra == 'mcore'" }
flashinfer-python = [
    { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" }
]
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "953aabf75c0500180dc14a6a76cf9e7e7c4baec7"}


[project.optional-dependencies]
deepspeed = [
    "deepspeed==0.17.6"
]
dev = [
    "ruff==0.11.9",
    "black==24.10.0",
    "pytest>=6.2.5",
    "pytest-asyncio",
    "pre-commit",
    "litellm",
]
docs = [
    "sphinx>=7.0.0",
    "sphinx-rtd-theme>=2.0.0",
    "sphinx-autodoc-typehints>=1.25.0",
    "myst-parser>=2.0.0",
    "sphinx-copybutton>=0.5.0",
    "sphinx-autobuild>=2021.3.14"
]
# TODO(tgriggs): Add `sandboxes` here once available as a package.
sandboxes = [
    "litellm[proxy]>=1.67.5",
]
vllm = [
    "vllm==0.11.0",
    "flash-attn==2.8.3",
    "torch==2.8.0",
    "flashinfer-python",
    "flashinfer-jit-cache",
    "torchvision"
]
sglang = [
    "sglang[srt,openai,torch_memory_saver]==0.4.8.post1",  # 0.4.9.post1 causes non-colocate weight broadcast to hang
    "flashinfer-python",
    "flash-attn==2.8.3",
    "torch==2.7.1",
    "torchvision",
]
mcore = [
  "transformer-engine[pytorch]==2.9.0",
  "flash-attn==2.8.1",
  "vllm==0.11.0",
  "torch==2.8.0",
  "flashinfer-python==0.5.2",
  "torchvision",
  "megatron-bridge @ git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@v0.2.0",
  "megatron-core==0.15.0",
  "flashinfer-jit-cache==0.5.2",
  "nvidia-modelopt",
]
flashrl = [
    # NOTE: Custom vLLM wheel must be installed separately.
    # See examples/flash_rl/README.md for installation instructions.
    "flash-attn==2.8.3",
    "torch==2.7.0",
    "flashinfer-python",
    "torchvision",
]
miniswe = [
    # NOTE (sumanthrh): Needs to be a commit after https://github.com/SWE-agent/mini-swe-agent/commit/4f5d445e99d13b5482478c23508bf2fbf7c0670c
    "mini-swe-agent>=1.12.0",
    "litellm",
]


[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[[tool.uv.index]]
name = "flashinfer-cu128"
url = "https://flashinfer.ai/whl/cu128"
explicit = true

[tool.setuptools]
include-package-data = true

[tool.setuptools.packages.find]
include = ["skyrl_train*"]

[tool.setuptools.dynamic]
version = {attr = "skyrl_train.__version__"}
readme = {file = ["README.md"]}

# Ship default config files in the wheel
[tool.setuptools.package-data]
"skyrl_train" = ["config/**"]

[tool.pytest.ini_options]
addopts = "-v -s"
testpaths = [
    "tests",
]



[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 120
known_third_party = "wandb"

[tool.black]
line-length = 120
include = '\.pyi?$'
extend-exclude = '''
/(
  # directories
  \.eggs
  | \.git
  | \.hg
  | \.mypy_cache
  | \.tox
  | \.venv
  | build
  | dist
)/
'''

[tool.flake8]
max-line-length = 120
max-doc-length = 120
extend-ignore = [
    # Default ignored errors by flake8
    "E121", "E123", "E126", "E226", "E24", "E704",
    # F401 module imported but unused
    "F401",
    # E203 whitespace before ':' (conflict with black)
    "E203",
    # E231 missing whitespace after ',' (conflict with black)
    "E231",
    # E501 line too long (conflict with black)
    "E501",
    # E741 do not use variables named 'l', 'O', or 'I'
    "E741",
    # W503 line break before binary operator (conflict with black)
    "W503",
    # W504 line break after binary operator (conflict with black)
    "W504",
    # W505 doc line too long (conflict with black)
    "W505",
    # W605 invalid escape sequence 'x' (conflict with latex within docs)
    "W605",
]

[tool.ruff.lint]
ignore = [
    "F722" # Syntax error in annotation - ignored because this doesn't play well with jaxtyping
]

[pytest]
markers = [
    "vllm",
    "sglang",
    "integrations",
    "megatron",
]
