// cuda_add.cpp - This would be compiled into cuda_add.so
#include "xla/ffi/api/ffi.h"
#include <cuda_runtime.h>

namespace ffi = xla::ffi;

// CUDA kernel (assumes this is defined elsewhere or in the same file)
__global__ void elementwise_add_kernel(const float* a, const float* b, float* c, int size) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < size) {
        c[idx] = a[idx] + b[idx];
    }
}

// FFI wrapper function
ffi::Error CudaAddImpl(cudaStream_t stream, 
                       ffi::Buffer<ffi::F32> a,
                       ffi::Buffer<ffi::F32> b, 
                       ffi::ResultBuffer<ffi::F32> c) {
    int size = a.element_count();
    
    // Launch CUDA kernel
    int block_size = 256;
    int grid_size = (size + block_size - 1) / block_size;
    
    elementwise_add_kernel<<<grid_size, block_size, 0, stream>>>(
        a.typed_data(), b.typed_data(), c->typed_data(), size);
    
    // Check for CUDA errors (optional but recommended)
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        return ffi::Error::Internal("CUDA kernel launch failed");
    }
    
    return ffi::Error::Success();
}

// Register the FFI handler
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    CudaAdd, CudaAddImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // CUDA stream context
        .Arg<ffi::Buffer<ffi::F32>>()              // input a
        .Arg<ffi::Buffer<ffi::F32>>()              // input b  
        .Ret<ffi::Buffer<ffi::F32>>()              // output c
);
