#!/bin/bash
# compile.sh - Example compilation script for cuda_add.so

# Get JAX include directory
JAX_INCLUDE=$(python -c "from jax import ffi; print(ffi.include_dir(), end='')")

# Compile the CUDA code into a shared library
nvcc -shared -Xcompiler -fPIC \
    -isystem ${JAX_INCLUDE} \
    -arch=sm_80 \
    -o add.so \
    add.cu \
    -lcuda -lcudart

echo "Compiled add.so successfully!"
