name: Main

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}

on:
  pull_request:
    branches:
      - main
      - Torch2
  push:
    branches:
      - main
    tags:
      - 'v*.*.*'

env:
  # Change this to invalidate existing cache.
  CACHE_PREFIX: v2
  PYTHONPATH: ./
  TOKENIZERS_PARALLELISM: 'false'

jobs:
  checks:
    name: ${{ matrix.task.name }} (py ${{ matrix.python }})
    runs-on: [ubuntu-latest]
    timeout-minutes: 10
    strategy:
      fail-fast: false
      matrix:
        python: ['3.8', '3.10']
        task:
          - name: Lint
            run: |
              ruff check .

        include:
          - python: '3.10'
            task:
              name: Test
              run: pytest -v --color=yes --durations=5 tests/

          - python: '3.10'
            task:
              name: Type check
              run: mypy .

          - python: '3.10'
            task:
              name: Build
              run: |
                python -m build

          - python: '3.10'
            task:
              name: Style
              run: |
                isort --check .
                black --check .

          - python: '3.10'
            task:
              name: Data pipeline
              run: |
                python scripts/prepare_memmap_dataset.py test_fixtures/*.json.gz -o /tmp/c4-sample.npy --validate --ack-deprecated

    steps:
      - uses: actions/checkout@v3

      - name: Setup Python environment
        uses: ./.github/actions/setup-venv
        with:
          python-version: ${{ matrix.python }}
          cache-prefix: ${{ env.CACHE_PREFIX }}

      - name: Restore mypy cache
        if: matrix.task.name == 'Type check'
        uses: actions/cache@v3
        with:
          path: .mypy_cache
          key: mypy-${{ env.CACHE_PREFIX }}-${{ runner.os }}-${{ matrix.python }}-${{ hashFiles('*requirements.txt', '*pyproject.toml') }}-${{ github.ref }}-${{ github.sha }}
          restore-keys: |
            mypy-${{ env.CACHE_PREFIX }}-${{ runner.os }}-${{ matrix.python }}-${{ hashFiles('*requirements.txt', '*pyproject.toml') }}-${{ github.ref }}
            mypy-${{ env.CACHE_PREFIX }}-${{ runner.os }}-${{ matrix.python }}-${{ hashFiles('*requirements.txt', '*pyproject.toml') }}

      - name: ${{ matrix.task.name }}
        run: |
          . .venv/bin/activate
          ${{ matrix.task.run }}

      - name: Upload package distribution files
        if: matrix.task.name == 'Build'
        uses: actions/upload-artifact@v3
        with:
          name: package
          path: dist

      - name: Clean up
        if: always()
        run: |
          . .venv/bin/activate
          pip uninstall -y olmo

  gpu_tests:
    name: GPU Tests
    runs-on: ubuntu-latest
    timeout-minutes: 8
    env:
      BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
      BEAKER_IMAGE: olmo-torch2-test
      BEAKER_WORKSPACE: ai2/llm-testing
    steps:
      - name: Determine current commit SHA (pull request)
        if: github.event_name == 'pull_request'
        run: |
          echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV

      - name: Determine current commit SHA (push)
        if: github.event_name != 'pull_request'
        run: |
          echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV

      - name: GPU Tests
        uses: allenai/beaker-run-action@v1.2
        if: env.BEAKER_TOKEN != ''
        with:
          spec: |
            version: v2
            description: GPU Tests
            budget: ai2/oe-training
            tasks:
              - name: tests
                image:
                  beaker: ${{ env.BEAKER_IMAGE }}
                context:
                  priority: normal
                resources:
                  gpuCount: 1
                constraints:
                  cluster:
                    - ai2/general-cirrascale
                    - ai2/general-cirrascale-a100-80g-ib
                    - ai2/allennlp-cirrascale
                envVars:
                  - name: COMMIT_SHA
                    value: ${{ env.COMMIT_SHA }}
                  - name: GITHUB_TOKEN
                    value: ${{ secrets.GITHUB_TOKEN }}
                  - name: CUDA_LAUNCH_BLOCKING
                    value: "1"
                  - name: CUBLAS_WORKSPACE_CONFIG
                    value: ":16:8"
                  - name: TOKENIZERS_PARALLELISM
                    value: "false"
                command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/", "-k", "not hf_olmo"]
                result:
                  path: /unused
          token: ${{ env.BEAKER_TOKEN }}
          workspace: ${{ env.BEAKER_WORKSPACE }}

  release:
    name: Release
    runs-on: ubuntu-latest
    needs: [checks]
    if: startsWith(github.ref, 'refs/tags/')
    steps:
      - uses: actions/checkout@v3
        with:
          fetch-depth: 0

      - name: Setup Python environment
        uses: ./.github/actions/setup-venv
        with:
          python-version: '3.10'
          cache-prefix: ${{ env.CACHE_PREFIX }}

      - name: Prepare environment
        run: |
          echo "RELEASE_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV
          echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV

      - name: Download package distribution files
        uses: actions/download-artifact@v3
        with:
          name: package
          path: dist

      - name: Generate release notes
        run: |
          . .venv/bin/activate
          python scripts/release_notes.py > ${{ github.workspace }}-RELEASE_NOTES.md

      - name: Publish package to PyPI
        run: |
          . .venv/bin/activate
          twine upload -u __token__ -p '${{ secrets.PYPI_PASSWORD }}' dist/*

      - name: Publish GitHub release
        uses: softprops/action-gh-release@v1
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          body_path: ${{ github.workspace }}-RELEASE_NOTES.md
          prerelease: ${{ contains(env.TAG, 'rc') }}
          files: |
            dist/*
