# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Build wheels and deploy

on:
  create:
    tags:
      - v*

jobs:

  setup_release:
    name: Create Release
    runs-on: ubuntu-latest
    steps:
      - name: Get the tag version
        id: extract_branch
        run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
        shell: bash

      - name: Create Release
        id: create_release
        uses: actions/create-release@v1
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          tag_name: ${{ steps.extract_branch.outputs.branch }}
          release_name: ${{ steps.extract_branch.outputs.branch }}

  build_wheels:
    name: Build Wheel
    needs: setup_release
    runs-on: ${{ matrix.os }}

    strategy:
      fail-fast: false
      matrix:
          # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
          # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
          os: [ubuntu-20.04]
          python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
          torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
          cuda-version: ['11.8.0', '12.3.2']
          # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
          # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
          # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
          # when building without C++11 ABI and using it on nvcr images.
          cxx11_abi: ['FALSE', 'TRUE']
          exclude:
            # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
            # Pytorch < 2.2 does not support Python 3.12
            - torch-version: '2.1.2'
              python-version: '3.12'
            # Pytorch < 2.5 does not support Python 3.13
            - torch-version: '2.1.2'
              python-version: '3.13'
            - torch-version: '2.2.2'
              python-version: '3.13'
            - torch-version: '2.3.1'
              python-version: '3.13'
            - torch-version: '2.4.0'
              python-version: '3.13'

    steps:
      - name: Checkout
        uses: actions/checkout@v4

      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python-version }}

      - name: Set CUDA and PyTorch versions
        run: |
          echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
          echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
          echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
          echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV

      - name: Free up disk space
        if: ${{ runner.os == 'Linux' }}
        # https://github.com/easimon/maximize-build-space/blob/master/action.yml
        # https://github.com/easimon/maximize-build-space/tree/test-report
        run: |
          sudo rm -rf /usr/share/dotnet
          sudo rm -rf /opt/ghc
          sudo rm -rf /opt/hostedtoolcache/CodeQL

      - name: Set up swap space
        if: runner.os == 'Linux'
        uses: pierotofy/set-swap-space@v1.0
        with:
          swap-size-gb: 10

      - name: Install CUDA ${{ matrix.cuda-version }}
        if: ${{ matrix.cuda-version != 'cpu' }}
        uses: Jimver/cuda-toolkit@v0.2.19
        id: cuda-toolkit
        with:
          cuda: ${{ matrix.cuda-version }}
          linux-local-args: '["--toolkit"]'
          # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
          # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
          method: 'network'
          sub-packages: '["nvcc"]'

      - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
        run: |
          pip install --upgrade pip
          # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
          pip install setuptools==68.0.0
          # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
          # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
          pip install typing-extensions==4.12.2
          # We want to figure out the CUDA version to download pytorch
          # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
          # This code is ugly, maybe there's a better way to do this.
          export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
            minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
            maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
            print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
          )
          if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
            # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
            # Can't use --no-deps because we need cudnn etc.
            # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
            pip install jinja2
            pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
            pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
          else
            pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
          fi
          nvcc --version
          python --version
          python -c "import torch; print('PyTorch:', torch.__version__)"
          python -c "import torch; print('CUDA:', torch.version.cuda)"
          python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
        shell:
          bash

      - name: Build wheel
        run: |
          # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
          # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
          # However this still fails so I'm using a newer version of setuptools
          pip install setuptools==68.0.0
          pip install ninja packaging wheel
          export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
          export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
          # Limit MAX_JOBS otherwise the github runner goes OOM
          MAX_JOBS=2 CAUSAL_CONV1D_FORCE_BUILD="TRUE" CAUSAL_CONV1D_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
          tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
          wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
          ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
          echo "wheel_name=${wheel_name}" >> $GITHUB_ENV

      - name: Log Built Wheels
        run: |
          ls dist

      - name: Get the tag version
        id: extract_branch
        run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}

      - name: Get Release with tag
        id: get_current_release
        uses: joutvhu/get-release@v1
        with:
          tag_name: ${{ steps.extract_branch.outputs.branch }}
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

      - name: Upload Release Asset
        id: upload_release_asset
        uses: actions/upload-release-asset@v1
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          upload_url: ${{ steps.get_current_release.outputs.upload_url }}
          asset_path: ./dist/${{env.wheel_name}}
          asset_name: ${{env.wheel_name}}
          asset_content_type: application/*

  publish_package:
    name: Publish package
    needs: [build_wheels]

    runs-on: ubuntu-latest

    steps:
      - uses: actions/checkout@v4

      - uses: actions/setup-python@v5
        with:
          python-version: '3.10'

      - name: Install dependencies
        run: |
          pip install ninja packaging setuptools wheel twine
          # We don't want to download anything CUDA-related here
          pip install torch --index-url https://download.pytorch.org/whl/cpu

      - name: Build core package
        env:
          CAUSAL_CONV1D_SKIP_CUDA_BUILD: "TRUE"
        run: |
          python setup.py sdist --dist-dir=dist

      - name: Deploy
        env:
          TWINE_USERNAME: "__token__"
          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
        run: |
          python -m twine upload dist/*
