[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "skyrl-train"
version = "0.2.0"
description = "skyrl-train"
authors = [
    {name = "NovaSkyAI", email = "novasky.berkeley@gmail.com"}
]
license = {text = "MIT"}
readme = "README.md"
requires-python = "==3.12.*"
classifiers = [
    "Programming Language :: Python :: 3",
    "License :: OSI Approved :: MIT License",
    "Operating System :: OS Independent",
]

dependencies = [
    "loguru",
    "tqdm",
    "ninja",
    "tensorboard",
    "func_timeout",
    "transformers>=4.51.0",
    "hydra-core==1.3.2",
    "accelerate",
    "torchdata",
    "omegaconf",
    "ray==2.51.1",
    "peft",
    "debugpy==1.8.0",
    "hf_transfer",
    "wandb",
    "datasets==4.0.0",
    "tensordict",
    "jaxtyping",
    "skyrl-gym",
    "flash-attn",
    "polars",
    "s3fs",
    "fastapi",
    "uvicorn",
]

[tool.uv]
required-version = ">=0.8.10"
conflicts = [
    [
        { extra = "vllm" },
        { extra = "sglang" },
    ],
    [
        { extra = "vllm" },
        { extra = "flashrl" },
        { extra = "sglang" },
    ],
    [
        { extra = "flashrl" },
        { extra = "miniswe" },
    ],
    [
        { extra = "mcore" },
        { extra = "vllm" },
        { extra = "sglang" },
        { extra = "flashrl" },
    ]
]
no-build-isolation-package = [
    "transformer-engine-torch",
    "transformer-engine",
]

[tool.uv.extra-build-dependencies]
flash-attn = [{requirement = "torch", match-runtime = true}]
transformer-engine = [{ requirement = "torch", match-runtime = true }, "build_tools"]
transformer-engine-torch = [{ requirement = "torch", match-runtime = true }, "build_tools"]

[tool.uv.extra-build-variables]
flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"}

[tool.uv.sources]
skyrl-gym = { path = "./skyrl-gym" , editable = true }
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }
# We use `flashinfer-jit-cache` to avoid slow JIT compilation on first run.
# Different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang
flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm'" }
flashinfer-python = [
    { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'mcore' and extra != 'vllm'" },
    { url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" }
]

[project.optional-dependencies]
deepspeed = [
    "deepspeed==0.17.6"
]
dev = [
    "ruff==0.11.9",
    "black==24.10.0",
    "pytest>=6.2.5",
    "pytest-asyncio",
    "pre-commit",
    "litellm",
]
docs = [
    "sphinx>=7.0.0",
    "sphinx-rtd-theme>=2.0.0",
    "sphinx-autodoc-typehints>=1.25.0",
    "myst-parser>=2.0.0",
    "sphinx-copybutton>=0.5.0",
    "sphinx-autobuild>=2021.3.14"
]
# TODO(tgriggs): Add `sandboxes` here once available as a package.
sandboxes = [
    "litellm[proxy]>=1.67.5",
]
vllm = [
    "vllm==0.11.0",
    "flash-attn==2.8.3",
    "torch==2.8.0",
    "flashinfer-python",
    "flashinfer-jit-cache",
    "torchvision"
]
sglang = [
    "sglang[srt,openai,torch_memory_saver]==0.4.8.post1",  # 0.4.9.post1 causes non-colocate weight broadcast to hang
    "flashinfer-python",
    "flash-attn==2.8.3",
    "torch==2.7.1",
    "torchvision",
]
mcore = [
  # Make sure to change the flash attention source (under tool.uv.sources) above to a compatible version (<= 2.7.4.post1) for TransformerEngine==2.5.0
  # https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl
  # For single node: build transformer-engine separately first, and uncomment the transformer-engine library import below
  # uv pip install "torch==2.7.1"
  # uv pip install "nvidia-cudnn-cu12>=9.3"
  # export CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')"
  # export CPATH="$CUDNN_PATH/include:${CPATH:-}"
  # export LD_LIBRARY_PATH="$CUDNN_PATH/lib:${LD_LIBRARY_PATH:-}"
  # uv pip install --no-build-isolation "transformer_engine[pytorch]==2.5.0" --verbose
  # "transformer-engine[pytorch]==2.5.0",
  "transformer-engine[pytorch]==2.7.0",
  "flash-attn==2.7.4.post1",
  "vllm==0.10.1.1",
  "torch==2.7.1",
  "flashinfer-python",
  "torchvision",
  "mbridge==0.15.1",
  "megatron-bridge==0.1.0rc4",
  "megatron-core==0.14.0",
]
flashrl = [
    # NOTE: Custom vLLM wheel must be installed separately.
    # See examples/flash_rl/README.md for installation instructions.
    "flash-attn==2.8.3",
    "torch==2.7.0",
    "flashinfer-python",
    "torchvision",
]
miniswe = [
    # NOTE (sumanthrh): Needs to be a commit after https://github.com/SWE-agent/mini-swe-agent/commit/4f5d445e99d13b5482478c23508bf2fbf7c0670c
    "mini-swe-agent>=1.12.0",
    "litellm",
]


[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[[tool.uv.index]]
name = "flashinfer-cu128"
url = "https://flashinfer.ai/whl/cu128"
explicit = true

[tool.setuptools]
include-package-data = true

[tool.setuptools.packages.find]
include = ["skyrl_train*"]

[tool.setuptools.dynamic]
version = {attr = "skyrl_train.__version__"}
readme = {file = ["README.md"]}

# Ship default config files in the wheel
[tool.setuptools.package-data]
"skyrl_train" = ["config/**"]

[tool.pytest.ini_options]
addopts = "-v -s"
testpaths = [
    "tests",
]



[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 120
known_third_party = "wandb"

[tool.black]
line-length = 120
include = '\.pyi?$'
extend-exclude = '''
/(
  # directories
  \.eggs
  | \.git
  | \.hg
  | \.mypy_cache
  | \.tox
  | \.venv
  | build
  | dist
)/
'''

[tool.flake8]
max-line-length = 120
max-doc-length = 120
extend-ignore = [
    # Default ignored errors by flake8
    "E121", "E123", "E126", "E226", "E24", "E704",
    # F401 module imported but unused
    "F401",
    # E203 whitespace before ':' (conflict with black)
    "E203",
    # E231 missing whitespace after ',' (conflict with black)
    "E231",
    # E501 line too long (conflict with black)
    "E501",
    # E741 do not use variables named 'l', 'O', or 'I'
    "E741",
    # W503 line break before binary operator (conflict with black)
    "W503",
    # W504 line break after binary operator (conflict with black)
    "W504",
    # W505 doc line too long (conflict with black)
    "W505",
    # W605 invalid escape sequence 'x' (conflict with latex within docs)
    "W605",
]

[tool.ruff.lint]
ignore = [
    "F722" # Syntax error in annotation - ignored because this doesn't play well with jaxtyping
]
