#include <hip/hip_runtime.h>
#include <rocblas/rocblas.h>
#include <iostream>
#include <vector>
#include <chrono>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <limits>
#include <cctype>

#define HIP_CHECK(cmd)                                                         \
    {                                                                          \
        hipError_t error = cmd;                                                \
        if (error != hipSuccess) {                                             \
            std::cerr << "HIP error: " << hipGetErrorString(error)             \
                      << " at " << __FILE__ << ":" << __LINE__ << std::endl;   \
            exit(EXIT_FAILURE);                                                \
        }                                                                      \
    }

#define ROCBLAS_CHECK(cmd)                                                     \
    {                                                                          \
        rocblas_status status = cmd;                                           \
        if (status != rocblas_status_success) {                                \
            std::cerr << "rocBLAS error: " << status                           \
                      << " at " << __FILE__ << ":" << __LINE__ << std::endl;   \
            exit(EXIT_FAILURE);                                                \
        }                                                                      \
    }

constexpr int K = 64;           // Number of top elements to select
constexpr int MERGE_FACTOR = 8; // Merge every 8 scores to 1
constexpr int BLOCK_SIZE = 256;

// Kernel to merge every 8 consecutive scores by finding the max
// Also stores the original index of the max element within each group
__global__ void merge8_max_kernel(const float* __restrict__ input,
                                  float* __restrict__ output,
                                  int* __restrict__ max_indices,
                                  int input_size, int output_size) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (tid < output_size) {
        int start_idx = tid * MERGE_FACTOR;
        int end_idx = min(start_idx + MERGE_FACTOR, input_size);
        
        float max_val = -INFINITY;
        int max_idx = start_idx;
        
        for (int i = start_idx; i < end_idx; i++) {
            float val = input[i];
            if (val > max_val) {
                max_val = val;
                max_idx = i;
            }
        }
        
        output[tid] = max_val;
        max_indices[tid] = max_idx;
    }
}

// Alternative: Vectorized merge using float4 for better memory bandwidth
// This reads 4 floats at a time and processes 2 float4s per output element
__global__ void merge8_max_vectorized_kernel(const float* __restrict__ input,
                                              float* __restrict__ output,
                                              int* __restrict__ max_indices,
                                              int input_size, int output_size) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (tid < output_size) {
        int start_idx = tid * MERGE_FACTOR;
        
        // Check if we can use vectorized load (aligned and full 8 elements available)
        if (start_idx + MERGE_FACTOR <= input_size && (start_idx % 4) == 0) {
            // Load 8 floats as two float4s
            float4 v0 = *reinterpret_cast<const float4*>(&input[start_idx]);
            float4 v1 = *reinterpret_cast<const float4*>(&input[start_idx + 4]);
            
            // Find max among all 8 values
            float vals[8] = {v0.x, v0.y, v0.z, v0.w, v1.x, v1.y, v1.z, v1.w};
            float max_val = vals[0];
            int max_local_idx = 0;
            
            #pragma unroll
            for (int i = 1; i < 8; i++) {
                if (vals[i] > max_val) {
                    max_val = vals[i];
                    max_local_idx = i;
                }
            }
            
            output[tid] = max_val;
            max_indices[tid] = start_idx + max_local_idx;
        } else {
            // Fallback to scalar path for edge cases
            int end_idx = min(start_idx + MERGE_FACTOR, input_size);
            float max_val = -INFINITY;
            int max_idx = start_idx;
            
            for (int i = start_idx; i < end_idx; i++) {
                float val = input[i];
                if (val > max_val) {
                    max_val = val;
                    max_idx = i;
                }
            }
            
            output[tid] = max_val;
            max_indices[tid] = max_idx;
        }
    }
}

// Top-K kernel: each thread scans a portion and maintains local top-k
__global__ void topk_scan_kernel(const float* __restrict__ values,
                                 const int* __restrict__ original_indices,
                                 float* __restrict__ block_top_values,
                                 int* __restrict__ block_top_indices,
                                 int n, int k) {
    constexpr int LOCAL_K = 8;  // Each thread keeps top 8
    float local_vals[LOCAL_K];
    int local_idxs[LOCAL_K];
    
    for (int i = 0; i < LOCAL_K; i++) {
        local_vals[i] = -INFINITY;
        local_idxs[i] = -1;
    }
    
    int tid = threadIdx.x;
    int global_tid = blockIdx.x * blockDim.x + tid;
    int stride = blockDim.x * gridDim.x;
    
    // Grid-stride loop - each thread scans every stride-th element
    for (int i = global_tid; i < n; i += stride) {
        float val = values[i];
        
        // Check if this value should be in local top-k
        if (val > local_vals[LOCAL_K - 1]) {
            // Find insertion position
            int pos = LOCAL_K - 1;
            while (pos > 0 && val > local_vals[pos - 1]) {
                pos--;
            }
            
            // Shift elements down
            for (int j = LOCAL_K - 1; j > pos; j--) {
                local_vals[j] = local_vals[j - 1];
                local_idxs[j] = local_idxs[j - 1];
            }
            
            local_vals[pos] = val;
            // Use original index from the merge step
            local_idxs[pos] = (original_indices != nullptr) ? original_indices[i] : i;
        }
    }
    
    // Now merge across threads in the block using shared memory
    __shared__ float s_vals[BLOCK_SIZE * LOCAL_K];
    __shared__ int s_idxs[BLOCK_SIZE * LOCAL_K];
    
    // Each thread writes its local top-k to shared memory
    for (int i = 0; i < LOCAL_K; i++) {
        s_vals[tid * LOCAL_K + i] = local_vals[i];
        s_idxs[tid * LOCAL_K + i] = local_idxs[i];
    }
    __syncthreads();
    
    // Thread 0 does final merge to get block's top-k
    if (tid == 0) {
        float result_vals[K];
        int result_idxs[K];
        
        // Initialize with -inf
        for (int i = 0; i < K; i++) {
            result_vals[i] = -INFINITY;
            result_idxs[i] = -1;
        }
        
        int total_candidates = BLOCK_SIZE * LOCAL_K;
        
        // Selection: repeatedly find max and add to result
        for (int r = 0; r < K; r++) {
            float max_val = -INFINITY;
            int max_pos = -1;
            
            for (int i = 0; i < total_candidates; i++) {
                if (s_vals[i] > max_val) {
                    max_val = s_vals[i];
                    max_pos = i;
                }
            }
            
            if (max_pos >= 0) {
                result_vals[r] = max_val;
                result_idxs[r] = s_idxs[max_pos];
                s_vals[max_pos] = -INFINITY;  // Mark as used
            }
        }
        
        // Write to global memory
        for (int i = 0; i < K; i++) {
            block_top_values[blockIdx.x * K + i] = result_vals[i];
            block_top_indices[blockIdx.x * K + i] = result_idxs[i];
        }
    }
}

// Final merge kernel - merges results from all blocks
__global__ void topk_final_merge_kernel(const float* __restrict__ block_top_values,
                                        const int* __restrict__ block_top_indices,
                                        float* __restrict__ final_values,
                                        int* __restrict__ final_indices,
                                        int num_blocks, int k) {
    extern __shared__ float shared_mem[];
    float* s_vals = shared_mem;
    int* s_idxs = (int*)(shared_mem + num_blocks * k);
    
    int tid = threadIdx.x;
    int total = num_blocks * k;
    
    // Load all candidates
    for (int i = tid; i < total; i += blockDim.x) {
        s_vals[i] = block_top_values[i];
        s_idxs[i] = block_top_indices[i];
    }
    __syncthreads();
    
    // Thread 0 finds global top-k
    if (tid == 0) {
        for (int r = 0; r < k; r++) {
            float max_val = -INFINITY;
            int max_pos = -1;
            
            for (int i = 0; i < total; i++) {
                if (s_vals[i] > max_val) {
                    max_val = s_vals[i];
                    max_pos = i;
                }
            }
            
            if (max_pos >= 0) {
                final_values[r] = max_val;
                final_indices[r] = s_idxs[max_pos];
                s_vals[max_pos] = -INFINITY;  // Mark as used
            } else {
                final_values[r] = -INFINITY;
                final_indices[r] = -1;
            }
        }
    }
}

// GPU top-k on merged scores
void topk_gpu(const float* d_values, const int* d_original_indices,
              int* d_top_indices, float* d_top_values,
              int n, int k, hipStream_t stream = 0) {
    int num_blocks = min((n + BLOCK_SIZE - 1) / BLOCK_SIZE, 16);
    
    float *d_block_values;
    int *d_block_indices;
    
    HIP_CHECK(hipMalloc(&d_block_values, num_blocks * k * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_block_indices, num_blocks * k * sizeof(int)));
    
    // Step 1: Each block finds its local top-k
    hipLaunchKernelGGL(topk_scan_kernel, dim3(num_blocks), dim3(BLOCK_SIZE), 
                       0, stream, d_values, d_original_indices,
                       d_block_values, d_block_indices, n, k);
    
    // Step 2: Merge block results
    size_t shared_mem_size = num_blocks * k * (sizeof(float) + sizeof(int));
    hipLaunchKernelGGL(topk_final_merge_kernel, dim3(1), dim3(256),
                       shared_mem_size, stream, d_block_values, d_block_indices,
                       d_top_values, d_top_indices, num_blocks, k);
    
    HIP_CHECK(hipFree(d_block_values));
    HIP_CHECK(hipFree(d_block_indices));
}

// CPU-based top-k for verification and alternative method
void topk_cpu(const float* d_merged_values, const int* d_original_indices,
              int* d_top_indices, float* d_top_values,
              int merged_size, int k, hipStream_t stream = 0) {
    std::vector<float> h_values(merged_size);
    std::vector<int> h_indices(merged_size);
    
    HIP_CHECK(hipMemcpyAsync(h_values.data(), d_merged_values, 
                              merged_size * sizeof(float), 
                              hipMemcpyDeviceToHost, stream));
    HIP_CHECK(hipMemcpyAsync(h_indices.data(), d_original_indices,
                              merged_size * sizeof(int),
                              hipMemcpyDeviceToHost, stream));
    HIP_CHECK(hipStreamSynchronize(stream));
    
    // Create pairs of (value, original_index)
    std::vector<std::pair<float, int>> pairs(merged_size);
    for (int i = 0; i < merged_size; i++) {
        pairs[i] = {h_values[i], h_indices[i]};
    }
    
    // Partial sort to get top-k
    std::partial_sort(pairs.begin(), pairs.begin() + k, pairs.end(),
                     [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
                         return a.first > b.first;
                     });
    
    std::vector<float> h_top_values(k);
    std::vector<int> h_top_indices(k);
    for (int i = 0; i < k; i++) {
        h_top_values[i] = pairs[i].first;
        h_top_indices[i] = pairs[i].second;
    }
    
    HIP_CHECK(hipMemcpyAsync(d_top_values, h_top_values.data(), k * sizeof(float),
                              hipMemcpyHostToDevice, stream));
    HIP_CHECK(hipMemcpyAsync(d_top_indices, h_top_indices.data(), k * sizeof(int),
                              hipMemcpyHostToDevice, stream));
}

// Enum for top-k method selection
enum class TopKMethod {
    GPU,           // GPU-based top-k kernel
    CPU_PARTIAL    // CPU partial_sort
};

// Main merge8 + top-k function
void merge8_topk(const float* d_scores, int* d_top_indices, float* d_top_values,
                 int n, int k, TopKMethod method = TopKMethod::GPU,
                 bool use_vectorized = true, hipStream_t stream = 0) {
    // Calculate merged size
    int merged_size = (n + MERGE_FACTOR - 1) / MERGE_FACTOR;
    
    // Allocate temporary buffers for merged scores and indices
    float* d_merged_scores;
    int* d_merged_indices;
    
    HIP_CHECK(hipMalloc(&d_merged_scores, merged_size * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_merged_indices, merged_size * sizeof(int)));
    
    // Step 1: Merge every 8 scores by finding max
    int merge_blocks = (merged_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
    
    if (use_vectorized) {
        hipLaunchKernelGGL(merge8_max_vectorized_kernel, dim3(merge_blocks), 
                           dim3(BLOCK_SIZE), 0, stream,
                           d_scores, d_merged_scores, d_merged_indices, n, merged_size);
    } else {
        hipLaunchKernelGGL(merge8_max_kernel, dim3(merge_blocks), dim3(BLOCK_SIZE),
                           0, stream, d_scores, d_merged_scores, d_merged_indices, 
                           n, merged_size);
    }
    
    // Step 2: Find top-k on merged scores
    switch (method) {
        case TopKMethod::GPU:
            topk_gpu(d_merged_scores, d_merged_indices, d_top_indices, 
                     d_top_values, merged_size, k, stream);
            break;
        case TopKMethod::CPU_PARTIAL:
            topk_cpu(d_merged_scores, d_merged_indices, d_top_indices,
                     d_top_values, merged_size, k, stream);
            break;
    }
    
    HIP_CHECK(hipFree(d_merged_scores));
    HIP_CHECK(hipFree(d_merged_indices));
}

int main(int argc, char** argv) {
    int M = 16384;        // Number of rows in matrix (can be overridden by argument)
    const int N = 128;    // Dimension of vectors
    
    // Parse command line arguments
    TopKMethod method = TopKMethod::GPU;
    std::string method_name = "GPU";
    bool use_vectorized = true;
    
    for (int i = 1; i < argc; i++) {
        std::string arg = argv[i];
        if (arg == "gpu" || arg == "GPU") {
            method = TopKMethod::GPU;
            method_name = "GPU";
        } else if (arg == "cpu" || arg == "CPU") {
            method = TopKMethod::CPU_PARTIAL;
            method_name = "CPU (partial_sort)";
        } else if (arg == "scalar") {
            use_vectorized = false;
        } else if (arg == "vectorized") {
            use_vectorized = true;
        } else if (arg.substr(0, 2) == "-m" || arg.substr(0, 2) == "-M") {
            // Parse -m<value> or -M<value>
            if (arg.length() > 2) {
                M = std::atoi(arg.substr(2).c_str());
            } else if (i + 1 < argc) {
                M = std::atoi(argv[++i]);
            }
            if (M <= 0) {
                std::cerr << "Error: M must be a positive integer" << std::endl;
                return 1;
            }
        } else if (std::isdigit(arg[0])) {
            // Treat standalone number as M
            M = std::atoi(arg.c_str());
            if (M <= 0) {
                std::cerr << "Error: M must be a positive integer" << std::endl;
                return 1;
            }
        } else if (arg == "-h" || arg == "--help") {
            std::cout << "Usage: " << argv[0] << " [M] [gpu|cpu] [scalar|vectorized]" << std::endl;
            std::cout << "  M          - Number of scores/rows (default: 16384)" << std::endl;
            std::cout << "  -m<value>  - Set M explicitly (e.g., -m8192)" << std::endl;
            std::cout << "  gpu        - GPU-based top-k kernel (default)" << std::endl;
            std::cout << "  cpu        - CPU partial_sort" << std::endl;
            std::cout << "  scalar     - Use scalar merge kernel" << std::endl;
            std::cout << "  vectorized - Use vectorized merge kernel (default)" << std::endl;
            return 0;
        }
    }
    
    int merged_size = (M + MERGE_FACTOR - 1) / MERGE_FACTOR;
    
    std::cout << "ROCm Top-K with Merge-8 Benchmark" << std::endl;
    std::cout << "Matrix size: " << M << "x" << N << std::endl;
    std::cout << "Vector size: 1x" << N << std::endl;
    std::cout << "Original scores: " << M << std::endl;
    std::cout << "Merged scores (8->1): " << merged_size << std::endl;
    std::cout << "Top-K: " << K << std::endl;
    std::cout << "Top-K Method: " << method_name << std::endl;
    std::cout << "Merge Kernel: " << (use_vectorized ? "Vectorized" : "Scalar") << std::endl;
    std::cout << "----------------------------------------" << std::endl;
    
    // Allocate host memory
    std::vector<float> h_vector(N);
    std::vector<float> h_matrix(M * N);
    std::vector<float> h_result(M);
    std::vector<int> h_top_indices(K);
    std::vector<float> h_top_values(K);
    
    // Initialize data
    srand(42);
    for (int i = 0; i < N; i++) {
        h_vector[i] = static_cast<float>(rand()) / RAND_MAX * 10.0f;
    }
    
    for (int i = 0; i < M * N; i++) {
        h_matrix[i] = static_cast<float>(rand()) / RAND_MAX * 10.0f;
    }
    
    // Make some results definitely larger for testing
    // Spread them out so they fall in different merge groups
    for (int i = 0; i < 100; i++) {
        int target_row = i * 8;  // Every 8th row to test merge behavior
        if (target_row < M) {
            for (int j = 0; j < N; j++) {
                h_matrix[target_row * N + j] = 1000.0f + static_cast<float>(rand()) / RAND_MAX * 100.0f;
            }
        }
    }
    
    // Allocate device memory
    float *d_vector, *d_matrix, *d_result;
    int *d_top_indices;
    float *d_top_values;
    
    HIP_CHECK(hipMalloc(&d_vector, N * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_matrix, M * N * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_result, M * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_top_indices, K * sizeof(int)));
    HIP_CHECK(hipMalloc(&d_top_values, K * sizeof(float)));
    
    // Copy data to device
    HIP_CHECK(hipMemcpy(d_vector, h_vector.data(), N * sizeof(float), 
                        hipMemcpyHostToDevice));
    HIP_CHECK(hipMemcpy(d_matrix, h_matrix.data(), M * N * sizeof(float), 
                        hipMemcpyHostToDevice));
    
    // Create rocBLAS handle
    rocblas_handle handle;
    ROCBLAS_CHECK(rocblas_create_handle(&handle));
    
    // Warmup run
    float alpha = 1.0f;
    float beta = 0.0f;
    ROCBLAS_CHECK(rocblas_sgemv(handle, rocblas_operation_transpose,
                                N, M, &alpha, d_matrix, N,
                                d_vector, 1, &beta, d_result, 1));
    merge8_topk(d_result, d_top_indices, d_top_values, M, K, method, use_vectorized);
    HIP_CHECK(hipDeviceSynchronize());
    
    // Benchmark parameters
    const int num_iterations = 100;
    std::vector<double> latencies;
    
    std::cout << "Running " << num_iterations << " iterations..." << std::endl;
    
    for (int iter = 0; iter < num_iterations; iter++) {
        auto start = std::chrono::high_resolution_clock::now();
        
        // Perform matrix-vector multiplication using rocBLAS
        ROCBLAS_CHECK(rocblas_sgemv(handle, rocblas_operation_transpose,
                                    N, M, &alpha, d_matrix, N,
                                    d_vector, 1, &beta, d_result, 1));
        
        // Merge every 8 scores and find top-K
        merge8_topk(d_result, d_top_indices, d_top_values, M, K, method, use_vectorized);
        
        HIP_CHECK(hipDeviceSynchronize());
        
        auto end = std::chrono::high_resolution_clock::now();
        std::chrono::duration<double, std::micro> elapsed = end - start;
        latencies.push_back(elapsed.count());
    }
    
    // Copy final results back to host
    HIP_CHECK(hipMemcpy(h_result.data(), d_result, M * sizeof(float), 
                        hipMemcpyDeviceToHost));
    HIP_CHECK(hipMemcpy(h_top_indices.data(), d_top_indices, K * sizeof(int), 
                        hipMemcpyDeviceToHost));
    HIP_CHECK(hipMemcpy(h_top_values.data(), d_top_values, K * sizeof(float), 
                        hipMemcpyDeviceToHost));
    
    // Calculate statistics
    double sum = 0.0;
    double min_latency = latencies[0];
    double max_latency = latencies[0];
    
    for (double lat : latencies) {
        sum += lat;
        min_latency = std::min(min_latency, lat);
        max_latency = std::max(max_latency, lat);
    }
    
    double mean_latency = sum / num_iterations;
    
    double variance = 0.0;
    for (double lat : latencies) {
        variance += (lat - mean_latency) * (lat - mean_latency);
    }
    variance /= num_iterations;
    double std_dev = std::sqrt(variance);
    
    // Print results
    std::cout << "\n=== Performance Results ===" << std::endl;
    std::cout << "Mean latency: " << mean_latency << " μs" << std::endl;
    std::cout << "Min latency:  " << min_latency << " μs" << std::endl;
    std::cout << "Max latency:  " << max_latency << " μs" << std::endl;
    std::cout << "Std dev:      " << std_dev << " μs" << std::endl;
    
    // Print top-K results
    std::cout << "\n=== Top-" << K << " Results (with Merge-8) ===" << std::endl;
    std::cout << "Showing first 10 and last 5 of top-" << K << ":" << std::endl;
    for (int i = 0; i < 10; i++) {
        int orig_idx = h_top_indices[i];
        int merge_group = orig_idx / MERGE_FACTOR;
        std::cout << "  Rank " << (i+1) << ": Index " << orig_idx 
                  << " (merge group " << merge_group << ")"
                  << ", Value " << h_top_values[i] << std::endl;
    }
    std::cout << "  ..." << std::endl;
    for (int i = K - 5; i < K; i++) {
        int orig_idx = h_top_indices[i];
        int merge_group = orig_idx / MERGE_FACTOR;
        std::cout << "  Rank " << (i+1) << ": Index " << orig_idx 
                  << " (merge group " << merge_group << ")"
                  << ", Value " << h_top_values[i] << std::endl;
    }
    
    // CPU verification - compute merge-8 and top-k on CPU
    std::cout << "\n=== CPU Verification (Merge-8 + Top-K) ===" << std::endl;
    
    // First merge every 8 scores
    std::vector<std::pair<float, int>> merged_pairs(merged_size);
    for (int g = 0; g < merged_size; g++) {
        int start = g * MERGE_FACTOR;
        int end = std::min(start + MERGE_FACTOR, M);
        float max_val = -INFINITY;
        int max_idx = start;
        
        for (int i = start; i < end; i++) {
            if (h_result[i] > max_val) {
                max_val = h_result[i];
                max_idx = i;
            }
        }
        merged_pairs[g] = {max_val, max_idx};
    }
    
    // Find top-k on merged pairs
    std::partial_sort(merged_pairs.begin(), merged_pairs.begin() + K, merged_pairs.end(),
                     [](const auto& a, const auto& b) { return a.first > b.first; });
    
    std::cout << "CPU Top-10 (with Merge-8):" << std::endl;
    for (int i = 0; i < 10; i++) {
        int orig_idx = merged_pairs[i].second;
        int merge_group = orig_idx / MERGE_FACTOR;
        std::cout << "  Rank " << (i+1) << ": Index " << orig_idx
                  << " (merge group " << merge_group << ")"
                  << ", Value " << merged_pairs[i].first << std::endl;
    }
    
    // Check if GPU results match CPU
    int matches = 0;
    for (int i = 0; i < K; i++) {
        for (int j = 0; j < K; j++) {
            if (h_top_indices[i] == merged_pairs[j].second) {
                matches++;
                break;
            }
        }
    }
    std::cout << "\nGPU top-" << K << " indices matching CPU: " << matches << "/" << K << std::endl;
    
    // Cleanup
    ROCBLAS_CHECK(rocblas_destroy_handle(handle));
    HIP_CHECK(hipFree(d_vector));
    HIP_CHECK(hipFree(d_matrix));
    HIP_CHECK(hipFree(d_result));
    HIP_CHECK(hipFree(d_top_indices));
    HIP_CHECK(hipFree(d_top_values));
    
    std::cout << "\nDone!" << std::endl;
    
    return 0;
}
