name: SkyRL-tx-CPU

on:
  push:
    branches: [ main ]
    paths:
      - 'skyrl-tx/**'
      - '.github/workflows/cpu_skyrl_tx.yaml'
  pull_request:
    paths:
      - 'skyrl-tx/**'
      - '.github/workflows/cpu_skyrl_tx.yaml'
  workflow_dispatch:

permissions:
  checks: write   # for status checks to appear
  contents: read

# Cancel runs for previous commits on the same branch
concurrency:
  group: skyrl-tx-${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  skyrl_tx_tests:
    runs-on: ubuntu-latest
    defaults:
      run:
        shell: bash
        working-directory: ./skyrl-tx
    steps:
    - uses: actions/checkout@v4
    - name: Install uv
      run: |
        curl -LsSf https://astral.sh/uv/install.sh | sh
    # - name: Check if reference docs are up to date
    #   run: |
    #     uv run --extra dev typer tx/run/main.py utils docs --name tx --output docs/reference.md && git diff --exit-code docs/reference.md
    # - name: Test docs
    #   run: |
    #     uv run --extra dev mkdocs build --strict
    - name: Run lint
      run: |
        uvx ruff check
    # - name: Run type checking
    #   run: |
    #     uv run --extra tinker --extra dev ty check
    - name: Run pytest
      run: |
        uv run --extra tinker --extra dev pytest --forked -s tests
    - name: Run a single training step
      run: |
        uv run tx train --model pcmoritz/qwen3-tiny-test --dataset mahiatlinux/TinyStories-GPT4-V2-50K-SUBSET --output-dir /tmp --batch-size 2 --max-steps 1 --optimizer-args '{"learning_rate": 0.002, "weight_decay": 0.1}'
        uv run tx train --model pcmoritz/qwen3-tiny-test --dataset mahiatlinux/TinyStories-GPT4-V2-50K-SUBSET --output-dir /tmp --batch-size 2 --max-steps 1 --optimizer-args '{"learning_rate": 0.002, "weight_decay": 0.1}' --tp-size 2
    - name: Run a single fine-tuning step on a chat dataset
      run: |
        uv run --with jinja2 tx train --model pcmoritz/qwen3-tiny-test --dataset smangrul/ultrachat-feedback-10k-chatml --output-dir /tmp --batch-size 2 --max-steps 1 --loader tx.loaders.chat --load-checkpoint-path /tmp
    - name: Run a single fine-tuning step with Qwen3 MoE
      run: |
        uv run --with huggingface_hub hf download trl-internal-testing/tiny-Qwen3MoeForCausalLM --local-dir /tmp/qwen3_moe
        uv run --with jinja2 tx train --model trl-internal-testing/tiny-Qwen3MoeForCausalLM --dataset smangrul/ultrachat-feedback-10k-chatml --output-dir /tmp --batch-size 2 --max-steps 1 --loader tx.loaders.chat --load-checkpoint-path /tmp/qwen3_moe
    - name: Test experiment tracker integration
      run: |
        WANDB_MODE=offline WANDB_API_KEY=dummy uv run --with wandb tx train --model pcmoritz/qwen3-tiny-test --dataset mahiatlinux/TinyStories-GPT4-V2-50K-SUBSET --output-dir /tmp --batch-size 2 --max-steps 1 --tracker wandb --tracker-args '{"name": "Qwen3-8B", "project": "tx"}'
    - name: Run engine benchmarks
      run: |
        uv run --extra tinker --extra dev python benchmarks/benchmark_engine.py --base-model trl-internal-testing/tiny-Qwen3ForCausalLM --backend-config '{"max_lora_adapters": 3, "max_lora_rank": 1}' --num-warmup-steps 1 --num-steps 1 --num-requests 1 --seq-len 8 --sample-max-tokens 16
