#!/bin/bash
# compile.sh - Compilation script for argmin.so

# Get JAX include directory
JAX_INCLUDE=$(python -c "from jax import ffi; print(ffi.include_dir(), end='')")

# Compile the CUDA code into a shared library
nvcc -shared -Xcompiler -fPIC \
    -isystem ${JAX_INCLUDE} \
    -arch=sm_80 \
    -o kmeans.so \
    kmeans.cu \
    -lcuda -lcudart

echo "Compiled kmeans.so successfully!"
