import ctypes
import jax
import jax.numpy as jnp

# Load the compiled CUDA library
cuda_add_lib = ctypes.cdll.LoadLibrary('./add.so')

# Register the FFI target with JAX
jax.ffi.register_ffi_target(
    "cuda_add",                                    # target name
    jax.ffi.pycapsule(cuda_add_lib.CudaAdd),      # wrapped function pointer
    platform="CUDA"                               # specify CUDA platform
)

def cuda_elementwise_add(a, b):
    """Elementwise addition using CUDA kernel via JAX FFI."""
    
    # Ensure inputs are float32 (matching our kernel)
    if a.dtype != jnp.float32 or b.dtype != jnp.float32:
        raise ValueError("Inputs must be float32")
    
    # Ensure inputs have the same shape
    if a.shape != b.shape:
        raise ValueError("Input shapes must match")
    
    # Call the FFI function
    result = jax.ffi.ffi_call(
        "cuda_add",                                    # target name (must match registration)
        jax.ShapeDtypeStruct(a.shape, a.dtype),      # output shape/dtype specification
        vmap_method="broadcast_all"                    # handle batching
    )(a, b)
    
    return result

# Example usage
if __name__ == "__main__":
    # Create test inputs
    a = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=jnp.float32)
    b = jnp.array([5.0, 6.0, 7.0, 8.0], dtype=jnp.float32)
    
    # Call our CUDA kernel
    result = cuda_elementwise_add(a, b)
    
    # Verify result
    expected = a + b  # JAX reference
    print(f"CUDA result: {result}")
    print(f"Expected:    {expected}")
    print(f"Match: {jnp.allclose(result, expected)}")
    
    # Test with 2D arrays
    a_2d = jnp.ones((2, 3), dtype=jnp.float32)
    b_2d = jnp.ones((2, 3), dtype=jnp.float32) * 2
    result_2d = cuda_elementwise_add(a_2d, b_2d)
    print(f"2D result: {result_2d}")
    
    # Test with vmap (automatic batching)
    batched_a = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) 
    batched_b = jnp.array([[5, 6], [7, 8]], dtype=jnp.float32)
    vmap_result = jax.vmap(cuda_elementwise_add)(batched_a, batched_b)
    print(f"vmap result: {vmap_result}")
