// hungarian_kernel.cu

#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <stdio.h>

// Include necessary headers
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

// #define MAX_N 128  // Maximum allowed dimension
#define MAX_N 64  // Maximum allowed dimension

namespace ffi = xla::ffi;

// Device function to sum elements of a float array
__device__ float reduce_sumf(const float* x, int N)
{
    float sum = x[0];
    for (int i = 1; i < N; i++)
    {
        sum += x[i];
    }
    return sum;
}

// Device function to compare elements of a float array to a value
__device__ void vec_equalf(float* vec, float val, int N)
{
    for (int i = 0; i < N; i++)
    {
        if (vec[i] == val)
        {
            vec[i] = 1.0f;
        }
        else
        {
            vec[i] = 0.0f;
        }
    }
}

// Device function to check if an integer is in an integer array
__device__ bool k_in_vec(const int* vec, int k, int vec_max)
{
    for (int i = 0; i < vec_max; i++)
    {
        if (vec[i] == k)
        {
            return true;
        }
    }
    return false;
}

// Device function to collect integers not in a given integer array
__device__ void collect_not_in_vec(const int* vec, int vec_max, int N, int* res, int* collected_cnt)
{
    *collected_cnt = 0;
    for (int k = 0; k < N; k++)
    {
        bool in_check = k_in_vec(vec, k, vec_max);
        if (in_check)
        {
            continue;
        }
        res[*collected_cnt] = k;
        (*collected_cnt)++;
    }
}


// Device function to perform greedy matching on the residual unmatched rows and columns
__device__ void greedy_matching_residual(
    const float* cost_matrix,  // [N x N], the most recent cost matrix
    int N,
    int* matched_pair_i,       // [N], matched rows
    int* matched_pair_j,       // [N], matched columns
    int matched_cnt)           // Number of already matched pairs
{
    // Initialize arrays to keep track of matched rows and columns
    bool row_matched[MAX_N] = {false};
    bool col_matched[MAX_N] = {false};

    // Mark already matched rows and columns
    for (int i = 0; i < matched_cnt; i++)
    {
        int row = matched_pair_i[i];
        int col = matched_pair_j[i];
        row_matched[row] = true;
        col_matched[col] = true;
    }

    // For each unmatched row, find the unmatched column with minimum cost
    for (int i = 0; i < N; i++)
    {
        if (!row_matched[i])
        {
            float min_cost = 1e6f;
            int min_col = -1;

            // Search over unmatched columns
            for (int j = 0; j < N; j++)
            {
                if (!col_matched[j])
                {
                    float cost = cost_matrix[i * N + j];
                    if (cost < min_cost)
                    {
                        min_cost = cost;
                        min_col = j;
                    }
                }
            }

            if (min_col != -1)
            {
                // Add to matched pairs
                matched_pair_i[matched_cnt] = i;
                matched_pair_j[matched_cnt] = min_col;
                matched_cnt++;

                // Mark row and column as matched
                row_matched[i] = true;
                col_matched[min_col] = true;
            }
            else
            {
                // No available columns, cannot complete matching
                // Assign -1 to indicate failure for this row
                matched_pair_i[matched_cnt] = i;
                matched_pair_j[matched_cnt] = -1;
                matched_cnt++;
            }
        }
    }
}


// Kernel function implementing the Hungarian matching algorithm for batches
__global__ void hungarian_matching_kernel(
    const float* cost_matrices,    // [B x N x N]
    int* matched_pair_i_out,       // [B x N]
    int* matched_pair_j_out,       // [B x N]
    int N,
    int B)
{
    int batch_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (batch_idx >= B)
        return;

    // Offset pointers to the current batch
    const float* cost_matrix = cost_matrices + batch_idx * N * N;
    int* matched_pair_i = matched_pair_i_out + batch_idx * N;
    int* matched_pair_j = matched_pair_j_out + batch_idx * N;

    // Local arrays
    float local_cost_matrix[MAX_N * MAX_N];
    float zero_mask_matrix[MAX_N * MAX_N];
    float zero_mask_matrix_masked[MAX_N * MAX_N];
    int temp_matched_pair_i[MAX_N];
    int temp_matched_pair_j[MAX_N];
    int non_matched_i_vec[MAX_N];
    int non_matched_i_cnt = 0;
    int marked_cols[MAX_N];
    int marked_col_cnt = 0;
    int updated_matched_i_vec[MAX_N];
    int updated_matched_i_cnt = 0;

    // Copy cost matrix to local arrays
    for (int idx = 0; idx < N * N; idx++)
    {
        local_cost_matrix[idx] = cost_matrix[idx];
    }

    // Subtract row minima
    for (int row = 0; row < N; row++)
    {
        float row_min = local_cost_matrix[row * N];
        for (int col = 1; col < N; col++)
        {
            float val = local_cost_matrix[row * N + col];
            if (val < row_min)
                row_min = val;
        }
        for (int col = 0; col < N; col++)
        {
            local_cost_matrix[row * N + col] -= row_min;
        }
    }

    // Subtract column minima
    for (int col = 0; col < N; col++)
    {
        float col_min = local_cost_matrix[col];
        for (int row = 1; row < N; row++)
        {
            float val = local_cost_matrix[row * N + col];
            if (val < col_min)
                col_min = val;
        }
        for (int row = 0; row < N; row++)
        {
            local_cost_matrix[row * N + col] -= col_min;
        }
    }

    // Initialize zero_mask_matrix with the adjusted cost matrix
    for (int idx = 0; idx < N * N; idx++)
    {
        zero_mask_matrix[idx] = local_cost_matrix[idx];
    }

    const int MAX_ITER = 10;
    int iter = 0;

    while (true)
    {
        // Mask elements with zero
        for (int j = 0; j < N; j++)
        {
            vec_equalf(zero_mask_matrix + j * N, 0.0f, N);
        }

        // Copy zero_mask_matrix to zero_mask_matrix_masked
        for (int idx = 0; idx < N * N; idx++)
        {
            zero_mask_matrix_masked[idx] = zero_mask_matrix[idx];
        }

        // Initialize matching
        int num_matched_pair = 0;
        for (int k = 0; k < N; k++)
        {
            temp_matched_pair_i[k] = -1;
            temp_matched_pair_j[k] = -1;
        }

        // Perform greedy matching
        for (int _ = 0; _ < N; _++)
        {
            int selected_row_idx = -1;
            int selected_col_idx = -1;
            float min_zero_val_row = 1e6f;
            float min_zero_val_col = 1e6f;

            // Find the row with the minimum number of zeros
            for (int k = 0; k < N; k++)
            {
                float* row = zero_mask_matrix_masked + k * N;
                float nzero_in_row = reduce_sumf(row, N);
                if (nzero_in_row != 0.0f && nzero_in_row < min_zero_val_row)
                {
                    min_zero_val_row = nzero_in_row;
                    selected_row_idx = k;
                }
            }
            if (min_zero_val_row == 1e6f)
            {
                break;
            }

            // Find the column with the minimum number of zeros
            for (int k = 0; k < N; k++)
            {
                // Sum over column k
                float nzero_in_col = 0.0f;
                for (int i = 0; i < N; i++)
                {
                    nzero_in_col += zero_mask_matrix_masked[i * N + k];
                }
                if (nzero_in_col != 0.0f && nzero_in_col < min_zero_val_col)
                {
                    min_zero_val_col = nzero_in_col;
                    selected_col_idx = k;
                }
            }
            if (min_zero_val_col == 1e6f)
            {
                break;
            }

            // Decide whether to select a zero from the row or column
            if (min_zero_val_row <= min_zero_val_col)
            {
                // Select one zero in the selected row
                selected_col_idx = -1; // Reset selected_col_idx
                for (int k = 0; k < N; k++)
                {
                    if (zero_mask_matrix_masked[selected_row_idx * N + k] == 1.0f)
                    {
                        selected_col_idx = k;
                        break;
                    }
                }

                if (selected_col_idx == -1)
                    break; // No available zero
            }
            else
            {
                // Select one zero in the selected column
                selected_row_idx = -1; // Reset selected_row_idx
                for (int i = 0; i < N; i++)
                {
                    if (zero_mask_matrix_masked[i * N + selected_col_idx] == 1.0f)
                    {
                        selected_row_idx = i;
                        break;
                    }
                }

                if (selected_row_idx == -1)
                    break; // No available zero
            }

            // Mark selected row and column to avoid re-selection
            for (int l = 0; l < N; l++)
            {
                zero_mask_matrix_masked[l * N + selected_col_idx] = 0.0f;
                zero_mask_matrix_masked[selected_row_idx * N + l] = 0.0f;
            }

            // Record the matched pair
            temp_matched_pair_i[num_matched_pair] = selected_row_idx;
            temp_matched_pair_j[num_matched_pair] = selected_col_idx;
            num_matched_pair++;
        }

        // Copy matched pairs to output arrays
        // for (int k = 0; k < num_matched_pair; k++)
        // {
        //     matched_pair_i[k] = temp_matched_pair_i[k];
        //     matched_pair_j[k] = temp_matched_pair_j[k];
        //     printf("matched_pair_i[%d] = %d, matched_pair_j[%d] = %d\n", k, matched_pair_i[k], k, matched_pair_j[k]);
        // }
        // printf("\n");

        if (num_matched_pair == N)
        {
            // // Matching complete
            for (int k = 0; k < N; k++)
            {
                matched_pair_i[k] = temp_matched_pair_i[k];
                matched_pair_j[k] = temp_matched_pair_j[k];
            }
            break;
        }

        // Select minimum cover lines
        collect_not_in_vec(temp_matched_pair_i, num_matched_pair, N, non_matched_i_vec, &non_matched_i_cnt);

        marked_col_cnt = 0;
        bool check_switch = true;
        for (int loop = 0; loop < N && check_switch; loop++)
        {
            check_switch = false;

            // Collect columns
            for (int k = 0; k < non_matched_i_cnt; k++)
            {
                int row_idx = non_matched_i_vec[k];
                float* row_array = zero_mask_matrix + row_idx * N;
                for (int t = 0; t < N; t++)
                {
                    if (row_array[t] == 1.0f && !k_in_vec(marked_cols, t, marked_col_cnt))
                    {
                        marked_cols[marked_col_cnt] = t;
                        marked_col_cnt++;
                        check_switch = true;
                    }
                }
            }

            // Remove rows according to marked columns
            for (int k = 0; k < num_matched_pair; k++)
            {
                int matched_i = temp_matched_pair_i[k];
                int matched_j = temp_matched_pair_j[k];
                if (!k_in_vec(non_matched_i_vec, matched_i, non_matched_i_cnt) && k_in_vec(marked_cols, matched_j, marked_col_cnt))
                {
                    non_matched_i_vec[non_matched_i_cnt] = matched_i;
                    non_matched_i_cnt++;
                    check_switch = true;
                }
            }
        }

        collect_not_in_vec(non_matched_i_vec, non_matched_i_cnt, N, updated_matched_i_vec, &updated_matched_i_cnt);

        // Adjust cost matrix
        float min_val = 1e6f;
        for (int row = 0; row < N; row++)
        {
            if (!k_in_vec(updated_matched_i_vec, row, updated_matched_i_cnt))
            {
                for (int c = 0; c < N; c++)
                {
                    if (!k_in_vec(marked_cols, c, marked_col_cnt))
                    {
                        float cur_cost = local_cost_matrix[row * N + c];
                        if (cur_cost < min_val)
                        {
                            min_val = cur_cost;
                        }
                    }
                }
            }
        }

        // Subtract min_val from non-marked elements
        for (int row = 0; row < N; row++)
        {
            if (!k_in_vec(updated_matched_i_vec, row, updated_matched_i_cnt))
            {
                for (int c = 0; c < N; c++)
                {
                    if (!k_in_vec(marked_cols, c, marked_col_cnt))
                    {
                        local_cost_matrix[row * N + c] -= min_val;
                    }
                }
            }
            else
            {
                // Add min_val to elements in intersection
                for (int c = 0; c < N; c++)
                {
                    if (k_in_vec(marked_cols, c, marked_col_cnt))
                    {
                        local_cost_matrix[row * N + c] += min_val;
                    }
                }
            }
        }

        // Copy the updated local cost matrix to zero_mask_matrix for the next iteration
        for (int idx = 0; idx < N * N; idx++)
        {
            zero_mask_matrix[idx] = local_cost_matrix[idx];
        }

        // Increase iteration counter
        iter++;
        if (iter >= MAX_ITER)
        {
            // Perform greedy matching on the residual unmatched rows and columns
            greedy_matching_residual(local_cost_matrix, N, temp_matched_pair_i, temp_matched_pair_j, num_matched_pair);
            for (int k = 0; k < N; k++)
            {
                matched_pair_i[k] = temp_matched_pair_i[k];
                matched_pair_j[k] = temp_matched_pair_j[k];
            }

            // Handle failure
            printf("Exceeded maximum iterations\n");
            return;
        }
    }
}

// FFI wrapper function
ffi::Error HungarianMatchingImpl(
    cudaStream_t stream,
    ffi::Buffer<ffi::DataType::F32> cost_matrices_buffer,
    ffi::Result<ffi::Buffer<ffi::DataType::S32>> matched_pair_i_out_buffer,
    ffi::Result<ffi::Buffer<ffi::DataType::S32>> matched_pair_j_out_buffer)
{
    // Get the dimensions
    auto dims = cost_matrices_buffer.dimensions();
    if (dims.size() != 3 || dims[1] != dims[2]) {
        return ffi::Error(ffi::ErrorCode::kInvalidArgument,
                          "Cost matrices must have shape [B, N, N]");
    }
    int B = dims[0];
    int N = dims[1];

    if (N > MAX_N) {
        return ffi::Error(ffi::ErrorCode::kInvalidArgument,
                          "N exceeds MAX_N");
    }

    // Get the raw device pointers
    // float* cost_matrices = cost_matrices_buffer.data(); // Corrected method
    float* cost_matrices = cost_matrices_buffer.typed_data();
    int* matched_pair_i_out = matched_pair_i_out_buffer->typed_data();
    int* matched_pair_j_out = matched_pair_j_out_buffer->typed_data();

    // // Launch the kernel
    int threadsPerBlock = 1;
    int blocksPerGrid = (B + threadsPerBlock - 1) / threadsPerBlock;
    hungarian_matching_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
        cost_matrices, matched_pair_i_out, matched_pair_j_out, N, B);

    // Check for kernel launch errors
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        return ffi::Error(ffi::ErrorCode::kUnknown,
                          cudaGetErrorString(err));
    }

    return ffi::Error::Success();
}

// Register the handler
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    HungarianMatching,
    HungarianMatchingImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>() // For CUDA stream
        .Arg<ffi::Buffer<ffi::DataType::F32>>()   // cost_matrices_buffer
        .Ret<ffi::Buffer<ffi::DataType::S32>>()   // matched_pair_i_out_buffer
        .Ret<ffi::Buffer<ffi::DataType::S32>>()   // matched_pair_j_out_buffer
);



// Host code for testing the CUDA implementation
int main()
{
    int N = 7;  // Dimension (can be assigned dynamically)
    int B = 1;  // Batch size

    if (N > MAX_N)
    {
        printf("N exceeds MAX_N (%d)\n", MAX_N);
        return -1;
    }

    // Allocate host memory
    float* h_cost_matrix = (float*)malloc(sizeof(float) * B * N * N);
    int* h_matched_pair_i = (int*)malloc(sizeof(int) * B * N);
    int* h_matched_pair_j = (int*)malloc(sizeof(int) * B * N);

    // Generate random cost matrix
    srand(3);  // Seed for reproducibility
    for (int b = 0; b < B; b++)
    {
        for (int i = 0; i < N * N; i++)
        {
            h_cost_matrix[b * N * N + i] = (float)(rand() % 100);
        }
    }

    // Allocate device memory
    float* d_cost_matrix;
    int* d_matched_pair_i;
    int* d_matched_pair_j;
    cudaMalloc((void**)&d_cost_matrix, sizeof(float) * B * N * N);
    cudaMalloc((void**)&d_matched_pair_i, sizeof(int) * B * N);
    cudaMalloc((void**)&d_matched_pair_j, sizeof(int) * B * N);

    // Copy data to device
    cudaMemcpy(d_cost_matrix, h_cost_matrix, sizeof(float) * B * N * N, cudaMemcpyHostToDevice);

    // Launch kernel
    int threadsPerBlock = 256;
    int blocksPerGrid = (B + threadsPerBlock - 1) / threadsPerBlock;
    hungarian_matching_kernel<<<blocksPerGrid, threadsPerBlock>>>(d_cost_matrix, d_matched_pair_i, d_matched_pair_j, N, B);

    // Copy results back to host
    cudaMemcpy(h_matched_pair_i, d_matched_pair_i, sizeof(int) * B * N, cudaMemcpyDeviceToHost);
    cudaMemcpy(h_matched_pair_j, d_matched_pair_j, sizeof(int) * B * N, cudaMemcpyDeviceToHost);

    // Print results
    for (int b = 0; b < B; b++)
    {
        printf("Batch %d:\n", b);
        printf("Matched pairs (i): ");
        for (int k = 0; k < N; k++)
        {
            printf("%d ", h_matched_pair_i[b * N + k]);
        }
        printf("\n");
        printf("Matched pairs (j): ");
        for (int k = 0; k < N; k++)
        {
            printf("%d ", h_matched_pair_j[b * N + k]);
        }
        printf("\n");
    }

    // Clean up
    free(h_cost_matrix);
    free(h_matched_pair_i);
    free(h_matched_pair_j);
    cudaFree(d_cost_matrix);
    cudaFree(d_matched_pair_i);
    cudaFree(d_matched_pair_j);

    return 0;
}
