#include "xla/ffi/api/ffi.h"
#include <cuda_runtime.h>

namespace ffi = xla::ffi;

__global__ void argmin_kernel(const float* x_ptr, int* out_ptr) {
        // x shape: [32, 64]
    int tid = threadIdx.x; // Only one warp with 32 threads
        
    if (tid < 32) {
        const float* x = x_ptr + tid * 64; // Each thread processes one row of 64 elements
        float min_val = x[0];
        int min_idx = 0;
        for (int i = 1; i < 64; ++i) {
            if (x[i] < min_val) {
                min_val = x[i];
                min_idx = i;
            }
        }
        out_ptr[tid] = min_idx; // Store the index of the minimum value for this row
    }
}

// FFI wrapper function
ffi::Error ArgminImpl(cudaStream_t stream, 
                      ffi::Buffer<ffi::F32> x,
                      ffi::ResultBuffer<ffi::S32> out) {
    // Check input dimensions are exactly [32, 64]
    if (x.dimensions().size() != 2 || 
        x.dimensions()[0] != 32 || 
        x.dimensions()[1] != 64) {
        return ffi::Error::InvalidArgument("Input must have shape [32, 64]");
    }
    
    // Check output dimensions are exactly [32]
    if (out->dimensions().size() != 1 || 
        out->dimensions()[0] != 32) {
        return ffi::Error::InvalidArgument("Output must have shape [32]");
    }
    
    // Launch kernel with single block of 32 threads
    argmin_kernel<<<1, 32, 0, stream>>>(
        x.typed_data(), out->typed_data());
    
    // Check for CUDA errors
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        return ffi::Error::Internal(
            std::string("CUDA error: ") + cudaGetErrorString(err));
    }
    
    return ffi::Error::Success();
}

// Register the FFI handler
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    Argmin, ArgminImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // CUDA stream context
        .Arg<ffi::Buffer<ffi::F32>>()              // input x [32, 64]
        .Ret<ffi::Buffer<ffi::S32>>()              // output indices [32]
);
