#include "xla/ffi/api/ffi.h"
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <mma.h>

namespace ffi = xla::ffi;
using namespace nvcuda;

__global__ void costs_kernel(
    const __nv_bfloat16* xs,      // [B, 512, 64]
    const __nv_bfloat16* totals,  // [B, 64, 64] 
    const int32_t* counts,        // [B, 64]
    __nv_bfloat16* costs          // [B, 512, 64] output
) {
    const int tid = threadIdx.x;
    const int warp_id = tid / 32;
    const int lane_id = tid % 32;

    const int bid = blockIdx.x; // Assuming one block per batch element
    xs = xs + bid * 512 * 64; // Adjust for batch index
    totals = totals + bid * 64 * 64; // Adjust for batch index
    counts = counts + bid * 64; // Adjust for batch index
    costs = costs + bid * 512 * 64; // Adjust for batch index
    
    __shared__ __align__(16) __nv_bfloat16 centroids[64 * 64];  // [64, 64] flattened for WMMA
    __shared__ __nv_bfloat16 centroid_sq_norms[64];  // [64]
    
    for (int d = tid; d < 4096; d += 128) {
	int k = d / 64;
	centroids[d] = __hdiv(totals[d], __int2bfloat16_rn(counts[k]));
    }
    
    __syncthreads();
    
    // Compute centroid squared norms: csq = einsum('kd,kd->k', cs, cs)
    for (int k = tid; k < 64; k += 128) {
        float sum = 0.0f;
        for (int d = 0; d < 64; d++) {
            float val = __bfloat162float(centroids[k * 64 + d]);
            sum += val * val;
        }
        centroid_sq_norms[k] = __float2bfloat16(sum);
    }
    
    __syncthreads();
    
    __shared__ float acc_results[4][16 * 16];  // One per warp
    wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag;
    
    // Each warp handles different tile positions
    int warp_row = warp_id * 8;  // Each warp handles 8 rows of 16x16 tiles
    
    // Process tiles within warp's assigned rows
    for (int tile_row = warp_row; tile_row < warp_row + 8 && tile_row * 16 < 512; tile_row++) {
        for (int tile_col = 0; tile_col < 4; tile_col++) {  // 4 tiles to cover 64 cols
            // Initialize accumulator
            wmma::fill_fragment(acc_frag, 0.0f);
            
            // Compute A @ B.T over the reduction dimension (4 tiles of 16)
            for (int tile_k = 0; tile_k < 4; tile_k++) {
                // Load A fragment: xs[tile_row*16:(tile_row+1)*16, tile_k*16:(tile_k+1)*16]
                wmma::load_matrix_sync(a_frag, xs + tile_row * 16 * 64 + tile_k * 16, 64);
                
                // Load B fragment as column-major to get centroids.T effect
                // We want centroids[tile_k*16:(tile_k+1)*16, tile_col*16:(tile_col+1)*16].T
                wmma::load_matrix_sync(b_frag, centroids + tile_k * 16 + tile_col * 16 * 64, 64);
                
                // Multiply-accumulate
                wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
            }
            
            // Store results and compute final costs
            //wmma::store_matrix_sync(output_fp32 + tile_row * 16 * 64 + tile_col * 16, acc_frag, 16, wmma::mem_row_major);
            wmma::store_matrix_sync(acc_results[warp_id], acc_frag, 16, wmma::mem_row_major);
	    //continue;
	    for (int row_pair = 0; row_pair < 8; row_pair++) {
		int row_idx = 2 * row_pair + lane_id / 16;
		int col_idx = lane_id % 16;
		int out_row = tile_row * 16 + row_idx;
		int out_col = tile_col * 16 + col_idx;
		if (out_row < 512 && out_col < 64) {
		    costs[out_row * 64 + out_col] = __float2bfloat16(-2.0f * acc_results[warp_id][row_idx * 16 + col_idx]) + centroid_sq_norms[out_col];
		}
	    }

	

        }
    }
}

__global__ void _similarity_kernel(
    const __nv_bfloat16* xs,        // [512, 64]
    const __nv_bfloat16* centroids, // [64, 64]
    float* similarity       // [512, 64] output
) {
    // Use 4 warps (128 threads) in a single block to utilize 4 tensor cores
    const int tid = threadIdx.x;
    const int warp_id = tid / 32;
    
    // WMMA fragments for bf16 computation
    wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag;
    
    int warp_tile_start = warp_id * 8;  // Each warp handles 8 rows of tiles
    
    for (int tile_row = warp_tile_start; tile_row < warp_tile_start + 8 && tile_row * 16 < 512; tile_row++) {
        for (int tile_col = 0; tile_col < 4; tile_col++) {  // 4 tiles to cover 64 columns
            wmma::fill_fragment(acc_frag, 0.0f);
            for (int tile_k = 0; tile_k < 4; tile_k++) {
                wmma::load_matrix_sync(a_frag, xs + tile_row * 16 * 64 + tile_k * 16, 64);
                wmma::load_matrix_sync(b_frag, centroids + tile_col * 16 * 64 + tile_k * 16, 64);
                wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
            }
            wmma::store_matrix_sync(similarity + tile_row * 16 * 64 + tile_col * 16, acc_frag, 64, wmma::mem_row_major);
        }
    }
}

template<int N, int K, int D>
__global__ void similarity_kernel(
    const __half* xs,        // [B, 512, 64] [B, N, D]
    const __half* centroids, // [B, 64, 64] [B, K, D]
    __half* similarity       // [B, 512, 64] output [B, N, K]
) {
    const int warp_id = threadIdx.x / 32;
    const int thread_count = blockDim.x;
    const int warpcount = thread_count / 32; // Number of warps in the block
    const int batch_idx = blockIdx.x; // Assuming one block per batch

    auto my_xs = xs + batch_idx * N * D;
    auto my_centroids = centroids + batch_idx * K * D;
    auto my_similarity = similarity + batch_idx * N * K;

    const int WMMA_A = 16;
    const int WMMA_B = 16;
    const int WMMA_C = 16;
    wmma::fragment<wmma::matrix_a, WMMA_A, WMMA_B, WMMA_C, __half, wmma::row_major> xs_frag;
    wmma::fragment<wmma::matrix_b, WMMA_A, WMMA_B, WMMA_C, __half, wmma::col_major> cs_frag;
    wmma::fragment<wmma::accumulator, WMMA_A, WMMA_B, WMMA_C, __half> acc_frag;

    for (int tile_n = warp_id * WMMA_A; tile_n < N; tile_n += warpcount * WMMA_A) {
        for (int tile_k = 0; tile_k < K; tile_k += WMMA_B) {
            wmma::fill_fragment(acc_frag, 0.0f);
            for (int tile_d = 0; tile_d < D; tile_d += WMMA_C) {
                wmma::load_matrix_sync(xs_frag, my_xs + tile_n * D + tile_d, D);
                wmma::load_matrix_sync(cs_frag, my_centroids + tile_k * D + tile_d, D);
                wmma::mma_sync(acc_frag, xs_frag, cs_frag, acc_frag);
            }
            wmma::store_matrix_sync(my_similarity + tile_n * K + tile_k, acc_frag, K, wmma::mem_row_major);
        }
    }
}

__device__ void warp_print(float x) {
    // Print from a single warp thread
    int lane_id = threadIdx.x % 32;
    int warp_id = threadIdx.x / 32;
    if (lane_id == 0) printf("Warp [%d]:", warp_id);
    for (int i = 0; i < 32; i++) if (i == lane_id) printf(" %.2f", x);
    if (lane_id == 0) printf("\n");
}

//IMPORTANT: Consider using vectorized half2 operations throughout
//IMPORTANT: Use half2 and padding to avoid bank conflicts
//IMPORTANT: Check that all reductions use atomics or warp primitives
template<int N, int K, int D, int iters>
__global__ void kmeans_kernel(
    const __half* xs_ptr,        // [B, 512, 64] [B, N, D]
    __half* totals_ptr, // [B, 64, 64] [B, K, D]
    int32_t* counts_ptr,        // [B, 64] [B, K]
    int32_t* labels_ptr // [B, 512] [B, N]
) {
    const int tid = threadIdx.x;
    const int tcnt = blockDim.x;
    const int lcnt = 32;
    const int lid = tid % lcnt;
    const int wid = tid / lcnt;
    const int wcnt = tcnt / lcnt;
    __syncthreads(); // Ensure all threads have printed their initial values
    //if (tid < 64) {
    //    printf("xs_ptr[%d]: %f\n", tid, __half2float(xs_ptr[tid]));
    //}

    const int batch_idx = blockIdx.x; // Assuming one block per batch
    xs_ptr = xs_ptr + batch_idx * N * D; // Adjust for batch index
    totals_ptr = totals_ptr + batch_idx * K * D; // Adjust for batch index
    counts_ptr = counts_ptr + batch_idx * K; // Adjust for batch index
    labels_ptr = labels_ptr + batch_idx * N; // Adjust for batch index

    const int Dpad = D + 8; // Pad D to avoid bank conflicts, if needed
    const int Kpad = K + 4;


    __shared__ __half centroids[K][Dpad];
    __shared__ float totals[K][Dpad];
    __shared__ float centroid_sq_norms[K];
    __shared__ int32_t counts[K];
    __shared__ int32_t labels[N];

    // Tile sizes
    const int TN = 8;
    const int TK = 32;
    const int TD = 16;
    // WMMA fragments for cost computation
    wmma::fragment<wmma::matrix_a, TN, TK, TD, __half, wmma::row_major> xs_frag;
    wmma::fragment<wmma::matrix_b, TN, TK, TD, __half, wmma::col_major> cs_frag;
    wmma::fragment<wmma::accumulator, TN, TK, TD, float> acc_frag;

    //__shared__ float sim[warps][TN][K]; // Shared memory for dot products
    extern __shared__ char shared_memory[]; // Dynamic shared memory
    __half (*xs)[Dpad] = reinterpret_cast<__half(*)[Dpad]>(shared_memory); // Shared memory for input points
    size_t xs_size = N * Dpad * sizeof(__half);
    float (*sim)[TN][Kpad] = reinterpret_cast<float(*)[TN][Kpad]>(shared_memory + xs_size); // Shared memory for dot products
    //extern __shared__ float sim[][TN][Kpad]; // Dynamic shared memory for dot products
    //__shared__ __half xs[N][Dpad]; // Shared memory for input points

    for (int i = tid; i < N * D; i += tcnt) {
        int n = i / D;
        int d = i % D;
        xs[n][d] = xs_ptr[i]; // Load input points into shared memory
    }


    

    //Load initial centroids and counts
    for (int i = tid; i < K; i += tcnt) {
        counts[i] = counts_ptr[i];
        centroid_sq_norms[i] = 0.0f;
    }
    for (int i = tid; i < K*D; i += tcnt) {
        int k = i / D;
        int d = i % D;
        totals[k][d] = __half2float(totals_ptr[k * D + d]);
    }
    __syncthreads(); // Ensure we are initialized

    //Main loop for k-means iterations
    for (int iter = 0; iter < iters; iter++) {
        for (int i = tid; i < K; i += tcnt) {
            centroid_sq_norms[i] = 0.0f; // Reset norms for this iteration
        }
        __syncthreads(); // Ensure norms are reset
        // Compute centroids and sqmags
        for (int i = tid; i < K * D; i += tcnt) {
            int k = i / D;
            int d = i % D;
            float centroid_elem = totals[k][d] / float(counts[k]);
            centroids[k][d] = __float2half(centroid_elem);
            totals[k][d] = 0.0f; // Reset totals for next iteration
            // THIS NEEDS ATOMICS OR WARP REDUCTION
            float sq_elem = centroid_elem * centroid_elem;
            // Warp reduction
            for (int stride = 1; stride < lcnt; stride <<= 1) {
                sq_elem = sq_elem + __shfl_xor_sync(0xffffffff, sq_elem, stride);
            }
            if (lid == 0) {
                atomicAdd(&centroid_sq_norms[k], sq_elem); // Atomic add to shared memory
            }
            //centroid_sq_norms[k] = __hadd(centroid_sq_norms[k], __hmul(centroid_elem, centroid_elem));
        }
        __syncthreads();

        for (int i = tid; i < K; i += tcnt) {
            counts[i] = 0; // Reset counts for next iteration
        }
        __syncthreads(); // Ensure centroids and norms are ready
        // Iteratively assign and accumulate tiles of points
        for (int n = wid * TN; n < N; n += wcnt * TN) {
            // Write dot products into warp-specific shared memory
            for (int k = 0; k < K; k += TK) {
                wmma::fill_fragment(acc_frag, 0.0f);
                //if (n == 0 && wid == 0 && k == 0) for (int t = 0; t < acc_frag.num_elements; t++) printf("tid[%d] acc_frag[%d]: %f\n", tid, t, acc_frag.x[t]);
                for (int d = 0; d < D; d += TD) {
                    //wmma::load_matrix_sync(xs_frag, xs_ptr + n * D + d, D);
                    wmma::load_matrix_sync(xs_frag, &xs[n][d], Dpad);
                    wmma::load_matrix_sync(cs_frag, &centroids[k][d], Dpad);
                    wmma::mma_sync(acc_frag, xs_frag, cs_frag, acc_frag);
                    //if (n == 0 && wid == 0 && k == 0) for (int t = 0; t < acc_frag.num_elements; t++) printf("tid[%d] acc_frag[%d]: %f\n", tid, t, acc_frag.x[t]);
                }
                // Store dot products in shared memory
                wmma::store_matrix_sync(&sim[wid][0][k], acc_frag, Kpad, wmma::mem_row_major);
            }
            // Assign tile of points by minimum cost
            for (int p = 0; p < TN; p++) {
                // Argmin cost
                float min_val = centroid_sq_norms[lid] - 2.0f * sim[wid][p][lid];
                int32_t min_idx = lid;
                // Reduce across this lane
                for (int i = lid + lcnt; i < K; i += lcnt) {
                    float other_val = centroid_sq_norms[i] - 2.0f * sim[wid][p][i];
                    bool update = (other_val < min_val || (other_val == min_val && i < min_idx));
                    min_val = update ? other_val : min_val;
                    min_idx = update ? i : min_idx;
                }
                // Reduce across the warp
                for (int stride = 1; stride < lcnt; stride <<= 1) {
                    float other_val = __shfl_xor_sync(0xffffffff, min_val, stride);
                    int other_idx = __shfl_xor_sync(0xffffffff, min_idx, stride);
                    bool update = (other_val < min_val || (other_val == min_val && other_idx < min_idx));
                    min_val = update ? other_val : min_val;
                    min_idx = update ? other_idx : min_idx;
                }
                // Now min_idx holds the index of the closest centroid
                // Update labels and counts
                if (lid == 0) {
                    labels[n + p] = min_idx; // Assign label
                    atomicAdd(&counts[min_idx], 1); // Increment count
                }
                // Update totals
                for (int i = lid; i < D; i += lcnt) {
                    //__half x_elem = xs_ptr[(n + p) * D + i];
                    __half x_elem = xs[n + p][i];
                    atomicAdd(&totals[min_idx][i], x_elem);
                }
            }
        }
        __syncthreads(); // Ensure totals and counts are ready
    }
    
    // Write out labels, totals and counts
    for (int i = tid; i < N; i += tcnt) {
        labels_ptr[i] = labels[i];
    }
    for (int i = tid; i < K * D; i += tcnt) {
        int k = i / D;
        int d = i % D;
        totals_ptr[k * D + d] = __float2half(totals[k][d]);
    }
    for (int i = tid; i < K; i += tcnt) {
        counts_ptr[i] = counts[i];
    }
}



// FFI wrapper function
ffi::Error CostsImpl(cudaStream_t stream,
                     ffi::Buffer<ffi::BF16> xs,
                     ffi::Buffer<ffi::BF16> totals,
                     ffi::Buffer<ffi::S32> counts,
                     ffi::ResultBuffer<ffi::BF16> costs) {
    // Validate input dimensions
    auto ndims_xs = xs.dimensions().size();
    if (ndims_xs < 2 ||
        xs.dimensions()[ndims_xs - 2] != 512 ||
        xs.dimensions()[ndims_xs - 1] != 64) {
        return ffi::Error::InvalidArgument("xs must have shape [..., 512, 64]");
    }

    auto ndims_totals = totals.dimensions().size();
    if (ndims_totals < 2 ||
        totals.dimensions()[ndims_totals - 2] != 64 ||
        totals.dimensions()[ndims_totals - 1] != 64) {
        return ffi::Error::InvalidArgument("totals must have shape [..., 64, 64]");
    }

    auto ndims_counts = counts.dimensions().size();
    if (ndims_counts < 1 || counts.dimensions()[ndims_counts - 1] != 64) {
        return ffi::Error::InvalidArgument("counts must have shape [..., 64]");
    }

    auto ndims_costs = costs->dimensions().size();
    if (ndims_costs < 2 ||
        costs->dimensions()[ndims_costs - 2] != 512 ||
        costs->dimensions()[ndims_costs - 1] != 64) {
        return ffi::Error::InvalidArgument("costs must have shape [..., 512, 64]");
    }

    // Calculate batch size (product of all but last 2 dimensions)
    size_t batch_size = 1;
    auto dims = xs.dimensions();
    for (size_t i = 0; i < dims.size() - 2; i++) {
        batch_size *= dims[i];
    }
    
    // Launch kernel with 4 warps (128 threads) per batch element
    const int num_warps = 4;
    costs_kernel<<<batch_size, num_warps * 32, 0, stream>>>(
        reinterpret_cast<const __nv_bfloat16*>(xs.typed_data()),
        reinterpret_cast<const __nv_bfloat16*>(totals.typed_data()),
        reinterpret_cast<const int32_t*>(counts.typed_data()),
        reinterpret_cast<__nv_bfloat16*>(costs->typed_data())
    );
    
    // Check for CUDA errors
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        return ffi::Error::Internal(
            std::string("CUDA error: ") + cudaGetErrorString(err));
    }
    
    return ffi::Error::Success();
}

// FFI wrapper function for similarity kernel
ffi::Error SimilarityImpl(cudaStream_t stream,
                          ffi::Buffer<ffi::F16> xs,
                          ffi::Buffer<ffi::F16> centroids,
                          ffi::ResultBuffer<ffi::F16> similarity) {
    const int N = 512; // Number of data points
    const int K = 64;  // Number of centroids
    const int D = 64;  // Dimension of each data point and centroid
    // Validate input dimensions
    auto ndims_xs = xs.dimensions().size();
    if (ndims_xs < 2 ||
        xs.dimensions()[ndims_xs - 2] != N ||
        xs.dimensions()[ndims_xs - 1] != D) {
        return ffi::Error::InvalidArgument("xs must have shape [..., 512, 64]");
    }
   
    auto ndims_centroids = centroids.dimensions().size();
    if (ndims_centroids < 2 ||
        centroids.dimensions()[ndims_centroids - 2] != K ||
        centroids.dimensions()[ndims_centroids - 1] != D) {
        return ffi::Error::InvalidArgument("centroids must have shape [..., 64, 64]");
    }
    
    auto ndims_similarity = similarity->dimensions().size();
    if (ndims_similarity < 2 ||
        similarity->dimensions()[ndims_similarity - 2] != N ||
        similarity->dimensions()[ndims_similarity - 1] != K) {
        return ffi::Error::InvalidArgument("similarity must have shape [..., 512, 64]");
    }
    
    // Calculate batch size (product of all but last 2 dimensions)
    size_t batch_size = 1;
    auto dims = xs.dimensions();
    for (size_t i = 0; i < dims.size() - 2; i++) {
        batch_size *= dims[i];
    }
    
    // Launch kernel with 4 warps (128 threads) per batch element
    const int num_warps = 4;
    similarity_kernel<N,K,D><<<batch_size, num_warps * 32, 0, stream>>>(
        reinterpret_cast<const __half*>(xs.typed_data()),
        reinterpret_cast<const __half*>(centroids.typed_data()),
        reinterpret_cast<__half*>(similarity->typed_data())
    );
    
    // Check for CUDA errors
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        return ffi::Error::Internal(
            std::string("CUDA error: ") + cudaGetErrorString(err));
    }
    
    return ffi::Error::Success();
}

// FFI wrapper function for kmeans kernel
ffi::Error KmeansImpl(cudaStream_t stream,
                      ffi::Buffer<ffi::F16> xs,
                      ffi::Buffer<ffi::F16> totals,
                      ffi::Buffer<ffi::S32> counts,
                      ffi::ResultBuffer<ffi::S32> labels) {
    const int N = 512; // Number of data points
    const int K = 64;  // Number of centroids
    const int D = 64;  // Dimension of each data point and centroid
    const int iters = 100; // Number of iterations
    
    // Validate input dimensions
    auto ndims_xs = xs.dimensions().size();
    if (ndims_xs < 2 ||
        xs.dimensions()[ndims_xs - 2] != N ||
        xs.dimensions()[ndims_xs - 1] != D) {
        return ffi::Error::InvalidArgument("xs must have shape [..., 512, 64]");
    }

    auto ndims_totals = totals.dimensions().size();
    if (ndims_totals < 2 ||
        totals.dimensions()[ndims_totals - 2] != K ||
        totals.dimensions()[ndims_totals - 1] != D) {
        return ffi::Error::InvalidArgument("totals must have shape [..., 64, 64]");
    }

    auto ndims_counts = counts.dimensions().size();
    if (ndims_counts < 1 || counts.dimensions()[ndims_counts - 1] != K) {
        return ffi::Error::InvalidArgument("counts must have shape [..., 64]");
    }

    auto ndims_labels = labels->dimensions().size();
    if (ndims_labels < 1 || labels->dimensions()[ndims_labels - 1] != N) {
        return ffi::Error::InvalidArgument("labels must have shape [..., 512]");
    }

    // Calculate batch size (product of all but last dimensions)
    size_t batch_size = 1;
    auto dims = xs.dimensions();
    for (size_t i = 0; i < dims.size() - 2; i++) {
        batch_size *= dims[i];
    }
    
    // Launch kernel with 4 warps (128 threads) per batch element
    const int num_warps = 16;
    const size_t dynamic_shared_memory_size = num_warps * 8 * (K + 4) * sizeof(float) + N * (D + 8) * sizeof(__half); // Shared memory for input points, padded to avoid bank conflicts
    cudaFuncSetAttribute(
        kmeans_kernel<N,K,D,iters>,
        cudaFuncAttributeMaxDynamicSharedMemorySize,
        dynamic_shared_memory_size // 90KB of shared memory per block
    );
    kmeans_kernel<N,K,D,iters><<<batch_size, num_warps * 32, dynamic_shared_memory_size, stream>>>(
        reinterpret_cast<const __half*>(xs.typed_data()),
        reinterpret_cast<__half*>(totals.typed_data()),
        reinterpret_cast<int32_t*>(counts.typed_data()),
        reinterpret_cast<int32_t*>(labels->typed_data())
    );
    
    // Check for CUDA errors
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        return ffi::Error::Internal(
            std::string("CUDA error: ") + cudaGetErrorString(err));
    }
    
    return ffi::Error::Success();
}

// Register the FFI handler
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    KmeansCosts, CostsImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // CUDA stream context
        .Arg<ffi::Buffer<ffi::BF16>>()             // xs [512, 64]
        .Arg<ffi::Buffer<ffi::BF16>>()             // totals [64, 64]
        .Arg<ffi::Buffer<ffi::S32>>()              // counts [64]
        .Ret<ffi::Buffer<ffi::BF16>>()             // costs [512, 64]
);

XLA_FFI_DEFINE_HANDLER_SYMBOL(
    KmeansSimilarity, SimilarityImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // CUDA stream context
        .Arg<ffi::Buffer<ffi::F16>>()             // xs [512, 64]
        .Arg<ffi::Buffer<ffi::F16>>()             // centroids [64, 64]
        .Ret<ffi::Buffer<ffi::F16>>()             // similarity [512, 64]
);

XLA_FFI_DEFINE_HANDLER_SYMBOL(
    KmeansKernel, KmeansImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // CUDA stream context
        .Arg<ffi::Buffer<ffi::F16>>()             // xs [512, 64]
        .Arg<ffi::Buffer<ffi::F16>>()             // totals [64, 64]
        .Arg<ffi::Buffer<ffi::S32>>()             // counts [64]
        .Ret<ffi::Buffer<ffi::S32>>()             // labels [512]
);
