// auction_matching.cu

#include <cuda_runtime.h>
// #include <device_launch_parameters.h>
// #include <thrust/device_vector.h>
// #include <thrust/host_vector.h>
#include <iostream>
// #include <limits>
// #include <vector>
#include <cfloat> // Include this header for FLT_MAX
// #include <stdio.h>

// Include necessary headers
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

#define MAX_N 64  // Maximum allowed dimension
// #define MAX_ITR 50000
// #define EPSILON_END_FACTOR 10000000.0f

#define MAX_ITR 2000
#define EPSILON_END_FACTOR 1000.0f

namespace ffi = xla::ffi;


__device__ float atomicMaxFloat(float* address, float val) {
    int* address_as_int = (int*)address;
    int old = *address_as_int, assumed;

    do {
        assumed = old;
        old = atomicCAS(address_as_int, assumed, 
                        __float_as_int(fmaxf(val, __int_as_float(assumed))));



    } while (assumed != old);

    return __int_as_float(old);
}

__device__ void atomicMaxWithIndex(float* maxVal, int* maxIdx, float val, int idx) {
    // Atomically find the maximum float value
    int* address_as_int = (int*)maxVal;  // Treat float as int for atomic operations
    int old = *address_as_int, assumed;
    
    do {
        assumed = old;
        float oldFloat = __int_as_float(assumed);
        
        // If the current value is larger, break
        if (val <= oldFloat) break;

        // Otherwise, attempt to replace
        old = atomicCAS(address_as_int, assumed, __float_as_int(val));

        // If the max value is updated, also update the index
        if (old == assumed) {
            atomicExch(maxIdx, idx);  // Atomic exchange to set the max index
        }
    } while (assumed != old);
}



// // Device function to perform greedy matching on the residual unmatched rows and columns
// __device__ void greedy_matching_residual(
//     const float* cost_matrix,  // [N x M], the most recent cost matrix
//     int N,
//     int M,
//     int* unassigned_worker_mask,
//     int* worker_to_task)
// {
//     // Initialize arrays to keep track of matched rows and columns
//     // bool row_matched[MAX_N] = {false};
//     bool col_matched[MAX_N] = {false};

//     // Mark already matched rows and columns
//     for (int i = 0; i < N; i++)
//     {
//         if (!unassigned_worker_mask[i])
//         {
//             int task = worker_to_task[i];
//             col_matched[task] = true;
//         }
//     }
//     // __syncthreads();

//     // For each unmatched row, find the unmatched column with minimum cost
//     for (int i = 0; i < N; i++)
//     {
//         if (unassigned_worker_mask[i])
//         {
//             float min_cost = FLT_MAX;
//             int min_col = -1;

//             // Search over unmatched columns
//             for (int j = 0; j < M; j++)
//             {
//                 if (!col_matched[j])
//                 {
//                     float cost = cost_matrix[i * M + j];
//                     if (cost < min_cost)
//                     {
//                         min_cost = cost;
//                         min_col = j;
//                     }
//                 }
//             }

//             if (min_col != -1)
//             {
//                 // Add to matched pairs
//                 unassigned_worker_mask[i] = 0;
//                 col_matched[min_col] = true;
//                 worker_to_task[i] = min_col;
//             }
//         }
//     }
// }


// Device function to perform greedy matching on the residual unmatched rows and columns.
// This version, at each iteration, selects the unassigned worker with the smallest available cost edge.
__device__ void greedy_matching_residual(
    const float* cost_matrix,  // [N x M] cost matrix (row-major order)
    int N,                     // number of workers (rows)
    int M,                     // number of tasks (columns)
    int* unassigned_worker_mask, // array of length N: 1 if worker i is unassigned, 0 if assigned
    int* worker_to_task)         // array of length N: current task assignment for each worker (-1 if unassigned)
{
    // We assume M is not larger than MAX_N.
    bool col_matched[MAX_N] = {false};

    // Mark tasks that are already assigned.
    for (int i = 0; i < N; i++) {
        if (!unassigned_worker_mask[i]) {  // worker is assigned
            int task = worker_to_task[i];
            if (task >= 0 && task < M) {
                col_matched[task] = true;
            }
        }
    }

    // Greedy matching: while at least one unassigned worker can be assigned,
    // select the one with the minimum available cost edge.
    bool progress = true;
    while (progress) {
        progress = false;

        float best_global_cost = FLT_MAX;
        int best_worker = -1;
        int best_task = -1;

        // Loop over all workers to find, for each unassigned worker,
        // the minimum cost edge (i.e. the best available task).
        for (int i = 0; i < N; i++) {
            if (unassigned_worker_mask[i]) {
                float worker_best_cost = FLT_MAX;
                int worker_best_task = -1;

                // For worker i, search over all tasks not yet assigned.
                for (int j = 0; j < M; j++) {
                    if (!col_matched[j]) {
                        float cost = cost_matrix[i * M + j];
                        if (cost < worker_best_cost) {
                            worker_best_cost = cost;
                            worker_best_task = j;
                        }
                    }
                }

                // If this worker found a candidate task, check if it is the global best.
                if (worker_best_task != -1 && worker_best_cost < best_global_cost) {
                    best_global_cost = worker_best_cost;
                    best_worker = i;
                    best_task = worker_best_task;
                }
            }
        }

        // If we found a valid worker-task pair, assign it.
        if (best_worker != -1 && best_task != -1) {
            unassigned_worker_mask[best_worker] = 0;
            worker_to_task[best_worker] = best_task;
            col_matched[best_task] = true;
            progress = true;
        }
    }
}



__global__ void AuctionKernelBatch(
    int b,
    int n,
    int m,
    const float* C,
    int* workerToTask_out
) {
    // Declare shared memory arrays
    extern __shared__ char shared_memory[];
    
    // Compute pointers into shared memory
    int* unassignedWorkers_mask = (int*)shared_memory;           // n ints
    float* prices_per_task = (float*)(unassignedWorkers_mask + n);        // m floats
    int* bids_per_worker = (int*)(prices_per_task + m);                              // n ints
    float* bidValues = (float*)(bids_per_worker + n);                       // n floats
    float* bestbidValues_per_task = (float*)(bidValues + n);                      // m float
    int* workerToTask = (int*)(bestbidValues_per_task + m);                   // n ints
    int* taskToWorker = (int*)(workerToTask + n);                // m ints
    int* bidders_per_task = (int*)(taskToWorker + m);                // m ints
    float* epsilons = (float*)(bidders_per_task + m); // 1 floats
    int* numUnassigned = (int*)(epsilons + 1);
    int* numChanges = (int*)(numUnassigned + 1);
    

    // int idx = threadIdx.x;
    int block_idx = blockIdx.x;
    int globalIdx = blockIdx.x * blockDim.x + threadIdx.x;
    int localIdx = threadIdx.x;
    if (block_idx >= b) return;
    
    // Initialize shared memory arrays
    if (localIdx < n) {
        unassignedWorkers_mask[localIdx] = 1;    // All workers are initially unassigned
        bids_per_worker[localIdx] = -1;
        bidValues[localIdx] = 0.0f;
        workerToTask[localIdx] = -1;
    }
    if (localIdx < m)
    {
        prices_per_task[localIdx] = 0.0f;
        bestbidValues_per_task[localIdx] = -FLT_MAX;
        taskToWorker[localIdx] = -1;
        bidders_per_task[localIdx] = -1;
        // epsilon_sum_per_task[localIdx] = 0.0f;
    }
    if (localIdx==0){
        numUnassigned[0] = n;
        epsilons[0] = 1.0f / (n + 1);
    }
    // __syncthreads(); // Ensure initialization is complete

    // int numUnassigned = n;
    int iter_count = 0;
    int ideal_num_unassigned = max(n-m, 0);
    while (numUnassigned[0] > ideal_num_unassigned)
    {
    // for (int iter = 0; iter < 1; ++iter) {
        __syncthreads(); // Synchronize before starting bidding
        // Bidding Phase
        if (localIdx < n && (unassignedWorkers_mask[localIdx]==1)) {
            int worker = localIdx;

            // printf("Worker %d is unassigned\n", worker);

            // Compute net profits
            float maxProfit = -FLT_MAX;
            float secondMaxProfit = -FLT_MAX;
            int bestTask = -1;

            for (int j = 0; j < m; ++j) {
                float profit = -C[n*m*block_idx + worker * m + j] - prices_per_task[j];
                if (profit > maxProfit) {
                    secondMaxProfit = maxProfit;
                    maxProfit = profit;
                    bestTask = j;
                } else if (profit > secondMaxProfit) {
                    secondMaxProfit = profit;
                }
            }

            // printf("idx: %d, secondMaxProfit: %f \n", idx, secondMaxProfit);

            // Compute bid
            float bidValue = maxProfit - secondMaxProfit + epsilons[0];
            bids_per_worker[worker] = bestTask;
            bidValues[worker] = bidValue;
        }
        __syncthreads(); // Synchronize before starting assignment

        // printf("idx: %d, bidValues: %f \n", localIdx, bidValues[localIdx]);
        // printf("idx: %d, bids per worker: %d \n", localIdx, bids_per_worker[localIdx]);


        // Assignment Phase
        if (localIdx < m){
            bestbidValues_per_task[localIdx] = -FLT_MAX;
            bidders_per_task[localIdx] = -1;
        }
        __syncthreads();
        if (localIdx < n) {
            int worker = localIdx;
            int task = bids_per_worker[worker];
            if (task != -1) {
                float bidValue = bidValues[worker];
                atomicMaxWithIndex(&bestbidValues_per_task[task], &bidders_per_task[task], bidValue, worker);
            }
        }
        __syncthreads();
        // printf("idx: %d, bestbidValues_per_task: %f \n", localIdx, bestbidValues_per_task[localIdx]);
        // printf("idx: %d, bidders: %d \n", localIdx, bidders_per_task[localIdx]);

        if (localIdx < n) {
            // Update price
            if (bestbidValues_per_task[localIdx] > -FLT_MAX){
                prices_per_task[localIdx] +=  bestbidValues_per_task[localIdx];
                // epsilon_sum_per_task[localIdx] += epsilons[0];
            }
        }
        __syncthreads(); // Synchronize before starting assignment

        // printf("itr: %d, idx: %d, prices_per_task: %f \n", iter_count, localIdx, prices_per_task[localIdx]);
        // printf("itr: %d, idx: %d, bestbidValues_per_task: %f \n", iter_count, localIdx, bestbidValues_per_task[localIdx]);
        // printf("itr: %d, idx: %d, bidders_per_task: %d \n", iter_count, localIdx, bidders_per_task[localIdx]);
        // printf("itr: %d, idx: %d, taskToWorker: %d \n", iter_count, localIdx, taskToWorker[localIdx]);
        
        if (localIdx == 0) {numChanges[0] = 0;}
        if (localIdx < m) {
            int task = localIdx;
            int new_worker = bidders_per_task[task];
            int prev_worker = taskToWorker[task];
            if (bestbidValues_per_task[task] > -FLT_MAX){
                if (prev_worker != -1 && prev_worker != new_worker){
                    atomicExch(&workerToTask[prev_worker], -1);
                    atomicExch(&unassignedWorkers_mask[prev_worker], 1);
                }
            }
        }
        __syncthreads(); // Synchronize before starting assignment
        if (localIdx < m) {
            int task = localIdx;
            int new_worker = bidders_per_task[task];
            if (bestbidValues_per_task[task] > -FLT_MAX){
                if (new_worker!=-1)
                {
                    atomicExch(&workerToTask[new_worker], task);
                    taskToWorker[task] = new_worker;
                    atomicExch(&unassignedWorkers_mask[new_worker], 0);
                    atomicAdd(&numChanges[0], 1);
                }
            }
        }
        __syncthreads();

        // printf("itr: %d, idx: %d, unassignedWorkers_mask: %d \n", iter_count, localIdx, unassignedWorkers_mask[localIdx]);
        // printf("itr: %d, idx: %d, taskToWorker: %d \n", iter_count, localIdx, taskToWorker[localIdx]);
        // printf("itr: %d, idx: %d, workerToTask: %d \n", iter_count, localIdx, workerToTask[localIdx]);

        // Update unassigned workers mask and count numUnassigned
        if (localIdx == 0) {numUnassigned[0] = 0;}
        
        if (localIdx < n) {
            if (unassignedWorkers_mask[localIdx] == 1) {
                atomicAdd(&numUnassigned[0], 1);
            }
        }

        // if (localIdx == 0) {
        //     numUnassigned[0] = 0;
        //     for (int i = 0; i < n; ++i) {
        //         if (unassignedWorkers_mask[i] == 1) {
        //             numUnassigned[0]++;
        //     }
        //     }
        //     // printf("numUnassigned %d\n", numUnassigned[0]);
        // }

        __syncthreads(); // Synchronize before next iteration
        
        // if (((iter_count+1)%100==0) && ((numChanges[0] <= 5) || (numUnassigned[0] == ideal_num_unassigned))) {
        // if ((numChanges[0] <= 3) || (numUnassigned[0] == ideal_num_unassigned)) {
        // if (numChanges[0] <= int(float(n)*0.4)) {
        // if (numUnassigned[0] <= ideal_num_unassigned + int(n*0.1)) {
        if (numUnassigned[0] == ideal_num_unassigned) {

            if ((numUnassigned[0] == ideal_num_unassigned) && (epsilons[0] <= 1.0f / (n + 1) / EPSILON_END_FACTOR)) {
                // end the loop
                break;
            }

            // reset the unassigned workers mask and reduce the epsilon
            if (localIdx < n) {
                unassignedWorkers_mask[localIdx] = 1;
            }

            // if (localIdx < m) {
                // prices_per_task[localIdx] -= epsilon_sum_per_task[localIdx];  // Subtract the current epsilon to “recenter” prices.
                // epsilon_sum_per_task[localIdx] = 0.0f;

                // taskToWorker[localIdx] = -1;
                // bestbidValues_per_task[localIdx] = -FLT_MAX;
                // bidders_per_task[localIdx] = -1;
            // }

            __syncthreads();

            if (localIdx == 0) {
                epsilons[0] = epsilons[0] / 10.0f;
            }
            numUnassigned[0] = n;

        }

        // Reset temporary arrays
        if (localIdx < n) {
            bids_per_worker[localIdx] = -1;
            bidValues[localIdx] = 0.0f;
        }
        iter_count++;
        __syncthreads(); // Synchronize before next iteration

        if (iter_count > MAX_ITR) {
            if (localIdx == 0){
                // printf("Exceeded max iterations\n");
                // printf("numUnassigned %d\n", numUnassigned[0]);
                greedy_matching_residual(&C[n*m*block_idx], n, m, unassignedWorkers_mask, workerToTask);
            }
            break;
        }

    }
    __syncthreads();
    // Write the workerToTask assignments to global memory output array
    if (localIdx < n) {
        // printf("itr cnt: %d, output -- idx: %d, workerToTask: %d \n", iter_count, localIdx, workerToTask[localIdx]);
        workerToTask_out[block_idx*m+localIdx] = workerToTask[localIdx];
        __syncthreads();
    }
}


// FFI wrapper function
ffi::Error AuctionMatchingImpl(
    cudaStream_t stream,
    // ffi::Buffer<ffi::DataType::F32> cost_matrices_buffer,
    ffi::Buffer<ffi::DataType::F32> cost_matrices_buffer,
    ffi::Result<ffi::Buffer<ffi::DataType::S32>> matched_col_out_buffer)
{   
    // Get the dimensions
    auto dims = cost_matrices_buffer.dimensions();
    // if (dims.size() != 3 || dims[1] != dims[2]) {
    //     return ffi::Error(ffi::ErrorCode::kInvalidArgument,
    //                       "Cost matrices must have shape [B, N, N]");
    // }
    if (dims.size() != 3) {
        return ffi::Error(ffi::ErrorCode::kInvalidArgument,
                          "Cost matrices must have shape [B, N, N]");
    }
    int B = dims[0];
    int N = dims[1];
    int M = dims[2];

    if (M > MAX_N) {
        return ffi::Error(ffi::ErrorCode::kInvalidArgument,
                          "N exceeds MAX_N");
    }

    // Get the raw device pointers
    // float* cost_matrices = cost_matrices_buffer.typed_data();
    // int* matched_col_out = matched_col_out_buffer->typed_data();

    // float* h_cost_amtrix = (float*)malloc(sizeof(float) * B * N * N);
    // cudaMemcpy(h_cost_amtrix, cost_matrices_buffer.typed_data(), B*N*N * sizeof(float), cudaMemcpyDeviceToHost);
    cudaDeviceSynchronize();

    // print in a matrix form / limit decimal points
    // for (int i = 0; i < B; i++)
    // {
    //     printf("batch %d\n", i);
    //     for (int j = 0; j < N; j++)
    //     {
    //         for (int k = 0; k < N; k++)
    //         {
    //             printf("%.0f ", h_cost_amtrix[i*N*N + j*N + k]);
    //         }
    //         printf("\n");
    //     }
    // }

    // float* d_cost_amtrix;
    // int* d_workerToTask_out;
    // cudaMalloc(&d_cost_amtrix, B * N *N * sizeof(float));
    // cudaMalloc(&d_workerToTask_out, B * N * sizeof(int));

    // cudaMemcpy(d_cost_amtrix, h_cost_amtrix, B*N*N * sizeof(float), cudaMemcpyHostToDevice);

    // Define epsilon
    // float epsilon = 1.0f / (N + 1);

    // Calculate shared memory size
    // size_t sharedMemSize = N * (sizeof(int) * 5 + sizeof(float) * 3) + sizeof(int);
    size_t sharedMemSize = N * (sizeof(int) * 3 + sizeof(float) * 1) + M * (sizeof(int) * 2 + sizeof(float) * 2) + 2*sizeof(int) + sizeof(float);

    // printf("B: %d, N: %d, sharedMemSize: %d\n", B, N, sharedMemSize);

    // Launch kernel
    int gridSize = B;  // Single block launch
    int blockSize = max(N, M); // Since n is small and fits in one block
    AuctionKernelBatch<<<gridSize, blockSize, sharedMemSize>>>(
        B, N, M, cost_matrices_buffer.typed_data(), matched_col_out_buffer->typed_data()
    );
    cudaDeviceSynchronize();

    // int h_workerToTask_out[B*N];
    // cudaMemcpy(h_workerToTask_out, d_workerToTask_out, B*N * sizeof(int), cudaMemcpyDeviceToHost);

    // cudaMemcpy(matched_col_out_buffer->typed_data(), h_workerToTask_out, B*N * sizeof(int), cudaMemcpyHostToDevice);
    // cudaMemcpy(matched_col_out_buffer->typed_data(), d_workerToTask_out, B*N * sizeof(int), cudaMemcpyDeviceToDevice);

    // for (int i = 0; i < B; i++)
    // {
    //     for (int j = 0; j < N; j++)
    //     {
    //         printf("batch %d Worker %d assigned to Task %d\n", i, j, h_workerToTask_out[i*N + j]);
    //     }
    // }


    // Check for kernel launch errors
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        return ffi::Error(ffi::ErrorCode::kUnknown,
                          cudaGetErrorString(err));
    }

    return ffi::Error::Success();
}

// Register the handler
// XLA_FFI_DEFINE_HANDLER(
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    AuctionMatching,
    AuctionMatchingImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>() // For CUDA stream
        .Arg<ffi::Buffer<ffi::DataType::F32>>()   // cost_matrices_buffer
        .Ret<ffi::Buffer<ffi::DataType::S32>>()   // matched_col_out_buffer
);


// // Main function to test the AuctionKernel
// int main() {
//     // Define problem size
//     const int n = 64;
//     const int b = 10;

//     // Define cost matrix (flattened)
//     // float h_C[b* n * n] = {
//     //     90, 75, 75, 80,
//     //     35, 85, 55, 65,
//     //     125, 95, 90, 105,
//     //     45, 110, 95, 115
//     // };
//     float* h_C = (float*)malloc(sizeof(float) * b * n * n);

//     srand(3);  // Seed for reproducibility
//     printf("Start\n");

//     // Allocate device memory
//     float* d_C;
//     int* d_workerToTask_out;
//     cudaMalloc(&d_C, b * n * n * sizeof(float));
//     printf("Allocations\n");
//     cudaMalloc(&d_workerToTask_out, b* n * sizeof(int));
//     printf("Allocations2\n");

    
//     for (int itr=0; itr<100; itr++){
//         printf("Iteration %d\n", itr);

//         for (int j = 0; j < b; j++)
//         {
//             for (int i = 0; i < n * n; i++)
//             {
//                 h_C[j * n * n + i] = (float)(rand() % 100);
//             }
//         }


//         // Copy cost matrix to device
//         cudaMemcpy(d_C, h_C, b* n * n * sizeof(float), cudaMemcpyHostToDevice);

//         // Define epsilon
//         float epsilon = 1.0f / (n + 1);

//         // Calculate shared memory size
//         size_t sharedMemSize = n * (sizeof(int) * 5 + sizeof(float) * 3) + sizeof(int);

//         // Launch kernel
//         int blockSize = n; // Since n is small and fits in one block
//         int gridSize = b;  // Single block launch
//         AuctionKernelBatch<<<gridSize, blockSize, sharedMemSize>>>(
//             b, n, d_C, epsilon, d_workerToTask_out
//         );
//         cudaDeviceSynchronize();

//         // Check for kernel launch errors
//         cudaError_t err = cudaGetLastError();
//         if (err != cudaSuccess) {
//             std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
//             // Free device memory before exiting
//             cudaFree(d_C);
//             cudaFree(d_workerToTask_out);
//             return -1;
//         }

//         // Copy results back to host
//         int h_workerToTask_out[b*n];
//         cudaMemcpy(h_workerToTask_out, d_workerToTask_out, b*n * sizeof(int), cudaMemcpyDeviceToHost);

//         // Print the results
//         std::cout << "Assignments (worker -> task):" << std::endl;
//         for (int j = 0; j < b; j++)
//         {
//             for (int i = 0; i < n; ++i) {
//                 std::cout << "batch " << j << " Worker " << i << " assigned to Task " << h_workerToTask_out[j*n + i] << std::endl;
//             }
//         }
//     }

//     // for (int i = 0; i < n; ++i) {
//     //     std::cout << "Worker " << i << " assigned to Task " << h_workerToTask_out[i] << std::endl;
//     // }

//     // Free device memory
//     cudaFree(d_C);
//     cudaFree(d_workerToTask_out);

//     return 0;
// }