#!/bin/bash
# compile.sh - Compilation script for argmin.so

# Get JAX include directory
JAX_INCLUDE=$(python -c "from jax import ffi; print(ffi.include_dir(), end='')")

# Compile the CUDA code into a shared library
nvcc -shared -Xcompiler -fPIC \
    -isystem ${JAX_INCLUDE} \
    -I./cutlass/include \
    -arch=sm_80 \
    -o cluster.so \
    cluster.cu \
    -lcuda -lcudart \
    --resource-usage \
    -lineinfo \
    -std=c++17
