on:
  # manually trigger
  workflow_dispatch:
  pull_request:
    types:
      - "labeled"
      # to let the PR pass the test
      - "opened"
      - "reopened"
      - "synchronize"
  merge_group:
concurrency:
  group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
  cancel-in-progress: true
name: Test CUDA
jobs:
  test_cuda:
    name: Test Python and C++ on CUDA
    runs-on: nvidia
    # https://github.com/deepmodeling/deepmd-kit/pull/2884#issuecomment-1744216845
    container:
      image: nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04
      options: --gpus all
    if: github.repository_owner == 'deepmodeling' && (github.event_name == 'pull_request' && github.event.label && github.event.label.name == 'Test CUDA' || github.event_name == 'workflow_dispatch' || github.event_name == 'merge_group')
    steps:
    - name: Make sudo and git work
      run: apt-get update && apt-get install -y sudo git
    - uses: actions/checkout@v4
    - uses: actions/setup-python@v5
      with:
        python-version: '3.11'
        # cache: 'pip'
    - name: Setup MPI
      uses: mpi4py/setup-mpi@v1
      with:
        mpi: mpich
    - name: Install wget and unzip
      run: apt-get update && apt-get install -y wget unzip
    - uses: lukka/get-cmake@latest
      with:
        useLocalCache: true
        useCloudCache: false
    - run: |
         wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb \
         && sudo dpkg -i cuda-keyring_1.0-1_all.deb \
         && sudo apt-get update \
         && sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
      if: false  # skip as we use nvidia image
    - run: python -m pip install -U uv
    - run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.6.0" "jax[cuda12]==0.5.0"
    - run: |
        export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
        export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
        source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py --reinstall-package deepmd-kit
      env:
        DP_VARIANT: cuda
        DP_ENABLE_NATIVE_OPTIMIZATION: 1
        DP_ENABLE_PYTORCH: 1
    - run: dp --version
    - run: python -m pytest source/tests --durations=0
      env:
        NUM_WORKERS: 0
        CUDA_VISIBLE_DEVICES: 0
        # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
        XLA_PYTHON_CLIENT_PREALLOCATE: false
    - name: Convert models
      run: source/tests/infer/convert-models.sh
    - name: Download libtorch
      run: |
         wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcu124.zip -O libtorch.zip
         unzip libtorch.zip
    - run: |
        export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch
        export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
        source/install/test_cc_local.sh
      env:
        OMP_NUM_THREADS: 1
        TF_INTRA_OP_PARALLELISM_THREADS: 1
        TF_INTER_OP_PARALLELISM_THREADS: 1
        LMP_CXX11_ABI_0: 1
        CMAKE_GENERATOR: Ninja
        DP_VARIANT: cuda
        DP_USE_MPICH2: 1
    - run: |
        export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$LD_LIBRARY_PATH
        export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH
        python -m pytest -s source/lmp/tests || (cat log.lammps && exit 1)
        python -m pytest source/ipi/tests
      env:
        OMP_NUM_THREADS: 1
        TF_INTRA_OP_PARALLELISM_THREADS: 1
        TF_INTER_OP_PARALLELISM_THREADS: 1
        LAMMPS_PLUGIN_PATH: ${{ github.workspace }}/dp_test/lib/deepmd_lmp
        CUDA_VISIBLE_DEVICES: 0
  pass:
    name: Pass testing on CUDA
    needs: [test_cuda]
    runs-on: ubuntu-latest
    if: always()
    steps:
    - name: Decide whether the needed jobs succeeded or failed
      uses: re-actors/alls-green@release/v1
      with:
        jobs: ${{ toJSON(needs) }}
        allowed-skips: test_cuda
