#!/bin/bash
# compile.sh - Compilation script for argmin.so

# Get JAX include directory
JAX_INCLUDE=$(python -c "from jax import ffi; print(ffi.include_dir(), end='')")

# Compile the CUDA code into a shared library
nvcc -shared -Xcompiler -fPIC \
    -isystem ${JAX_INCLUDE} \
    -arch=sm_80 \
    -o argmin.so \
    argmin.cu \
    -lcuda -lcudart

echo "Compiled argmin.so successfully!"