name: Self-hosted runner (benchmark)

on:
  push:
    branches: [main]
  pull_request:
    types: [ opened, labeled, reopened, synchronize ]

concurrency:
  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
  cancel-in-progress: true

env:
  HF_HOME: /mnt/cache

jobs:
  benchmark:
    name: Benchmark
    strategy:
      matrix:
        group: [aws-g5-4xlarge-cache, aws-p4d-24xlarge-plus]
    runs-on:
      group: ${{ matrix.group }}
    if: |
      (github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-benchmark') )||
      (github.event_name == 'push' && github.ref == 'refs/heads/main')
    container:
      image: huggingface/transformers-pytorch-gpu
      options: --gpus all --privileged --ipc host
    steps:
      - name: Get repo
        uses: actions/checkout@v4
        with:
          ref: ${{ github.event.pull_request.head.sha || github.sha }}

      - name: Install libpq-dev & psql
        run: |
          apt update
          apt install -y libpq-dev postgresql-client

      - name: Install benchmark script dependencies
        run: python3 -m pip install -r benchmark/requirements.txt

      - name: Reinstall transformers in edit mode (remove the one installed during docker image build)
        working-directory: /transformers
        run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]"

      - name: Run database init script
        run: |
          psql -f benchmark/init_db.sql
        env:
          PGDATABASE: metrics
          PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }}
          PGUSER: transformers_benchmarks
          PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }}

      - name: Run benchmark
        run: |
          git config --global --add safe.directory /__w/transformers/transformers
          if [ "$GITHUB_EVENT_NAME" = "pull_request" ]; then
            commit_id=$(echo "${{ github.event.pull_request.head.sha }}")
          elif [ "$GITHUB_EVENT_NAME" = "push" ]; then
            commit_id=$GITHUB_SHA
          fi
          commit_msg=$(git show -s --format=%s | cut -c1-70)
          python3 benchmark/llama.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg"
        env:
          HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
          # Enable this to see debug logs
          # HF_HUB_VERBOSITY: debug
          # TRANSFORMERS_VERBOSITY: debug
          PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }}
          PGUSER: transformers_benchmarks
          PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }}
