#!/bin/bash
# compile.sh - Compilation script for argmin.so

# Get JAX include directory
JAX_INCLUDE=$(python -c "from jax import ffi; print(ffi.include_dir(), end='')")

# Compile the CUDA code into a shared library
nvcc -shared -Xcompiler -fPIC \
    -isystem ${JAX_INCLUDE} \
    -arch=sm_80 \
    -o assign_indices.so \
    assign_indices.cu \
    -lcuda -lcudart \
    -std=c++17
