#include <hip/hip_runtime.h>
#include <rocblas/rocblas.h>
#include <iostream>
#include <vector>
#include <chrono>
#include <cmath>
#include <cstdlib>

#define HIP_CHECK(cmd)                                                         \
    {                                                                          \
        hipError_t error = cmd;                                                \
        if (error != hipSuccess) {                                             \
            std::cerr << "HIP error: " << hipGetErrorString(error)             \
                      << " at " << __FILE__ << ":" << __LINE__ << std::endl;   \
            exit(EXIT_FAILURE);                                                \
        }                                                                      \
    }

#define ROCBLAS_CHECK(cmd)                                                     \
    {                                                                          \
        rocblas_status status = cmd;                                           \
        if (status != rocblas_status_success) {                                \
            std::cerr << "rocBLAS error: " << status                           \
                      << " at " << __FILE__ << ":" << __LINE__ << std::endl;   \
            exit(EXIT_FAILURE);                                                \
        }                                                                      \
    }

// Kernel to find indices with values > threshold
__global__ void find_indices_kernel(const float* input, int* indices, int* count, 
                                     float threshold, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (idx < n) {
        if (input[idx] > threshold) {
            int pos = atomicAdd(count, 1);
            indices[pos] = idx;
        }
    }
}

int main() {
    const int M = 65536;  // Number of rows in matrix
    const int N = 768;    // Dimension of vectors
    const float threshold = 100000.0f;
    
    std::cout << "ROCm Dot Product Benchmark" << std::endl;
    std::cout << "Matrix size: " << M << "x" << N << std::endl;
    std::cout << "Vector size: 1x" << N << std::endl;
    std::cout << "Threshold: " << threshold << std::endl;
    std::cout << "----------------------------------------" << std::endl;
    
    // Allocate host memory
    std::vector<float> h_vector(N);
    std::vector<float> h_matrix(M * N);
    std::vector<float> h_result(M);
    std::vector<int> h_indices(M);
    
    // Initialize data
    srand(42);
    for (int i = 0; i < N; i++) {
        h_vector[i] = static_cast<float>(rand()) / RAND_MAX * 10.0f;
    }
    
    for (int i = 0; i < M * N; i++) {
        h_matrix[i] = static_cast<float>(rand()) / RAND_MAX * 10.0f;
    }
    
    // Make some results definitely exceed threshold for testing
    for (int i = 0; i < 100; i++) {
        for (int j = 0; j < N; j++) {
            h_matrix[i * N + j] = 1000.0f + static_cast<float>(rand()) / RAND_MAX * 100.0f;
        }
    }
    
    // Allocate device memory
    float *d_vector, *d_matrix, *d_result;
    int *d_indices, *d_count;
    
    HIP_CHECK(hipMalloc(&d_vector, N * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_matrix, M * N * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_result, M * sizeof(float)));
    HIP_CHECK(hipMalloc(&d_indices, M * sizeof(int)));
    HIP_CHECK(hipMalloc(&d_count, sizeof(int)));
    
    // Copy data to device
    HIP_CHECK(hipMemcpy(d_vector, h_vector.data(), N * sizeof(float), 
                        hipMemcpyHostToDevice));
    HIP_CHECK(hipMemcpy(d_matrix, h_matrix.data(), M * N * sizeof(float), 
                        hipMemcpyHostToDevice));
    
    // Create rocBLAS handle
    rocblas_handle handle;
    ROCBLAS_CHECK(rocblas_create_handle(&handle));
    
    // Warmup run
    float alpha = 1.0f;
    float beta = 0.0f;
    ROCBLAS_CHECK(rocblas_sgemv(handle, rocblas_operation_none,
                                M, N, &alpha, d_matrix, M,
                                d_vector, 1, &beta, d_result, 1));
    HIP_CHECK(hipDeviceSynchronize());
    
    // Benchmark parameters
    const int num_iterations = 100;
    std::vector<double> latencies;
    
    std::cout << "Running " << num_iterations << " iterations..." << std::endl;
    
    for (int iter = 0; iter < num_iterations; iter++) {
        // Reset count
        int zero = 0;
        HIP_CHECK(hipMemcpy(d_count, &zero, sizeof(int), hipMemcpyHostToDevice));
        
        // Start timing
        auto start = std::chrono::high_resolution_clock::now();
        
        // Perform matrix-vector multiplication using rocBLAS
        // C = alpha * A * x + beta * C
        // A is MxN matrix (stored column-major in rocBLAS, but we treat as row-major)
        // We want result[i] = sum_j(matrix[i][j] * vector[j])
        // This is equivalent to: result = matrix * vector
        // In rocBLAS with row-major storage, we need to use transposed operation
        ROCBLAS_CHECK(rocblas_sgemv(handle, rocblas_operation_transpose,
                                    N, M, &alpha, d_matrix, N,
                                    d_vector, 1, &beta, d_result, 1));
        
        // Find indices with values > threshold
        int threadsPerBlock = 256;
        int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock;
        hipLaunchKernelGGL(find_indices_kernel, dim3(blocksPerGrid), 
                          dim3(threadsPerBlock), 0, 0,
                          d_result, d_indices, d_count, threshold, M);
        
        // Synchronize to ensure all operations complete
        HIP_CHECK(hipDeviceSynchronize());
        
        // End timing
        auto end = std::chrono::high_resolution_clock::now();
        std::chrono::duration<double, std::micro> elapsed = end - start;
        latencies.push_back(elapsed.count());
    }
    
    // Copy final results back to host
    HIP_CHECK(hipMemcpy(h_result.data(), d_result, M * sizeof(float), 
                        hipMemcpyDeviceToHost));
    int h_count;
    HIP_CHECK(hipMemcpy(&h_count, d_count, sizeof(int), hipMemcpyDeviceToHost));
    HIP_CHECK(hipMemcpy(h_indices.data(), d_indices, h_count * sizeof(int), 
                        hipMemcpyDeviceToHost));
    
    // Calculate statistics
    double sum = 0.0;
    double min_latency = latencies[0];
    double max_latency = latencies[0];
    
    for (double lat : latencies) {
        sum += lat;
        min_latency = std::min(min_latency, lat);
        max_latency = std::max(max_latency, lat);
    }
    
    double mean_latency = sum / num_iterations;
    
    // Calculate standard deviation
    double variance = 0.0;
    for (double lat : latencies) {
        variance += (lat - mean_latency) * (lat - mean_latency);
    }
    variance /= num_iterations;
    double std_dev = std::sqrt(variance);
    
    // Print results
    std::cout << "\n=== Performance Results ===" << std::endl;
    std::cout << "Mean latency: " << mean_latency << " μs" << std::endl;
    std::cout << "Min latency:  " << min_latency << " μs" << std::endl;
    std::cout << "Max latency:  " << max_latency << " μs" << std::endl;
    std::cout << "Std dev:      " << std_dev << " μs" << std::endl;
    std::cout << "\n=== Results ===" << std::endl;
    std::cout << "Number of indices > " << threshold << ": " << h_count << std::endl;
    
    if (h_count > 0) {
        std::cout << "First 10 indices and their values:" << std::endl;
        for (int i = 0; i < std::min(10, h_count); i++) {
            int idx = h_indices[i];
            std::cout << "  Index " << idx << ": " << h_result[idx] << std::endl;
        }
    }
    
    // Verify a few results manually
    std::cout << "\n=== Verification (first 3 rows) ===" << std::endl;
    for (int i = 0; i < 3; i++) {
        float manual_sum = 0.0f;
        for (int j = 0; j < N; j++) {
            manual_sum += h_matrix[i * N + j] * h_vector[j];
        }
        std::cout << "Row " << i << " - GPU: " << h_result[i] 
                  << ", CPU: " << manual_sum 
                  << ", Diff: " << std::abs(h_result[i] - manual_sum) << std::endl;
    }
    
    // Cleanup
    ROCBLAS_CHECK(rocblas_destroy_handle(handle));
    HIP_CHECK(hipFree(d_vector));
    HIP_CHECK(hipFree(d_matrix));
    HIP_CHECK(hipFree(d_result));
    HIP_CHECK(hipFree(d_indices));
    HIP_CHECK(hipFree(d_count));
    
    std::cout << "\nDone!" << std::endl;
    
    return 0;
}
