#include <hip/hip_runtime.h>
#include <rocblas/rocblas.h>
#include <iostream>
#include <vector>
#include <chrono>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <limits>

#define HIP_CHECK(cmd)                                                         \
    {                                                                          \
        hipError_t error = cmd;                                                \
        if (error != hipSuccess) {                                             \
            std::cerr << "HIP error: " << hipGetErrorString(error)             \
                      << " at " << __FILE__ << ":" << __LINE__ << std::endl;   \
            exit(EXIT_FAILURE);                                                \
        }                                                                      \
    }

#define ROCBLAS_CHECK(cmd)                                                     \
    {                                                                          \
        rocblas_status status = cmd;                                           \
        if (status != rocblas_status_success) {                                \
            std::cerr << "rocBLAS error: " << status                           \
                      << " at " << __FILE__ << ":" << __LINE__ << std::endl;   \
            exit(EXIT_FAILURE);                                                \
        }                                                                      \
    }

constexpr int K = 64;  // Number of top elements to select
constexpr int BLOCK_SIZE = 256;
constexpr int RADIX_BITS = 8;
constexpr int RADIX_SIZE = 1 << RADIX_BITS;  // 256

// Structure to hold value and original index
struct ValueIndex {
    float value;
    int index;
};

// Warp-level reduction to find max in warp
__device__ __forceinline__ float warp_reduce_max(float val) {
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        val = fmaxf(val, __shfl_down(val, offset));
    }
    return val;
}

// Warp-level reduction to find sum in warp
__device__ __forceinline__ int warp_reduce_sum(int val) {
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        val += __shfl_down(val, offset);
    }
    return val;
}

// Block-level reduction to find max value for normalization
__device__ float block_reduce_max(float val, float* shared) {
    int lane = threadIdx.x % warpSize;
    int wid = threadIdx.x / warpSize;
    
    val = warp_reduce_max(val);
    
    if (lane == 0) shared[wid] = val;
    __syncthreads();
    
    val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : -INFINITY;
    
    if (wid == 0) val = warp_reduce_max(val);
    
    return val;
}

// Count elements in each radix bucket for current byte
__global__ void count_radix_kernel(const float* __restrict__ values, 
                                   unsigned int* __restrict__ histogram,
                                   int n, int byte_idx, float min_val, float range) {
    __shared__ unsigned int local_hist[RADIX_SIZE];
    
    // Initialize local histogram
    for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) {
        local_hist[i] = 0;
    }
    __syncthreads();
    
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (idx < n) {
        // Normalize value to [0, 1] range and convert to unsigned int
        float normalized = (values[idx] - min_val) / range;
        normalized = fminf(fmaxf(normalized, 0.0f), 1.0f);
        unsigned int uint_val = static_cast<unsigned int>(normalized * 4294967295.0f);
        
        // For descending order, flip the bits
        uint_val = ~uint_val;
        
        // Extract the byte at byte_idx (0 = MSB for descending)
        int shift = (3 - byte_idx) * 8;
        unsigned int bucket = (uint_val >> shift) & 0xFF;
        
        atomicAdd(&local_hist[bucket], 1);
    }
    __syncthreads();
    
    // Add to global histogram
    for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) {
        if (local_hist[i] > 0) {
            atomicAdd(&histogram[i], local_hist[i]);
        }
    }
}

// Find the pivot bucket that contains the k-th element
__global__ void find_pivot_kernel(const unsigned int* __restrict__ histogram,
                                  int k, int* pivot_bucket, int* prefix_count) {
    __shared__ unsigned int prefix[RADIX_SIZE];
    
    // Compute prefix sums
    if (threadIdx.x == 0) {
        unsigned int sum = 0;
        int found = -1;
        int found_prefix = 0;
        
        for (int i = 0; i < RADIX_SIZE; i++) {
            prefix[i] = sum;
            sum += histogram[i];
            if (found < 0 && sum >= k) {
                found = i;
                found_prefix = prefix[i];
            }
        }
        
        *pivot_bucket = found;
        *prefix_count = found_prefix;
    }
}

// Filter elements that are in top-k buckets and collect final top-k
__global__ void filter_topk_kernel(const float* __restrict__ values,
                                   int* __restrict__ top_indices,
                                   float* __restrict__ top_values,
                                   int* __restrict__ count,
                                   int n, int k, 
                                   float min_val, float range,
                                   int byte_idx, int pivot_bucket) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (idx < n) {
        float val = values[idx];
        
        // Normalize value to [0, 1] range and convert to unsigned int
        float normalized = (val - min_val) / range;
        normalized = fminf(fmaxf(normalized, 0.0f), 1.0f);
        unsigned int uint_val = static_cast<unsigned int>(normalized * 4294967295.0f);
        
        // For descending order, flip the bits
        uint_val = ~uint_val;
        
        // Extract the byte at byte_idx (0 = MSB for descending)
        int shift = (3 - byte_idx) * 8;
        unsigned int bucket = (uint_val >> shift) & 0xFF;
        
        // If this element is in a bucket before the pivot, it's definitely in top-k
        if (bucket < pivot_bucket) {
            int pos = atomicAdd(count, 1);
            if (pos < k) {
                top_indices[pos] = idx;
                top_values[pos] = val;
            }
        }
    }
}

// Simple parallel selection for small arrays using partial bitonic sort
// This is efficient when we just need top-64 from a reasonably sized array
__global__ void topk_single_block_kernel(const float* __restrict__ values,
                                         int* __restrict__ top_indices,
                                         float* __restrict__ top_values,
                                         int n, int k) {
    __shared__ float s_values[BLOCK_SIZE];
    __shared__ int s_indices[BLOCK_SIZE];
    
    int tid = threadIdx.x;
    
    // Each thread maintains its local top-k candidates
    float local_vals[4];  // Each thread tracks 4 candidates
    int local_idxs[4];
    
    for (int i = 0; i < 4; i++) {
        local_vals[i] = -INFINITY;
        local_idxs[i] = -1;
    }
    
    // Grid-stride loop to process all elements
    for (int i = tid; i < n; i += blockDim.x) {
        float val = values[i];
        
        // Insert into local sorted list if larger than minimum
        if (val > local_vals[3]) {
            local_vals[3] = val;
            local_idxs[3] = i;
            
            // Bubble sort to maintain order
            for (int j = 3; j > 0; j--) {
                if (local_vals[j] > local_vals[j-1]) {
                    float tmp_v = local_vals[j];
                    int tmp_i = local_idxs[j];
                    local_vals[j] = local_vals[j-1];
                    local_idxs[j] = local_idxs[j-1];
                    local_vals[j-1] = tmp_v;
                    local_idxs[j-1] = tmp_i;
                }
            }
        }
    }
    
    // Now we need to merge across threads to get global top-k
    // Use shared memory for parallel reduction
    
    // First, each thread contributes its best value
    s_values[tid] = local_vals[0];
    s_indices[tid] = local_idxs[0];
    __syncthreads();
    
    // Parallel merge - iteratively select top elements
    for (int selected = 0; selected < k; selected++) {
        // Find maximum in shared memory
        __shared__ float max_val;
        __shared__ int max_idx;
        __shared__ int max_tid;
        
        if (tid == 0) {
            max_val = -INFINITY;
            max_idx = -1;
            max_tid = -1;
        }
        __syncthreads();
        
        // Each thread checks if it has the max
        if (s_values[tid] > -INFINITY) {
            // Use atomicMax pattern with CAS
            float old = atomicMax((int*)&max_val, __float_as_int(s_values[tid]));
        }
        __syncthreads();
        
        // Determine which thread has the max (first one wins ties)
        if (__float_as_int(s_values[tid]) == __float_as_int(max_val) && s_values[tid] > -INFINITY) {
            int old_tid = atomicMin(&max_tid, tid);
            if (old_tid > tid || old_tid < 0) {
                max_idx = s_indices[tid];
            }
        }
        __syncthreads();
        
        // Thread with max_tid writes to output and replaces its value with next candidate
        if (tid == max_tid) {
            top_values[selected] = s_values[tid];
            top_indices[selected] = s_indices[tid];
            
            // Shift local values
            local_vals[0] = local_vals[1];
            local_idxs[0] = local_idxs[1];
            local_vals[1] = local_vals[2];
            local_idxs[1] = local_idxs[2];
            local_vals[2] = local_vals[3];
            local_idxs[2] = local_idxs[3];
            local_vals[3] = -INFINITY;
            local_idxs[3] = -1;
            
            s_values[tid] = local_vals[0];
            s_indices[tid] = local_idxs[0];
        }
        __syncthreads();
    }
}

// Efficient warp-level top-k using warp shuffle
// Each warp maintains a sorted list and merges across warps
__global__ void topk_warp_merge_kernel(const float* __restrict__ values,
                                       int* __restrict__ top_indices,
                                       float* __restrict__ top_values,
                                       int n, int k) {
    constexpr int ITEMS_PER_THREAD = 2;  // Each thread keeps track of 2 elements
    
    __shared__ float s_topk_values[K];
    __shared__ int s_topk_indices[K];
    __shared__ int s_count;
    
    float local_vals[ITEMS_PER_THREAD];
    int local_idxs[ITEMS_PER_THREAD];
    
    for (int i = 0; i < ITEMS_PER_THREAD; i++) {
        local_vals[i] = -INFINITY;
        local_idxs[i] = -1;
    }
    
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    
    // Process elements with grid stride
    for (int i = tid; i < n; i += stride) {
        float val = values[i];
        
        // Try to insert into local sorted array
        if (val > local_vals[ITEMS_PER_THREAD - 1]) {
            // Find insertion point
            int insert_pos = ITEMS_PER_THREAD - 1;
            for (int j = 0; j < ITEMS_PER_THREAD - 1; j++) {
                if (val > local_vals[j]) {
                    insert_pos = j;
                    break;
                }
            }
            
            // Shift elements down
            for (int j = ITEMS_PER_THREAD - 1; j > insert_pos; j--) {
                local_vals[j] = local_vals[j-1];
                local_idxs[j] = local_idxs[j-1];
            }
            
            local_vals[insert_pos] = val;
            local_idxs[insert_pos] = i;
        }
    }
    
    // Now merge within block using shared memory
    if (threadIdx.x == 0) {
        s_count = 0;
        for (int i = 0; i < K; i++) {
            s_topk_values[i] = -INFINITY;
            s_topk_indices[i] = -1;
        }
    }
    __syncthreads();
    
    // Each thread tries to insert its elements into shared top-k
    for (int item = 0; item < ITEMS_PER_THREAD; item++) {
        if (local_vals[item] > -INFINITY) {
            // Try to insert into shared top-k
            // Find position using binary search
            int left = 0, right = K - 1;
            
            // Check if this value should be in top-k
            if (local_vals[item] > s_topk_values[K-1]) {
                int pos = atomicAdd(&s_count, 1);
                if (pos < K) {
                    // Simple atomic insert - not perfectly sorted but gets top-k
                    // We'll sort at the end
                    int insert_idx = K - 1 - pos;
                    if (insert_idx >= 0) {
                        s_topk_values[pos] = local_vals[item];
                        s_topk_indices[pos] = local_idxs[item];
                    }
                }
            }
        }
        __syncthreads();
    }
    
    // Thread 0 does final sort and writes output
    if (threadIdx.x == 0 && blockIdx.x == 0) {
        // Bubble sort the collected elements (k is small, so this is fine)
        for (int i = 0; i < K - 1; i++) {
            for (int j = 0; j < K - 1 - i; j++) {
                if (s_topk_values[j] < s_topk_values[j + 1]) {
                    float tmp_v = s_topk_values[j];
                    s_topk_values[j] = s_topk_values[j + 1];
                    s_topk_values[j + 1] = tmp_v;
                    
                    int tmp_i = s_topk_indices[j];
                    s_topk_indices[j] = s_topk_indices[j + 1];
                    s_topk_indices[j + 1] = tmp_i;
                }
            }
        }
        
        for (int i = 0; i < K; i++) {
            top_values[i] = s_topk_values[i];
            top_indices[i] = s_topk_indices[i];
        }
    }
}

// Most efficient approach: Per-block top-k with multi-block reduction
// Step 1: Each block finds its local top-k
__global__ void topk_per_block_kernel(const float* __restrict__ values,
                                      float* __restrict__ block_top_values,
                                      int* __restrict__ block_top_indices,
                                      int n, int k) {
    __shared__ float s_vals[BLOCK_SIZE];
    __shared__ int s_idxs[BLOCK_SIZE];
    
    int tid = threadIdx.x;
    int block_offset = blockIdx.x * blockDim.x;
    int global_idx = block_offset + tid;
    
    // Load data
    if (global_idx < n) {
        s_vals[tid] = values[global_idx];
        s_idxs[tid] = global_idx;
    } else {
        s_vals[tid] = -INFINITY;
        s_idxs[tid] = -1;
    }
    __syncthreads();
    
    // Bitonic sort within the block
    for (int size = 2; size <= BLOCK_SIZE; size *= 2) {
        for (int stride = size / 2; stride > 0; stride /= 2) {
            int partner = tid ^ stride;
            
            if (partner > tid && partner < BLOCK_SIZE) {
                bool ascending = ((tid & size) == 0);
                
                if ((s_vals[tid] < s_vals[partner]) == ascending) {
                    // Swap
                    float tmp_v = s_vals[tid];
                    s_vals[tid] = s_vals[partner];
                    s_vals[partner] = tmp_v;
                    
                    int tmp_i = s_idxs[tid];
                    s_idxs[tid] = s_idxs[partner];
                    s_idxs[partner] = tmp_i;
                }
            }
            __syncthreads();
        }
    }
    
    // Write top-k from this block (sorted in descending order after bitonic sort with ascending=false for tid=0)
    // After bitonic sort, largest elements are at the end
    if (tid < k) {
        int src_idx = BLOCK_SIZE - 1 - tid;  // Get from end
        block_top_values[blockIdx.x * k + tid] = s_vals[src_idx];
        block_top_indices[blockIdx.x * k + tid] = s_idxs[src_idx];
    }
}

// Step 2: Merge block results to get global top-k
__global__ void topk_merge_blocks_kernel(const float* __restrict__ block_top_values,
                                         const int* __restrict__ block_top_indices,
                                         float* __restrict__ final_values,
                                         int* __restrict__ final_indices,
                                         int num_blocks, int k) {
    __shared__ float s_vals[K * 4];  // Support merging up to 4 blocks worth
    __shared__ int s_idxs[K * 4];
    
    int tid = threadIdx.x;
    int total_candidates = num_blocks * k;
    int load_size = min(total_candidates, K * 4);
    
    // Load candidates into shared memory
    if (tid < load_size) {
        s_vals[tid] = block_top_values[tid];
        s_idxs[tid] = block_top_indices[tid];
    } else {
        s_vals[tid] = -INFINITY;
        s_idxs[tid] = -1;
    }
    __syncthreads();
    
    // Simple selection sort for top-k (since k=64 and we have few candidates)
    if (tid == 0) {
        for (int i = 0; i < k; i++) {
            int max_idx = i;
            float max_val = s_vals[i];
            
            for (int j = i + 1; j < load_size; j++) {
                if (s_vals[j] > max_val) {
                    max_val = s_vals[j];
                    max_idx = j;
                }
            }
            
            if (max_idx != i) {
                float tmp_v = s_vals[i];
                s_vals[i] = s_vals[max_idx];
                s_vals[max_idx] = tmp_v;
                
                int tmp_i = s_idxs[i];
                s_idxs[i] = s_idxs[max_idx];
                s_idxs[max_idx] = tmp_i;
            }
            
            final_values[i] = s_vals[i];
            final_indices[i] = s_idxs[i];
        }
    }
}

// Efficient single-pass top-k kernel - each block processes a contiguous chunk
// and finds its local top-k using shared memory reduction
__global__ void topk_block_kernel(const float* __restrict__ values,
                                  float* __restrict__ block_top_values,
                                  int* __restrict__ block_top_indices,
                                  int n, int k, int elements_per_block) {
    __shared__ float s_vals[K * 2];  // Double buffer for merge
    __shared__ int s_idxs[K * 2];
    
    int tid = threadIdx.x;
    int block_start = blockIdx.x * elements_per_block;
    int block_end = min(block_start + elements_per_block, n);
    
    // Initialize shared memory with -inf
    for (int i = tid; i < K * 2; i += blockDim.x) {
        s_vals[i] = -INFINITY;
        s_idxs[i] = -1;
    }
    __syncthreads();
    
    // Each thread processes elements and maintains local top candidates
    // Using registers for local storage
    float local_val = -INFINITY;
    int local_idx = -1;
    
    // Process elements in chunks, merging into shared memory periodically
    for (int base = block_start; base < block_end; base += blockDim.x) {
        int idx = base + tid;
        
        if (idx < block_end) {
            float val = values[idx];
            if (val > local_val) {
                local_val = val;
                local_idx = idx;
            }
        }
    }
    
    // Write local best to second half of shared memory
    s_vals[K + tid % K] = (tid < K) ? local_val : -INFINITY;
    s_idxs[K + tid % K] = (tid < K) ? local_idx : -1;
    
    // For threads >= K, try to insert via atomic comparison
    if (tid >= K && local_val > -INFINITY) {
        // Find minimum in second half and try to replace
        for (int i = 0; i < K; i++) {
            if (local_val > s_vals[K + i]) {
                float old = atomicMax((int*)&s_vals[K + i], __float_as_int(local_val));
                if (__int_as_float(old) < local_val) {
                    s_idxs[K + i] = local_idx;
                    break;
                }
            }
        }
    }
    __syncthreads();
    
    // Now do a proper reduction - each thread re-scans its portion
    // Reset and do proper sequential scan
    if (tid == 0) {
        // Collect all valid candidates and sort
        float temp_vals[K];
        int temp_idxs[K];
        int count = 0;
        
        // First pass: collect from shared memory
        for (int i = 0; i < K * 2 && count < K; i++) {
            if (s_vals[i] > -INFINITY) {
                // Insert sorted
                int insert_pos = count;
                for (int j = 0; j < count; j++) {
                    if (s_vals[i] > temp_vals[j]) {
                        insert_pos = j;
                        break;
                    }
                }
                // Shift down
                for (int j = min(count, K-1); j > insert_pos; j--) {
                    temp_vals[j] = temp_vals[j-1];
                    temp_idxs[j] = temp_idxs[j-1];
                }
                if (insert_pos < K) {
                    temp_vals[insert_pos] = s_vals[i];
                    temp_idxs[insert_pos] = s_idxs[i];
                    if (count < K) count++;
                }
            }
        }
        
        // Write to output
        for (int i = 0; i < K; i++) {
            if (i < count) {
                block_top_values[blockIdx.x * K + i] = temp_vals[i];
                block_top_indices[blockIdx.x * K + i] = temp_idxs[i];
            } else {
                block_top_values[blockIdx.x * K + i] = -INFINITY;
                block_top_indices[blockIdx.x * K + i] = -1;
            }
        }
    }
}

// Simpler and more reliable approach: each thread scans a portion and 
// we do warp-level then block-level reduction
__global__ void topk_scan_kernel(const float* __restrict__ values,
                                 float* __restrict__ block_top_values,
                                 int* __restrict__ block_top_indices,
                                 int n, int k) {
    // Each thread maintains its own top-k using a simple sorted array
    constexpr int LOCAL_K = 8;  // Each thread keeps top 8
    float local_vals[LOCAL_K];
    int local_idxs[LOCAL_K];
    
    for (int i = 0; i < LOCAL_K; i++) {
        local_vals[i] = -INFINITY;
        local_idxs[i] = -1;
    }
    
    int tid = threadIdx.x;
    int global_tid = blockIdx.x * blockDim.x + tid;
    int stride = blockDim.x * gridDim.x;
    
    // Grid-stride loop - each thread scans every stride-th element
    for (int i = global_tid; i < n; i += stride) {
        float val = values[i];
        
        // Check if this value should be in local top-k
        if (val > local_vals[LOCAL_K - 1]) {
            // Find insertion position
            int pos = LOCAL_K - 1;
            while (pos > 0 && val > local_vals[pos - 1]) {
                pos--;
            }
            
            // Shift elements down
            for (int j = LOCAL_K - 1; j > pos; j--) {
                local_vals[j] = local_vals[j - 1];
                local_idxs[j] = local_idxs[j - 1];
            }
            
            local_vals[pos] = val;
            local_idxs[pos] = i;
        }
    }
    
    // Now merge across threads in the block using shared memory
    __shared__ float s_vals[BLOCK_SIZE * LOCAL_K];
    __shared__ int s_idxs[BLOCK_SIZE * LOCAL_K];
    
    // Each thread writes its local top-k to shared memory
    for (int i = 0; i < LOCAL_K; i++) {
        s_vals[tid * LOCAL_K + i] = local_vals[i];
        s_idxs[tid * LOCAL_K + i] = local_idxs[i];
    }
    __syncthreads();
    
    // Thread 0 does final merge to get block's top-k
    if (tid == 0) {
        float result_vals[K];
        int result_idxs[K];
        int result_count = 0;
        
        // Initialize with -inf
        for (int i = 0; i < K; i++) {
            result_vals[i] = -INFINITY;
            result_idxs[i] = -1;
        }
        
        // Merge all thread results - use a simple heap-based approach
        // We have BLOCK_SIZE * LOCAL_K candidates, need to find top K
        int total_candidates = BLOCK_SIZE * LOCAL_K;
        
        // Selection: repeatedly find max and add to result
        for (int r = 0; r < K; r++) {
            float max_val = -INFINITY;
            int max_pos = -1;
            
            for (int i = 0; i < total_candidates; i++) {
                if (s_vals[i] > max_val) {
                    max_val = s_vals[i];
                    max_pos = i;
                }
            }
            
            if (max_pos >= 0) {
                result_vals[r] = max_val;
                result_idxs[r] = s_idxs[max_pos];
                s_vals[max_pos] = -INFINITY;  // Mark as used
            }
        }
        
        // Write to global memory
        for (int i = 0; i < K; i++) {
            block_top_values[blockIdx.x * K + i] = result_vals[i];
            block_top_indices[blockIdx.x * K + i] = result_idxs[i];
        }
    }
}

// Final merge kernel - merges results from all blocks
__global__ void topk_final_merge_kernel(const float* __restrict__ block_top_values,
                                        const int* __restrict__ block_top_indices,
                                        float* __restrict__ final_values,
                                        int* __restrict__ final_indices,
                                        int num_blocks, int k) {
    // Load all block results into shared memory and find global top-k
    extern __shared__ float shared_mem[];
    float* s_vals = shared_mem;
    int* s_idxs = (int*)(shared_mem + num_blocks * k);
    
    int tid = threadIdx.x;
    int total = num_blocks * k;
    
    // Load all candidates
    for (int i = tid; i < total; i += blockDim.x) {
        s_vals[i] = block_top_values[i];
        s_idxs[i] = block_top_indices[i];
    }
    __syncthreads();
    
    // Thread 0 finds global top-k
    if (tid == 0) {
        for (int r = 0; r < k; r++) {
            float max_val = -INFINITY;
            int max_pos = -1;
            
            for (int i = 0; i < total; i++) {
                if (s_vals[i] > max_val) {
                    max_val = s_vals[i];
                    max_pos = i;
                }
            }
            
            if (max_pos >= 0) {
                final_values[r] = max_val;
                final_indices[r] = s_idxs[max_pos];
                s_vals[max_pos] = -INFINITY;  // Mark as used
            } else {
                final_values[r] = -INFINITY;
                final_indices[r] = -1;
            }
        }
    }
}

// Final efficient top-k implementation (GPU)
void topk_gpu(const float* d_values, int* d_top_indices, float* d_top_values,
              int n, int k, hipStream_t stream = 0) {
    // Use fewer blocks for better per-block top-k quality
    int num_blocks = min((n + BLOCK_SIZE - 1) / BLOCK_SIZE, 16);
    
    float *d_block_values;
    int *d_block_indices;
    
    HIP_CHECK(hipMalloc(&d_block_values, num_blocks * k * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_block_indices, num_blocks * k * sizeof(int)));
    
    // Step 1: Each block finds its local top-k using scan approach
    hipLaunchKernelGGL(topk_scan_kernel, dim3(num_blocks), dim3(BLOCK_SIZE), 
                       0, stream, d_values, d_block_values, d_block_indices, n, k);
    
    // Step 2: Merge block results with proper shared memory size
    size_t shared_mem_size = num_blocks * k * (sizeof(float) + sizeof(int));
    hipLaunchKernelGGL(topk_final_merge_kernel, dim3(1), dim3(256),
                       shared_mem_size, stream, d_block_values, d_block_indices,
                       d_top_values, d_top_indices, num_blocks, k);
    
    HIP_CHECK(hipFree(d_block_values));
    HIP_CHECK(hipFree(d_block_indices));
}

// CPU-based efficient top-k using partial sort with heap
// This copies data from GPU, performs top-k on CPU, and copies results back
void topk_cpu(const float* d_values, int* d_top_indices, float* d_top_values,
              int n, int k, hipStream_t stream = 0) {
    // Allocate host memory
    std::vector<float> h_values(n);
    
    // Copy values from device to host
    HIP_CHECK(hipMemcpyAsync(h_values.data(), d_values, n * sizeof(float), 
                              hipMemcpyDeviceToHost, stream));
    HIP_CHECK(hipStreamSynchronize(stream));
    
    // Create pairs of (value, index) for sorting
    std::vector<std::pair<float, int>> pairs(n);
    for (int i = 0; i < n; i++) {
        pairs[i] = {h_values[i], i};
    }
    
    // Use partial_sort to get top-k (O(n log k) complexity)
    std::partial_sort(pairs.begin(), pairs.begin() + k, pairs.end(),
                     [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
                         return a.first > b.first;  // Descending order
                     });
    
    // Extract results
    std::vector<float> h_top_values(k);
    std::vector<int> h_top_indices(k);
    for (int i = 0; i < k; i++) {
        h_top_values[i] = pairs[i].first;
        h_top_indices[i] = pairs[i].second;
    }
    
    // Copy results back to device
    HIP_CHECK(hipMemcpyAsync(d_top_values, h_top_values.data(), k * sizeof(float),
                              hipMemcpyHostToDevice, stream));
    HIP_CHECK(hipMemcpyAsync(d_top_indices, h_top_indices.data(), k * sizeof(int),
                              hipMemcpyHostToDevice, stream));
}

// CPU-based top-k using nth_element (faster for large arrays, O(n) average)
void topk_cpu_nth(const float* d_values, int* d_top_indices, float* d_top_values,
                  int n, int k, hipStream_t stream = 0) {
    // Allocate host memory
    std::vector<float> h_values(n);
    
    // Copy values from device to host
    HIP_CHECK(hipMemcpyAsync(h_values.data(), d_values, n * sizeof(float), 
                              hipMemcpyDeviceToHost, stream));
    HIP_CHECK(hipStreamSynchronize(stream));
    
    // Create pairs of (value, index) for sorting
    std::vector<std::pair<float, int>> pairs(n);
    for (int i = 0; i < n; i++) {
        pairs[i] = {h_values[i], i};
    }
    
    // Use nth_element to partition - O(n) average complexity
    std::nth_element(pairs.begin(), pairs.begin() + k, pairs.end(),
                    [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
                        return a.first > b.first;  // Descending order
                    });
    
    // Sort just the top-k elements
    std::sort(pairs.begin(), pairs.begin() + k,
             [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
                 return a.first > b.first;  // Descending order
             });
    
    // Extract results
    std::vector<float> h_top_values(k);
    std::vector<int> h_top_indices(k);
    for (int i = 0; i < k; i++) {
        h_top_values[i] = pairs[i].first;
        h_top_indices[i] = pairs[i].second;
    }
    
    // Copy results back to device
    HIP_CHECK(hipMemcpyAsync(d_top_values, h_top_values.data(), k * sizeof(float),
                              hipMemcpyHostToDevice, stream));
    HIP_CHECK(hipMemcpyAsync(d_top_indices, h_top_indices.data(), k * sizeof(int),
                              hipMemcpyHostToDevice, stream));
}

// Enum for top-k method selection
enum class TopKMethod {
    GPU,           // GPU-based top-k kernel
    CPU_PARTIAL,   // CPU partial_sort (O(n log k))
    CPU_NTH        // CPU nth_element + sort (O(n) average)
};

// Unified top-k function with method selection
void topk(const float* d_values, int* d_top_indices, float* d_top_values,
          int n, int k, TopKMethod method = TopKMethod::GPU, hipStream_t stream = 0) {
    switch (method) {
        case TopKMethod::GPU:
            topk_gpu(d_values, d_top_indices, d_top_values, n, k, stream);
            break;
        case TopKMethod::CPU_PARTIAL:
            topk_cpu(d_values, d_top_indices, d_top_values, n, k, stream);
            break;
        case TopKMethod::CPU_NTH:
            topk_cpu_nth(d_values, d_top_indices, d_top_values, n, k, stream);
            break;
    }
}

int main(int argc, char** argv) {
    const int M = 16384;  // Number of rows in matrix
    const int N = 128;    // Dimension of vectors
    
    // Parse command line arguments for method selection
    TopKMethod method = TopKMethod::GPU;
    std::string method_name = "GPU";
    
    if (argc > 1) {
        std::string arg = argv[1];
        if (arg == "gpu" || arg == "GPU") {
            method = TopKMethod::GPU;
            method_name = "GPU";
        } else if (arg == "cpu" || arg == "CPU" || arg == "cpu_partial") {
            method = TopKMethod::CPU_PARTIAL;
            method_name = "CPU (partial_sort)";
        } else if (arg == "cpu_nth" || arg == "CPU_NTH") {
            method = TopKMethod::CPU_NTH;
            method_name = "CPU (nth_element)";
        } else {
            std::cout << "Usage: " << argv[0] << " [gpu|cpu|cpu_nth]" << std::endl;
            std::cout << "  gpu      - GPU-based top-k kernel (default)" << std::endl;
            std::cout << "  cpu      - CPU partial_sort (O(n log k))" << std::endl;
            std::cout << "  cpu_nth  - CPU nth_element + sort (O(n) average)" << std::endl;
            return 1;
        }
    }
    
    std::cout << "ROCm Top-K Selection Benchmark" << std::endl;
    std::cout << "Matrix size: " << M << "x" << N << std::endl;
    std::cout << "Vector size: 1x" << N << std::endl;
    std::cout << "Top-K: " << K << std::endl;
    std::cout << "Method: " << method_name << std::endl;
    std::cout << "----------------------------------------" << std::endl;
    
    // Allocate host memory
    std::vector<float> h_vector(N);
    std::vector<float> h_matrix(M * N);
    std::vector<float> h_result(M);
    std::vector<int> h_top_indices(K);
    std::vector<float> h_top_values(K);
    
    // Initialize data
    srand(42);
    for (int i = 0; i < N; i++) {
        h_vector[i] = static_cast<float>(rand()) / RAND_MAX * 10.0f;
    }
    
    for (int i = 0; i < M * N; i++) {
        h_matrix[i] = static_cast<float>(rand()) / RAND_MAX * 10.0f;
    }
    
    // Make some results definitely larger for testing
    for (int i = 0; i < 100; i++) {
        for (int j = 0; j < N; j++) {
            h_matrix[i * N + j] = 1000.0f + static_cast<float>(rand()) / RAND_MAX * 100.0f;
        }
    }
    
    // Allocate device memory
    float *d_vector, *d_matrix, *d_result;
    int *d_top_indices;
    float *d_top_values;
    
    HIP_CHECK(hipMalloc(&d_vector, N * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_matrix, M * N * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_result, M * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_top_indices, K * sizeof(int)));
    HIP_CHECK(hipMalloc(&d_top_values, K * sizeof(float)));
    
    // Copy data to device
    HIP_CHECK(hipMemcpy(d_vector, h_vector.data(), N * sizeof(float), 
                        hipMemcpyHostToDevice));
    HIP_CHECK(hipMemcpy(d_matrix, h_matrix.data(), M * N * sizeof(float), 
                        hipMemcpyHostToDevice));
    
    // Create rocBLAS handle
    rocblas_handle handle;
    ROCBLAS_CHECK(rocblas_create_handle(&handle));
    
    // Warmup run
    float alpha = 1.0f;
    float beta = 0.0f;
    ROCBLAS_CHECK(rocblas_sgemv(handle, rocblas_operation_transpose,
                                N, M, &alpha, d_matrix, N,
                                d_vector, 1, &beta, d_result, 1));
    topk(d_result, d_top_indices, d_top_values, M, K, method);
    HIP_CHECK(hipDeviceSynchronize());
    
    // Benchmark parameters
    const int num_iterations = 100;
    std::vector<double> latencies;
    
    std::cout << "Running " << num_iterations << " iterations..." << std::endl;
    
    for (int iter = 0; iter < num_iterations; iter++) {
        // Start timing
        auto start = std::chrono::high_resolution_clock::now();
        
        // Perform matrix-vector multiplication using rocBLAS
        ROCBLAS_CHECK(rocblas_sgemv(handle, rocblas_operation_transpose,
                                    N, M, &alpha, d_matrix, N,
                                    d_vector, 1, &beta, d_result, 1));
        
        // Find top-K indices using selected method
        topk(d_result, d_top_indices, d_top_values, M, K, method);
        
        // Synchronize to ensure all operations complete
        HIP_CHECK(hipDeviceSynchronize());
        
        // End timing
        auto end = std::chrono::high_resolution_clock::now();
        std::chrono::duration<double, std::micro> elapsed = end - start;
        latencies.push_back(elapsed.count());
    }
    
    // Copy final results back to host
    HIP_CHECK(hipMemcpy(h_result.data(), d_result, M * sizeof(float), 
                        hipMemcpyDeviceToHost));
    HIP_CHECK(hipMemcpy(h_top_indices.data(), d_top_indices, K * sizeof(int), 
                        hipMemcpyDeviceToHost));
    HIP_CHECK(hipMemcpy(h_top_values.data(), d_top_values, K * sizeof(float), 
                        hipMemcpyDeviceToHost));
    
    // Calculate statistics
    double sum = 0.0;
    double min_latency = latencies[0];
    double max_latency = latencies[0];
    
    for (double lat : latencies) {
        sum += lat;
        min_latency = std::min(min_latency, lat);
        max_latency = std::max(max_latency, lat);
    }
    
    double mean_latency = sum / num_iterations;
    
    // Calculate standard deviation
    double variance = 0.0;
    for (double lat : latencies) {
        variance += (lat - mean_latency) * (lat - mean_latency);
    }
    variance /= num_iterations;
    double std_dev = std::sqrt(variance);
    
    // Print results
    std::cout << "\n=== Performance Results ===" << std::endl;
    std::cout << "Mean latency: " << mean_latency << " μs" << std::endl;
    std::cout << "Min latency:  " << min_latency << " μs" << std::endl;
    std::cout << "Max latency:  " << max_latency << " μs" << std::endl;
    std::cout << "Std dev:      " << std_dev << " μs" << std::endl;
    
    // Print top-K results
    std::cout << "\n=== Top-" << K << " Results ===" << std::endl;
    std::cout << "Showing first 10 and last 5 of top-" << K << ":" << std::endl;
    for (int i = 0; i < 10; i++) {
        std::cout << "  Rank " << (i+1) << ": Index " << h_top_indices[i] 
                  << ", Value " << h_top_values[i] << std::endl;
    }
    std::cout << "  ..." << std::endl;
    for (int i = K - 5; i < K; i++) {
        std::cout << "  Rank " << (i+1) << ": Index " << h_top_indices[i] 
                  << ", Value " << h_top_values[i] << std::endl;
    }
    
    // Verify by computing top-K on CPU
    std::cout << "\n=== CPU Verification ===" << std::endl;
    std::vector<std::pair<float, int>> cpu_pairs(M);
    for (int i = 0; i < M; i++) {
        cpu_pairs[i] = {h_result[i], i};
    }
    std::partial_sort(cpu_pairs.begin(), cpu_pairs.begin() + K, cpu_pairs.end(),
                     [](const auto& a, const auto& b) { return a.first > b.first; });
    
    std::cout << "CPU Top-10:" << std::endl;
    for (int i = 0; i < 10; i++) {
        std::cout << "  Rank " << (i+1) << ": Index " << cpu_pairs[i].second 
                  << ", Value " << cpu_pairs[i].first << std::endl;
    }
    
    // Check if GPU results match CPU
    int matches = 0;
    for (int i = 0; i < K; i++) {
        for (int j = 0; j < K; j++) {
            if (h_top_indices[i] == cpu_pairs[j].second) {
                matches++;
                break;
            }
        }
    }
    std::cout << "\nGPU top-" << K << " indices matching CPU: " << matches << "/" << K << std::endl;
    
    // Cleanup
    ROCBLAS_CHECK(rocblas_destroy_handle(handle));
    HIP_CHECK(hipFree(d_vector));
    HIP_CHECK(hipFree(d_matrix));
    HIP_CHECK(hipFree(d_result));
    HIP_CHECK(hipFree(d_top_indices));
    HIP_CHECK(hipFree(d_top_values));
    
    std::cout << "\nDone!" << std::endl;
    
    return 0;
}
