name: wheels

on:
  workflow_call:
    secrets:
      twine_password:
        required: true
    inputs:
      twine_username:
        required: true
        type: string
      os:
        required: true
        type: string
      python:
        required: true
        type: string
      torch_version:
        required: true
        type: string
        description: "Example: 1.13.1"
      cuda_short_version:
        required: true
        type: string
        description: "Example: 117 for 11.7"
      publish:
        required: true
        type: boolean
      sdist:
        required: true
        type: boolean
      pypirc:
        required: false
        type: string
      arch_list:
        required: false
        type: string
        default: "5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6"

# this yaml file can be cleaned up using yaml anchors, but they're not supported in github actions yet
# https://github.com/actions/runner/issues/1182

env:
  # you need at least cuda 5.0 for some of the stuff compiled here.
  TORCH_CUDA_ARCH_LIST: ${{ inputs.arch_list }}
  MAX_JOBS: 1 # will crash otherwise
  DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
  XFORMERS_BUILD_TYPE: "Release"
  TWINE_USERNAME: __token__
  XFORMERS_PACKAGE_FROM: "wheel-${{ github.ref_name }}"

jobs:
  build_wheels:
    name: ${{ inputs.os }}-py${{ inputs.python }}-torch${{ inputs.torch_version }}+cu${{ inputs.cuda_short_version }}
    runs-on: ${{ inputs.os }}
    env:
      # alias for the current python version
      # windows does not have per version binary, it is just 'python3'
      PY: python${{ contains(inputs.os, 'ubuntu') && inputs.python || '3' }}

    container: ${{ contains(inputs.os, 'ubuntu') && 'quay.io/pypa/manylinux2014_x86_64' || null }}
    timeout-minutes: 360
    defaults:
      run:
        shell: bash
    steps:
      - id: cuda_info
        shell: python
        run: |
          import os
          import sys
          print(sys.version)
          # https://github.com/Jimver/cuda-toolkit/blob/master/src/links/linux-links.ts
          full_version, install_script = {
            "118": ("11.8.0", "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"),
            "117": ("11.7.1", "https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run"),
            "116": ("11.6.2", "https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run"),
          }["${{ inputs.cuda_short_version }}"]
          with open(os.environ['GITHUB_OUTPUT'], "r+") as fp:
            fp.write("CUDA_VERSION=" + full_version + "\n")
            fp.write("CUDA_INSTALL_SCRIPT=" + install_script + "\n")
      - name: Recursive checkout
        uses: actions/checkout@v3
        with:
          submodules: recursive
          path: "."
          fetch-depth: 0 # for tags

      - name: Setup twine config
        if: inputs.pypirc
        run: |
          echo "${{ inputs.pypirc }}" > ~/.pypirc
          cat ~/.pypirc

      - name: Define version
        run: |
          set -Eeuo pipefail
          git config --global --add safe.directory "*"
          $PY -m pip install packaging
          version=`$PY packaging/compute_wheel_version.py`
          echo "BUILD_VERSION=$version" >> ${GITHUB_ENV}
          cat ${GITHUB_ENV}

      - name: Setup proper pytorch dependency in "requirements.txt"
        run: |
          sed -i '/torch/d' ./requirements.txt
          echo "torch == ${{ inputs.torch_version }}" >> ./requirements.txt
          cat ./requirements.txt

      - if: runner.os == 'Windows'
        name: (Windows) Setup Runner
        uses: ./.github/actions/setup-windows-runner
        with:
          cuda: ${{ steps.cuda_info.outputs.CUDA_VERSION }}
          python: ${{ inputs.python }}

      - name: Install dependencies
        run: $PY -m pip install wheel setuptools twine -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cuda_short_version }}

      - if: runner.os == 'Linux'
        name: (Linux) install cuda
        run: >
          yum install wget git prename -y &&
          wget -q "${{ steps.cuda_info.outputs.CUDA_INSTALL_SCRIPT }}" -O cuda.run &&
          sh ./cuda.run --silent --toolkit &&
          rm ./cuda.run

      - name: Build wheel
        run: $PY setup.py bdist_wheel -d dist/ -k $PLAT_ARG
        env:
          PLAT_ARG: ${{ contains(inputs.os, 'ubuntu') && '--plat-name manylinux2014_x86_64' || '' }}

      - uses: actions/upload-artifact@v3
        with:
          name: ${{ inputs.os }}-py${{ inputs.python }}-torch${{ inputs.torch_version }}+cu${{ inputs.cuda_short_version }}
          path: dist/*.whl

      - name: Upload wheel to PyPi
        if: inputs.publish
        run: $PY -m twine upload dist/*.whl
        env:
          TWINE_USERNAME: ${{ inputs.twine_username }}
          TWINE_PASSWORD: ${{ secrets.twine_password }}

      - name: Upload source distribution to PyPi
        if: inputs.publish && inputs.sdist
        run: |
          rm -rf dist/
          # unpin pytorch version
          git checkout HEAD -- ./requirements.txt
          $PY setup.py sdist -d sdist/
          $PY -m twine upload sdist/*
        env:
          TWINE_USERNAME: ${{ inputs.twine_username }}
          TWINE_PASSWORD: ${{ secrets.twine_password }}
# Note: it might be helpful to have additional steps that test if the built wheels actually work
