name: Check running SLOW tests from a PR (only GPU)

on:
  workflow_dispatch:
    inputs:
      docker_image:
        default: 'diffusers/diffusers-pytorch-cuda'
        description: 'Name of the Docker image'
        required: true
      branch:
        description: 'PR Branch to test on'
        required: true
      test:
        description: 'Tests to run (e.g.: `tests/models`).'
        required: true

env:
  DIFFUSERS_IS_CI: yes
  IS_GITHUB_CI: "1"
  HF_HOME: /mnt/cache
  OMP_NUM_THREADS: 8
  MKL_NUM_THREADS: 8
  PYTEST_TIMEOUT: 600
  RUN_SLOW: yes

jobs:
  run_tests:
    name: "Run a test on our runner from a PR"
    runs-on:
      group: aws-g4dn-2xlarge
    container:
      image: ${{ github.event.inputs.docker_image }}
      options: --gpus 0 --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/

    steps:
      - name: Validate test files input
        id: validate_test_files
        env:
          PY_TEST: ${{ github.event.inputs.test }}
        run: |
          if [[ ! "$PY_TEST" =~ ^tests/ ]]; then
            echo "Error: The input string must start with 'tests/'."
            exit 1
          fi

          if [[ ! "$PY_TEST" =~ ^tests/(models|pipelines) ]]; then
            echo "Error: The input string must contain either 'models' or 'pipelines' after 'tests/'."
            exit 1
          fi

          if [[ "$PY_TEST" == *";"* ]]; then
            echo "Error: The input string must not contain ';'."
            exit 1
          fi
          echo "$PY_TEST"

      - name: Checkout PR branch
        uses: actions/checkout@v4
        with:
          ref: ${{ github.event.inputs.branch }}
          repository: ${{ github.event.pull_request.head.repo.full_name }}


      - name: Install pytest
        run: |
          python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
          python -m uv pip install -e [quality,test]
          python -m uv pip install peft

      - name: Run tests
        env:
            PY_TEST: ${{ github.event.inputs.test }}
        run: |
          pytest "$PY_TEST"
