#include "xla/ffi/api/ffi.h"
#include <cuda_runtime.h>
#include <math.h>
#include <cooperative_groups.h>
//namespace ffi = xla::ffi;
//using namespace nvcuda;
using namespace xla::ffi;
namespace cg = cooperative_groups;

template<int K>
__global__ void assign_kernel(
    const float* costs_ptr, // [B, K, N]
    int32_t* cnt_ptr, // [B, K]
    int32_t* lab_ptr, // [B, N]
    int32_t* fwd_ptr, // [B, K, CAP]
    int32_t* bwd_ptr, // [B, N]
    int32_t N,
    int32_t CAP
) {
    const int tid = threadIdx.x;
    const int tcnt = blockDim.x;
    const int lcnt = 32;
    const int lid = tid % lcnt;
    const int wid = tid / lcnt;
    const int wcnt = tcnt / lcnt;

    const int batch_idx = blockIdx.x;
    costs_ptr += batch_idx * K * N;
    cnt_ptr += batch_idx * K;
    lab_ptr += batch_idx * N;
    fwd_ptr += batch_idx * K * CAP;
    bwd_ptr += batch_idx * N;

    __shared__ int32_t counts[K];
    for (int i = tid; i < K; i += tcnt) {
        counts[i] = 0; // Initialize counts to zero
    }

    for (int i = tid; i < K * CAP; i += tcnt) {
        fwd_ptr[i] = -1; // Initialize forward pointers to -1
    }

    __syncthreads();

    for (int i = tid; i < N; i += tcnt) {
        float min_cost = INFINITY;
        int min_idx = -1;
        for (int k = 0; k < K; ++k) {
            bool available = (counts[k] < CAP);
            if (!available) continue; // Skip if the count for this label is already at capacity
            const float cost = costs_ptr[k * N + i];
            if (cost < min_cost) {
                min_cost = cost;
                min_idx = k;
            }
        }
        if (min_idx == -1) for (int k = 0; k < K; ++k) if (counts[k] < CAP) {
            min_idx = k;
            break;
        }
        if (min_idx == -1) min_idx = 0; // Fallback to the first label if all are at capacity

        int old_count = atomicAdd(&counts[min_idx], 1);
        while (old_count >= CAP) {
            atomicAdd(&counts[min_idx], -1);
            min_cost = INFINITY;
            min_idx = -1;
            for (int k = 0; k < K; ++k) {
                bool available = (counts[k] < CAP);
                if (!available) continue; // Skip if the count for this label is already at capacity
                const float cost = costs_ptr[k * N + i];
                if (cost < min_cost) {
                    min_cost = cost;
                    min_idx = k;
                }
            }
            if (min_idx == -1) for (int k = 0; k < K; ++k) if (counts[k] < CAP) {
                min_idx = k;
                break;
            }
            if (min_idx == -1) min_idx = 0; // Fallback to the first label if all are at capacity
            old_count = atomicAdd(&counts[min_idx], 1);
        }

        bool valid = (old_count < CAP);
        if (valid) {
            lab_ptr[i] = min_idx;
            fwd_ptr[min_idx * CAP + old_count] = i;
            bwd_ptr[i] = old_count;
        } else {
            lab_ptr[i] = 0;
            // omit the assignment to fwd_ptr
            bwd_ptr[i] = 0; // Give an arbitrary valid index
        }
    }

    if (false) {
    __syncthreads();
    //for (int k = tid; k < K; k += tcnt) {
    if (wid == 0) for (int k = lid; k < K; k += lcnt) {
        auto active_threads = cg::coalesced_threads();
        // check for zero counts and assign the kth point to them if they exist
        while (active_threads.any(counts[k] == 0)) {
            if (counts[k] == 0) {
                int new_old_count = atomicAdd(&counts[k], 1); // Increment count for cluster k
                int old_min_idx = atomicExch(&lab_ptr[k], k); // Assign point k to cluster k if it has zero count
                atomicAdd(&counts[old_min_idx], -1); // Decrement the count of the old cluster
                int old_old_count = atomicExch(&bwd_ptr[k], new_old_count); // Set the backward pointer
                fwd_ptr[old_min_idx * CAP + old_old_count] = -1; // Remove the old point from the forward pointer
                fwd_ptr[k * CAP + new_old_count] = k; // Assign the new point to the forward pointer
            }
            active_threads.sync();
        }
    }
    }

    __syncthreads();
    for (int i = tid; i < K; i += tcnt) {
        cnt_ptr[i] = counts[i];
    }
}

template <int K>
Error assign_host(
    cudaStream_t stream,
    Buffer<F32> costs,
    ResultBuffer<S32> cnt,
    ResultBuffer<S32> lab,
    ResultBuffer<S32> fwd,
    ResultBuffer<S32> bwd,
    int64_t N,
    int64_t CAP,
    int64_t num_warps
) {
    dim3 block(num_warps * 32);
    size_t batch_size = 1;
    auto dims = costs.dimensions();
    for (size_t i = 0; i < dims.size() - 2; ++i) {
        batch_size *= dims[i];
    }
    dim3 grid(batch_size);
    assign_kernel<K><<<grid, block, 0, stream>>>(
        costs.typed_data(),
        cnt->typed_data(),
        lab->typed_data(),
        fwd->typed_data(),
        bwd->typed_data(),
        int(N),
        int(CAP)
    );
    cudaError_t last_err = cudaGetLastError();
    if (last_err != cudaSuccess) return Error::Internal(std::string("CUDA error: ") + cudaGetErrorString(last_err));
    return Error::Success();
}

Error assign_host_multiK(
    cudaStream_t stream,
    Buffer<F32> costs,
    ResultBuffer<S32> cnt,
    ResultBuffer<S32> lab,
    ResultBuffer<S32> fwd,
    ResultBuffer<S32> bwd,
    int64_t N,
    int64_t CAP,
    int64_t num_warps
) {
    const int num_batch_dims = costs.dimensions().size() - 2;
    if (num_batch_dims < 0) return Error::InvalidArgument("costs must be rank 2 plus batch dims");
    const int K = costs.dimensions()[num_batch_dims];
    switch (K) {
        case 8: return assign_host<8>(stream, costs, cnt, lab, fwd, bwd, N, CAP, num_warps);
        case 16: return assign_host<16>(stream, costs, cnt, lab, fwd, bwd, N, CAP, num_warps);
        case 32: return assign_host<32>(stream, costs, cnt, lab, fwd, bwd, N, CAP, num_warps);
        case 64: return assign_host<64>(stream, costs, cnt, lab, fwd, bwd, N, CAP, num_warps);
        case 128: return assign_host<128>(stream, costs, cnt, lab, fwd, bwd, N, CAP, num_warps);
        case 256: return assign_host<256>(stream, costs, cnt, lab, fwd, bwd, N, CAP, num_warps);
        case 512: return assign_host<512>(stream, costs, cnt, lab, fwd, bwd, N, CAP, num_warps);
        default: return Error::InvalidArgument("K must be 8/16/32/64/128/256/512");
    }
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
    assign_indices,
    assign_host_multiK,
    Ffi::Bind()
        .Ctx<PlatformStream<cudaStream_t>>() // stream
        .Arg<Buffer<F32>>() // costs
        .Ret<Buffer<S32>>() // cnt
        .Ret<Buffer<S32>>() // lab
        .Ret<Buffer<S32>>() // fwd
        .Ret<Buffer<S32>>() // bwd
        .Attr<int64_t>("N") // N
        .Attr<int64_t>("CAP") // CAP
        .Attr<int64_t>("num_warps"), // num_warps
    {Traits::kCmdBufferCompatible}
);

