#include "xla/ffi/api/ffi.h"
#include <cuda_runtime.h>
#include <math.h>
#include <cutlass/numeric_types.h>
namespace ffi = xla::ffi;
using cutlass::bfloat16_t;

template <int K, int D>
__global__ void assign_kernel(
    const bfloat16_t* xs_ptr, // [B, N, D]
    const bfloat16_t* centroids_ptr, // [K, D]
    int32_t* labels_ptr, // [B, N]
    int32_t* new_centroids_ptr, // [K, D]
    int32_t N
) {
    const int tid = threadIdx.x;
    const int tcnt = blockDim.x;

    const int batch_idx = blockIdx.x;
    xs_ptr += batch_idx * N * D;
    centroids_ptr += batch_idx * K * D;
    new_centroids_ptr += batch_idx * K * D;
    labels_ptr += batch_idx * N;

    // placeholder - just write out zeros
    for (int i = tid; i < K * D; i += tcnt) {
        new_centroids_ptr[i] = 0.0;
    }
    for (int i = tid; i < N; i += tcnt) {
        labels_ptr[i] = 0;
    }
}


Error assign_host(
    cudaStream_t stream,
    ffi::Buffer<ffi::BF16> xs,
    ffi::Buffer<ffi::BF16> centroids,
    ffi::ResultBuffer<ffi::S32> labels,
    ffi::ResultBuffer<ffi::BF16> new_centroids,
    int64_t N,
    int64_t num_warps
) {
    dim3 block(num_warps * 32);
    size_t batch_size = 1;
    auto dims = xs.dimensions();
    for (size_t i = 0; i < dims.size() - 2; ++i) {
        batch_size *= dims[i];
    }
    dim3 grid(batch_size);
    assign_kernel<64, 64><<<grid, block, 0, stream>>>(
        xs.typed_data(),
        centroids.typed_data(),
        labels->typed_data(),
        new_centroids->typed_data(),
        int(N)
    );
    cudaError_t last_err = cudaGetLastError();
    if (last_err != cudaSuccess) {
        return ffi::Error::Internal(std::string("CUDA error: ") + cudaGetErrorString(last_err));
    }
    return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
    assign,
    assign_host,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>() //stream
        .Arg<ffi::Buffer<ffi::BF16>>() // xs
        .Arg<ffi::Buffer<ffi::BF16>>() // centroids
        .Ret<ffi::Buffer<ffi::S32>>() // labels
        .Ret<ffi::Buffer<ffi::BF16>>() // new_centroids
        .Attr<int64_t>("N") // N
        .Attr<int64_t>("num_warps"), // num_warps
    {Traits::kCmdBufferCompatible}
);
