#include <cuda.h>
#include <mma.h>
#include <cuda/pipeline>
#include <torch/extension.h>
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm75.h"
#include <vector>
#include "utils/gemmv2/src_iterator.h"
#include "utils/gemmv2/computer.h"
#include "utils/gemmv2/store_util.h"
#include "utils/gemmv2/meta_util.h"
#include <stdio.h>

#include "cuda_bf16.h"


using namespace nvcuda;

// Define Problem size
// Instruction size
#define M 16
#define N 16
#define K 8

// Block Tile Size
#define TileN 128
#define TileM 128
#define TileK 32

// Pipeline
#define Stages 2

// Warp Tile Size
#define wTileN 64
#define wTileM 64


//
//  GeMM kernel with 50% output & block sparsity
//

__device__ void block_spGemmKernel_(
    int m, int n, int k, int nnz,           // nnz: the number of nonzeros (in element) in each row 
    const float* __restrict__ lhs_matrix,
    const float* __restrict__ rhs_matrix,
    const int* __restrict__ indices,
    float* __restrict__ output_matrix,
    int16_t* __restrict__ metadata)
{
    // Get static variables
    constexpr int NWarps = TileN / wTileM;  // Number of warps along the N dimension
    constexpr int MWarps = TileM / wTileM;  // Number of warps along the M dimension
    constexpr int NWarpTiles = wTileN / N;
    constexpr int MWarpTiles = wTileM / M;
    constexpr int NumWarp = NWarps * MWarps;

    // dynamic shared memory
    extern __shared__ float smem[];

    // The number of nonzero blocks per row
    int nnz_block = nnz / TileN;

    // get tile offset
    // We use blockIdx.x to index the Tile row, blockIdx.y to index the NONZERO Tile column
    int m_offset = blockIdx.x * TileM;
    int n_offset = indices[blockIdx.y + nnz_block * blockIdx.x] * TileN;

    // Create CUDA pipeline
    cuda::pipeline<cuda::thread_scope_thread> pipe = cuda::make_pipeline();

    SRCIterator<TileM, TileN, TileK, 4, NumWarp, Stages> src_loader(lhs_matrix, rhs_matrix, smem, m_offset, n_offset, k);
    Computer<MWarps, NWarps, TileM, TileN, TileK, Stages, wTileM, wTileN, 4, M, N, K> computer(smem);


    wmma::fragment<wmma::accumulator, M, N, K, float> c[MWarpTiles][NWarpTiles];

    // Set the fragment to 0
    #pragma unroll
    for (int i = 0; i < MWarpTiles; i++){
        #pragma unroll
        for (int j = 0; j < NWarpTiles; j++){
            wmma::fill_fragment(c[i][j], 0.0f);
        }
    }

    // TODO: Currently, we assume that there is no residual on k dimension
    int k_batch = k / TileK; 
    // Start pipeline
    for (int compute_batch = 0, fetch_batch = 0; compute_batch < k_batch; compute_batch ++){
        for (; fetch_batch < k_batch && fetch_batch < (compute_batch + Stages); fetch_batch ++){
            pipe.producer_acquire();
            src_loader.Load_async(fetch_batch);
            pipe.producer_commit();
        }
        pipe.consumer_wait();
        __syncthreads();
        computer.compute_block_tile(c, compute_batch);
        pipe.consumer_release();
        __syncthreads();
    }

    // Get meta data
    MetaUtil<MWarps, NWarps, wTileM, wTileN, M, N, K, TileM, TileN> meta_out(metadata, smem, output_matrix, m_offset, blockIdx.y * TileN, m, nnz);
    meta_out.get_meta_data(c);

    __syncthreads();

    meta_out.write_nonzeros();
}


template <int kStages>
__device__ void block_spGemmKernel_bf16_(
    int m, int n, int k, int nnz,
    const nv_bfloat16* __restrict__ lhs_matrix,
    const nv_bfloat16* __restrict__ rhs_matrix,
    const int* __restrict__ indices,
    nv_bfloat16* __restrict__ output_matrix,
    int16_t* __restrict__ metadata)
{
    // Get static variables
    constexpr int NWarps = TileN / wTileM;  // Number of warps along the N dimension
    constexpr int MWarps = TileM / wTileM;  // Number of warps along the M dimension
    constexpr int NWarpTiles = wTileN / N;
    constexpr int MWarpTiles = wTileM / M;
    constexpr int NumWarp = NWarps * MWarps;

    // Two nv_bf16 is equivalent with a single float, so we convert the pointers to float
    const float* lhs_matrix_f = reinterpret_cast<const float*>(lhs_matrix);
    const float* rhs_matrix_f = reinterpret_cast<const float*>(rhs_matrix);

    // dynamic shared memory (we still use float)
    extern __shared__ float smem[];

    // The number of nonzero blocks per row
    int nnz_block = nnz / TileN;

    // get tile offset
    // We use blockIdx.x to index the Tile row, blockIdx.y to index the NONZERO Tile column
    int m_offset = blockIdx.x * TileM;
    int n_offset = indices[blockIdx.y + nnz_block * blockIdx.x] * TileN;

    // Create CUDA pipeline
    cuda::pipeline<cuda::thread_scope_thread> pipe = cuda::make_pipeline();

    SRCIteratorInterleaved<TileM, TileN, TileK, 4, NumWarp, kStages> src_loader(lhs_matrix_f, rhs_matrix_f, smem, m_offset, n_offset, k);
    Computer<MWarps, NWarps, TileM, TileN, TileK, kStages, wTileM, wTileN, 4, M, N, K> computer(smem);

    wmma::fragment<wmma::accumulator, M, N, K, float> c[MWarpTiles][NWarpTiles];

    // Set the fragment to 0
    #pragma unroll
    for (int i = 0; i < MWarpTiles; i++){
        #pragma unroll
        for (int j = 0; j < NWarpTiles; j++){
            wmma::fill_fragment(c[i][j], 0.0f);
        }
    }

    // TODO: Currently, we assume that there is no residual on k dimension
    int k_batch = k / TileK; 
    // Start pipeline
    for (int compute_batch = 0, fetch_batch = 0; compute_batch < k_batch; compute_batch ++){
        for (; fetch_batch < k_batch && fetch_batch < (compute_batch + kStages); fetch_batch ++){
            pipe.producer_acquire();
            src_loader.Load_async(fetch_batch);
            pipe.producer_commit();
        }
        pipe.consumer_wait();
        __syncthreads();
        computer.compute_block_tile_bf16(c, compute_batch);
        pipe.consumer_release();
        __syncthreads();
    }
    __syncthreads();

    MetaUtil_bf16<MWarps, NWarps, wTileM, wTileN, M, N, K, TileM, TileN> meta_out(metadata, smem, output_matrix, m_offset, blockIdx.y * TileN, m, nnz);
    meta_out.get_meta_data(c);

    __syncthreads();

    meta_out.write_nonzeros(); 
}


__global__ void block_spGemmKernel(
    int m, int n, int k, int nnz,
    const float* __restrict__ lhs_matrix,
    const float* __restrict__ rhs_matrix,
    const int* __restrict__ indices,
    float* __restrict__ output_matrix,
    int16_t* __restrict__ metadata)
{
    block_spGemmKernel_(m, n, k, nnz, lhs_matrix, rhs_matrix, indices, output_matrix, metadata);
}


__global__ void block_batchedSpGemmKernel(
    int m, int n, int k, int nnz,
    const float* __restrict__ lhs_matrix_b, int lhs_stride,
    const float* __restrict__ rhs_matrix_b, int rhs_stride,
    const int* __restrict__ indices_b, int indices_stride,
    float* __restrict__ output_matrix_b, int output_stride,
    int16_t* __restrict__ metadata_b, int meta_stride)
{
    // Get the entry index
    int entry_idx = blockIdx.z;

    // Get the input pointer for the current entry in the batch
    const float* lhs_matrix = lhs_matrix_b + entry_idx * lhs_stride;
    const float* rhs_matrix = rhs_matrix_b + entry_idx * rhs_stride;
    const int* indices = indices_b + entry_idx * indices_stride;
    float* output_matrix = output_matrix_b + entry_idx * output_stride;
    int16_t* metadata = metadata_b + entry_idx * meta_stride;

    block_spGemmKernel_(m, n, k, nnz, lhs_matrix, rhs_matrix, indices, output_matrix, metadata);
}


__global__ void block_spGemmKernel_bf16(
    int m, int n, int k, int nnz,
    const nv_bfloat16* __restrict__ lhs_matrix,
    const nv_bfloat16* __restrict__ rhs_matrix,
    const int* __restrict__ indices,
    nv_bfloat16* __restrict__ output_matrix,
    int16_t* __restrict__ metadata)
{
    block_spGemmKernel_bf16_<2>(m, n, k, nnz, lhs_matrix, rhs_matrix, indices, output_matrix, metadata);
}


__global__ void block_batchedSpGemmKernel_bf16(
    int m, int n, int k, int nnz,
    const nv_bfloat16* __restrict__ lhs_matrix_b, int lhs_stride,
    const nv_bfloat16* __restrict__ rhs_matrix_b, int rhs_stride,
    const int* __restrict__ indices_b, int indices_stride,
    nv_bfloat16* __restrict__ output_matrix_b, int output_stride, 
    int16_t* __restrict__ metadata_b, int meta_stride)
{
    // Get the entry index
    int entry_idx = blockIdx.z;

    // Get the input pointer for the current entry in the batch
    const nv_bfloat16* lhs_matrix = lhs_matrix_b + entry_idx * lhs_stride;
    const nv_bfloat16* rhs_matrix = rhs_matrix_b + entry_idx * rhs_stride;
    const int* indices = indices_b + entry_idx * indices_stride;
    nv_bfloat16* output_matrix = output_matrix_b + entry_idx * output_stride;
    int16_t* metadata = metadata_b + entry_idx * meta_stride;

    block_spGemmKernel_bf16_<2>(m, n, k, nnz, lhs_matrix, rhs_matrix, indices, output_matrix, metadata);
}


std::vector<torch::Tensor> block_sddmmv2_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices)
{
    int maxbytes = max(TileM * (TileN / 2 + 8) * sizeof(float), (TileM + TileN) * TileK * Stages * sizeof(float));
    cudaFuncSetAttribute(block_spGemmKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);

    int m = lhs_matrix.size(0);
    int k = lhs_matrix.size(1);
    int n = rhs_matrix.size(0);

    int nnz_block = indices.numel() / (m / TileM);
    int nnz = nnz_block * TileN;

    auto options_val = torch::TensorOptions().dtype(torch::kFloat32).device(lhs_matrix.device());
    auto output_matrix = torch::empty({m, nnz/2}, options_val);

    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(lhs_matrix.device());
    auto metadata = torch::empty({m, nnz/8}, options_meta);

    dim3 grid;
    grid.x = m / TileM;
    grid.y = nnz / TileN;

    dim3 block;
    block.x = TileM * TileN / wTileM / wTileN * 32;

    block_spGemmKernel<<<grid, block, maxbytes>>>(
        m, n, k, nnz, 
        lhs_matrix.data<float>(), rhs_matrix.data<float>(),
        indices.data<int>(), output_matrix.data<float>(),
        metadata.data<int16_t>()
    );

    return {output_matrix, metadata};
}

std::vector<torch::Tensor> batched_block_sddmmv2_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices)
{
    int maxbytes = max(TileM * (TileN / 2 + 8) * sizeof(float), (TileM + TileN) * TileK * Stages * sizeof(float));
    cudaFuncSetAttribute(block_batchedSpGemmKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);

    int m = lhs_matrix.size(-2);
    int k = lhs_matrix.size(-1);
    int n = rhs_matrix.size(-2);
    int batch_size = lhs_matrix.numel() / (m * k);

    int nnz_block = indices.numel() / (m / TileM) / batch_size;
    int nnz = nnz_block * TileN;

    auto options_val = torch::TensorOptions().dtype(torch::kFloat32).device(lhs_matrix.device());
    auto output_matrix = torch::empty({batch_size, m, nnz/2}, options_val);

    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(lhs_matrix.device());
    auto metadata = torch::empty({batch_size, m, nnz/8}, options_meta);

    int lhs_stride = m * k;
    int rhs_stride = n * k;
    int output_stride = m * nnz / 2;
    int indices_stride = indices.numel() / batch_size;
    int meta_stride = m * nnz / 8;

    dim3 grid;
    grid.x = m / TileM;
    grid.y = nnz / TileN;
    grid.z = batch_size;

    dim3 block;
    block.x = TileM * TileN / wTileM / wTileN * 32;

    block_batchedSpGemmKernel<<<grid, block, maxbytes>>>(
        m, n, k, nnz, 
        lhs_matrix.data<float>(), lhs_stride,
        rhs_matrix.data<float>(), rhs_stride,
        indices.data<int>(), indices_stride,
        output_matrix.data<float>(), output_stride,
        metadata.data<int16_t>(), meta_stride
    );

    return {output_matrix, metadata};
}


std::vector<torch::Tensor> block_sddmmv2_bf16_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices)
{
    int maxbytes = max(TileM * (TileN + 8) * sizeof(nv_bfloat16), (TileM + TileN) * TileK * Stages * sizeof(float));
    cudaFuncSetAttribute(block_spGemmKernel_bf16, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);

    int m = lhs_matrix.size(0);
    int k = lhs_matrix.size(1) / 2;
    int n = rhs_matrix.size(0);

    int nnz_block = indices.numel() / (m / TileM);
    int nnz = nnz_block * TileN;

    auto options_val = torch::TensorOptions().dtype(torch::kBFloat16).device(lhs_matrix.device());
    auto output_matrix = torch::empty({m, nnz/2}, options_val);

    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(lhs_matrix.device());
    auto metadata = torch::empty({m, nnz/16}, options_meta);

    dim3 grid;
    grid.x = m / TileM;
    grid.y = nnz / TileN;

    dim3 block;
    block.x = TileM * TileN / wTileM / wTileN * 32;

    block_spGemmKernel_bf16<<<grid, block, maxbytes>>>(
        m, n, k, nnz,
        (nv_bfloat16*)lhs_matrix.data_ptr(), (nv_bfloat16*)rhs_matrix.data_ptr(),
        indices.data<int>(), (nv_bfloat16*)output_matrix.data_ptr(), metadata.data<int16_t>()
    );

    return {output_matrix, metadata};
}

std::vector<torch::Tensor> batched_block_sddmmv2_bf16_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices)
{
    int maxbytes = max(TileM * (TileN / 2 + 8) * sizeof(nv_bfloat16), (TileM + TileN) * TileK * Stages * sizeof(float));
    cudaFuncSetAttribute(block_batchedSpGemmKernel_bf16, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);

    int m = lhs_matrix.size(-2);
    int k = lhs_matrix.size(-1) / 2;
    int n = rhs_matrix.size(-2);
    int batch_size = lhs_matrix.numel() / (m * k * 2);

    int nnz_block = indices.numel() / (m / TileM) / batch_size;
    int nnz = nnz_block * TileN;

    auto options_val = torch::TensorOptions().dtype(torch::kBFloat16).device(lhs_matrix.device());
    auto output_matrix = torch::empty({batch_size, m, nnz/2}, options_val);

    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(lhs_matrix.device());
    auto metadata = torch::empty({batch_size, m, nnz/16}, options_meta);

    int lhs_stride = m * k * 2;
    int rhs_stride = n * k * 2;
    int output_stride = m * nnz / 2;
    int indices_stride = indices.numel() / batch_size;
    int meta_stride = m * nnz / 16;

    dim3 grid;
    grid.x = m / TileM;
    grid.y = nnz / TileN;
    grid.z = batch_size;

    dim3 block;
    block.x = TileM * TileN / wTileM / wTileN * 32;

    block_batchedSpGemmKernel_bf16<<<grid, block, maxbytes>>>(
        m, n, k, nnz, 
        (nv_bfloat16*)lhs_matrix.data_ptr(), lhs_stride,
        (nv_bfloat16*)rhs_matrix.data_ptr(), rhs_stride,
        indices.data<int>(), indices_stride,
        (nv_bfloat16*)output_matrix.data_ptr(), output_stride,
        metadata.data<int16_t>(), meta_stride
    );

    return {output_matrix, metadata};
}