name: Check running SLOW tests from a PR (only GPU)

on:
  workflow_dispatch:
    inputs:
      docker_image:
        default: 'diffusers/diffusers-pytorch-cuda'
        description: 'Name of the Docker image'
        required: true
      pr_number:
        description: 'PR number to test on'
        required: true
      test:
        description: 'Tests to run (e.g.: `tests/models`).'
        required: true

env:
  DIFFUSERS_IS_CI: yes
  IS_GITHUB_CI: "1"
  HF_HOME: /mnt/cache
  OMP_NUM_THREADS: 8
  MKL_NUM_THREADS: 8
  PYTEST_TIMEOUT: 600
  RUN_SLOW: yes

jobs:
  run_tests:
    name: "Run a test on our runner from a PR"
    runs-on:
      group: aws-g4dn-2xlarge
    container:
      image: ${{ github.event.inputs.docker_image }}
      options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/

    steps:
      - name: Validate test files input
        id: validate_test_files
        env:
          PY_TEST: ${{ github.event.inputs.test }}
        run: |
          if [[ ! "$PY_TEST" =~ ^tests/ ]]; then
            echo "Error: The input string must start with 'tests/'."
            exit 1
          fi

          if [[ ! "$PY_TEST" =~ ^tests/(models|pipelines|lora) ]]; then
            echo "Error: The input string must contain either 'models', 'pipelines', or 'lora' after 'tests/'."
            exit 1
          fi

          if [[ "$PY_TEST" == *";"* ]]; then
            echo "Error: The input string must not contain ';'."
            exit 1
          fi
          echo "$PY_TEST"
        
        shell: bash -e {0}

      - name: Checkout PR branch
        uses: actions/checkout@v4
        with:
          ref: refs/pull/${{ inputs.pr_number }}/head

      - name: Install pytest
        run: |
          uv pip install -e ".[quality]"
          uv pip install peft

      - name: Run tests
        env:
            PY_TEST: ${{ github.event.inputs.test }}
        run: |
          pytest "$PY_TEST"
