#include <cuda.h>
#include <mma.h>
#include <cuda/pipeline>
#include <torch/extension.h>
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm75.h"
#include <vector>
#include "utils/gemmv2/src_iterator.h"
#include "utils/gemmv2/computer.h"
#include "utils/gemmv2/store_util.h"
#include "utils/gemmv2/meta_util.h"
#include "utils/gemmv2/mask_util.h"
#include <stdio.h>

#include "cuda_bf16.h"


using namespace nvcuda;

// Define Problem size
// Instruction size
#define M 16
#define N 16
#define K 8

// Block Tile Size
#define TileN 128
#define TileM 128
#define TileK 32

// Pipeline
#define Stages 2

// Warp Tile Size
#define wTileN 64
#define wTileM 64


//
//  GeMM kernel with 50% output
//


template <bool Mask = false>
__device__ void spGemmKernel_(
    int m, int n, int k,
    const float* __restrict__ lhs_matrix,
    const float* __restrict__ rhs_matrix,
    float* __restrict__ output_matrix,
    int16_t* __restrict__ metadata,
    const float* __restrict__ mask = NULL)
{
    // Get static variables
    constexpr int NWarps = TileN / wTileN;  // Number of warps along the N dimension
    constexpr int MWarps = TileM / wTileM;  // Number of warps along the M dimension
    constexpr int NWarpTiles = wTileN / N;
    constexpr int MWarpTiles = wTileM / M;
    constexpr int NumWarp = NWarps * MWarps;

    // dynamic shared memory
    extern __shared__ float smem[];

    // get tile offset
    int m_offset = blockIdx.x * TileM;
    int n_offset = blockIdx.y * TileN;

    // Create CUDA pipeline
    cuda::pipeline<cuda::thread_scope_thread> pipe = cuda::make_pipeline();

    SRCIterator<TileM, TileN, TileK, 4, NumWarp, Stages> src_loader(lhs_matrix, rhs_matrix, smem, m_offset, n_offset, k);
    Computer<MWarps, NWarps, TileM, TileN, TileK, Stages, wTileM, wTileN, 4, M, N, K> computer(smem);
    // MaskUtil<Mask, TileM, TileN, NumWarp, float4, M, N, K, wTileM, wTileN, MWarps, NWarps> mask_util(mask, smem, m_offset, n_offset, n);
    SeqMaskUtil<Mask, float4, TileN, MWarps, NWarps, wTileM, wTileN, M, N, K> mask_util(mask, smem, n_offset);

    wmma::fragment<wmma::accumulator, M, N, K, float> c[MWarpTiles][NWarpTiles];

    mask_util.load_mask(c);

    // TODO: Currently, we assume that there is no residual on k dimension
    int k_batch = k / TileK; 
    // Start pipeline
    for (int compute_batch = 0, fetch_batch = 0; compute_batch < k_batch; compute_batch ++){
        for (; fetch_batch < k_batch && fetch_batch < (compute_batch + Stages); fetch_batch ++){
            pipe.producer_acquire();
            src_loader.Load_async(fetch_batch);
            pipe.producer_commit();
        }
        pipe.consumer_wait();
        __syncthreads();
        computer.compute_block_tile(c, compute_batch);
        pipe.consumer_release();
        __syncthreads();
    }

    // Get meta data
    MetaUtil<MWarps, NWarps, wTileM, wTileN, M, N, K, TileM, TileN> meta_out(metadata, smem, output_matrix, m_offset, n_offset, m, n);
    meta_out.get_meta_data(c);

    __syncthreads();

    meta_out.write_nonzeros();    
}


template <int kStages, bool Mask = false>
__device__ void spGemmKernel_bf16_(
    int m, int n, int k,
    const nv_bfloat16* __restrict__ lhs_matrix,
    const nv_bfloat16* __restrict__ rhs_matrix,
    nv_bfloat16* __restrict__ output_matrix,
    int16_t* __restrict__ metadata,
    const nv_bfloat16* __restrict__ mask = NULL)
{
    // Get static variables
    constexpr int NWarps = TileN / wTileN;  // Number of warps along the N dimension
    constexpr int MWarps = TileM / wTileM;  // Number of warps along the M dimension
    constexpr int NWarpTiles = wTileN / N;
    constexpr int MWarpTiles = wTileM / M;
    constexpr int NumWarp = NWarps * MWarps;

    // Two nv_bf16 is equivalent with a single float, so we convert the pointers to float
    const float* lhs_matrix_f = reinterpret_cast<const float*>(lhs_matrix);
    const float* rhs_matrix_f = reinterpret_cast<const float*>(rhs_matrix);

    // dynamic shared memory (we still use float)
    extern __shared__ float smem[];

    // get tile offset
    int m_offset = blockIdx.x * TileM;
    int n_offset = blockIdx.y * TileN;

    // Create CUDA pipeline
    cuda::pipeline<cuda::thread_scope_thread> pipe = cuda::make_pipeline();

    SRCIteratorInterleaved<TileM, TileN, TileK, 4, NumWarp, kStages> src_loader(lhs_matrix_f, rhs_matrix_f, smem, m_offset, n_offset, k);


    Computer<MWarps, NWarps, TileM, TileN, TileK, kStages, wTileM, wTileN, 4, M, N, K> computer(smem);
    // BUG: the mask should be interleaved.
    // MaskUtil<Mask, TileM, TileN, NumWarp, float4, M, N, K, wTileM, wTileN, MWarps, NWarps> mask_util(mask, smem, m_offset, n_offset, n);  
    SeqMaskUtil_bf16<Mask, float4, TileN, MWarps, NWarps, wTileM, wTileN, M, N, K> mask_util(mask, smem, n_offset);  

    wmma::fragment<wmma::accumulator, M, N, K, float> c[MWarpTiles][NWarpTiles];
    mask_util.load_mask(c);

    // TODO: Currently, we assume that there is no residual on k dimension
    int k_batch = k / TileK; 
    // Start pipeline
    for (int compute_batch = 0, fetch_batch = 0; compute_batch < k_batch; compute_batch ++){
        for (; fetch_batch < k_batch && fetch_batch < (compute_batch + kStages); fetch_batch ++){
            pipe.producer_acquire();
            src_loader.Load_async(fetch_batch);
            pipe.producer_commit();
        }
        pipe.consumer_wait();
        __syncthreads();
        computer.compute_block_tile_bf16(c, compute_batch);
        pipe.consumer_release();
        __syncthreads();
    }
    __syncthreads();

    MetaUtil_bf16<MWarps, NWarps, wTileM, wTileN, M, N, K, TileM, TileN> meta_out(metadata, smem, output_matrix, m_offset, n_offset, m, n);
    meta_out.get_meta_data(c);

    __syncthreads();

    meta_out.write_nonzeros(); 
}



template <bool Mask = false>
__global__ void spGemmKernel(
    int m, int n, int k,
    const float* __restrict__ lhs_matrix,
    const float* __restrict__ rhs_matrix,
    float* __restrict__ output_matrix,
    int16_t* __restrict__ metadata,
    const float* __restrict__ mask = NULL)
{
    spGemmKernel_<Mask>(m, n, k, lhs_matrix, rhs_matrix, output_matrix, metadata, mask);
}


template <bool Mask = false>
__global__ void spGemmKernel_bf16(
    int m, int n, int k,
    const nv_bfloat16* __restrict__ lhs_matrix,
    const nv_bfloat16* __restrict__ rhs_matrix,
    nv_bfloat16* __restrict__ output_matrix,
    int16_t* __restrict__ metadata,
    const nv_bfloat16* __restrict__ mask = NULL)
{
    spGemmKernel_bf16_<2, Mask>(m, n, k, lhs_matrix, rhs_matrix, output_matrix, metadata, mask);
}


template <bool Mask = false>
__global__ void batchedSpGemmKernel(
    int m, int n, int k,
    const float* __restrict__ lhs_matrix_b,
    int lhs_stride,
    const float* __restrict__ rhs_matrix_b,
    int rhs_stride,
    float* __restrict__ output_matrix_b,
    int output_stride,
    int16_t* __restrict__ metadata_b,
    int meta_stride,
    const float* __restrict__ mask_b = NULL,
    int mask_stride=0, int mask_batch=0)
{
    // Get the entry index
    int entry_idx = blockIdx.z;

    // Get the input pointer for the current entry in the batch
    const float* lhs_matrix = lhs_matrix_b + entry_idx * lhs_stride;
    const float* rhs_matrix = rhs_matrix_b + entry_idx * rhs_stride;
    float* output_matrix = output_matrix_b + entry_idx * output_stride;
    int16_t* metadata = metadata_b + entry_idx * meta_stride;
    const float* mask;
    int mask_idx = entry_idx / mask_batch;
    if (Mask) mask = mask_b + mask_idx * mask_stride; 
    else mask = mask_b;

    spGemmKernel_<Mask>(m, n, k, lhs_matrix, rhs_matrix, output_matrix, metadata, mask);
}


template <bool Mask = false>
__global__ void batchedSpGemmKernel_bf16(
    int m, int n, int k,
    const nv_bfloat16* __restrict__ lhs_matrix_b,
    int lhs_stride,
    const nv_bfloat16* __restrict__ rhs_matrix_b,
    int rhs_stride,
    nv_bfloat16* __restrict__ output_matrix_b,
    int output_stride,
    int16_t* __restrict__ metadata_b,
    int meta_stride, 
    const nv_bfloat16* __restrict__ mask_b = NULL,
    int mask_stride=0, int mask_batch=0)
{
    // Get the entry index
    int entry_idx = blockIdx.z;

    // Get the input pointer for the current entry in the batch
    const nv_bfloat16* lhs_matrix = lhs_matrix_b + entry_idx * lhs_stride;
    const nv_bfloat16* rhs_matrix = rhs_matrix_b + entry_idx * rhs_stride;
    nv_bfloat16* output_matrix = output_matrix_b + entry_idx * output_stride;
    int16_t* metadata = metadata_b + entry_idx * meta_stride;

    const nv_bfloat16* mask;
    int mask_idx = entry_idx / mask_batch;
    if (Mask) mask = mask_b + mask_idx * mask_stride; 
    else mask = mask_b;

    spGemmKernel_bf16_<2, Mask>(m, n, k, lhs_matrix, rhs_matrix, output_matrix, metadata, mask);
}


std::vector<torch::Tensor> sddmmv2_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask)
{
    int maxbytes = max(TileM * (TileN / 2 + 8) * sizeof(float), (TileM + TileN) * TileK * Stages * sizeof(float));

    int m = lhs_matrix.size(0);
    int k = lhs_matrix.size(1);
    int n = rhs_matrix.size(0);

    auto options_val = torch::TensorOptions().dtype(torch::kFloat32).device(lhs_matrix.device());
    auto output_matrix = torch::empty({m, n/2}, options_val);

    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(lhs_matrix.device());
    auto metadata = torch::empty({m, n/8}, options_meta);

    dim3 grid;
    grid.x = m / TileM;
    grid.y = n / TileN;

    dim3 block;
    block.x = TileM * TileN / wTileM / wTileN * 32;
    if (mask.has_value()){
        cudaFuncSetAttribute(spGemmKernel<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
        spGemmKernel<true><<<grid, block, maxbytes>>>(
            m, n, k, lhs_matrix.data<float>(), rhs_matrix.data<float>(), 
            output_matrix.data<float>(), metadata.data<int16_t>(), mask.value().data<float>());
    } else {
        cudaFuncSetAttribute(spGemmKernel<false>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
        spGemmKernel<false><<<grid, block, maxbytes>>>(
            m, n, k, lhs_matrix.data<float>(), rhs_matrix.data<float>(), 
            output_matrix.data<float>(), metadata.data<int16_t>());
    }

    return {output_matrix, metadata};
}


std::vector<torch::Tensor> sddmmv2_bf16_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask)
{
    // int maxbytes = max(TileM * (TileN / 2 + 8) * sizeof(float), (TileM + TileN) * TileK * Stages * sizeof(float));
    int maxbytes = max(TileM * (TileN + 8) * sizeof(nv_bfloat16), (TileM + TileN) * TileK * Stages * sizeof(float));

    int m = lhs_matrix.size(0);
    int k = lhs_matrix.size(1) / 2;
    int n = rhs_matrix.size(0);

    auto options_val = torch::TensorOptions().dtype(torch::kBFloat16).device(lhs_matrix.device());
    auto output_matrix = torch::empty({m, n/2}, options_val);

    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(lhs_matrix.device());
    auto metadata = torch::empty({m, n/16}, options_meta);

    dim3 grid;
    grid.x = m / TileM;
    grid.y = n / TileN;

    dim3 block;
    block.x = TileM * TileN / wTileM / wTileN * 32;

    if (mask.has_value()){
        cudaFuncSetAttribute(spGemmKernel_bf16<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
        spGemmKernel_bf16<true><<<grid, block, maxbytes>>>(
            m, n, k, 
            (nv_bfloat16*)lhs_matrix.data_ptr(), (nv_bfloat16*)rhs_matrix.data_ptr(), 
            (nv_bfloat16*)output_matrix.data_ptr(), metadata.data<int16_t>(), 
            (nv_bfloat16*)mask.value().data_ptr());
    } else {
        cudaFuncSetAttribute(spGemmKernel_bf16<false>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
        spGemmKernel_bf16<false><<<grid, block, maxbytes>>>(
            m, n, k, 
            (nv_bfloat16*)lhs_matrix.data_ptr(), (nv_bfloat16*)rhs_matrix.data_ptr(), 
            (nv_bfloat16*)output_matrix.data_ptr(), metadata.data<int16_t>());
    }

    return {output_matrix, metadata};
}


std::vector<torch::Tensor> batched_sddmmv2_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask)
{
    int maxbytes = max(TileM * (TileN / 2 + 8) * sizeof(float), (TileM + TileN) * TileK * Stages * sizeof(float));

    int m = lhs_matrix.size(-2);
    int k = lhs_matrix.size(-1);
    int n = rhs_matrix.size(-2);
    int batch_size = lhs_matrix.numel() / (m * k);

    int lhs_stride = m * k;
    int rhs_stride = n * k;
    int output_stride = m * n / 2;
    int meta_stride = m * n / 8;

    auto options_val = torch::TensorOptions().dtype(torch::kFloat32).device(lhs_matrix.device());
    auto output_matrix = torch::empty({batch_size, m, n/2}, options_val);

    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(lhs_matrix.device());
    auto metadata = torch::empty({batch_size, m, n/8}, options_meta);

    dim3 grid;
    grid.x = m / TileM;
    grid.y = n / TileN;
    grid.z = batch_size;

    dim3 block;
    block.x = TileM * TileN / wTileM / wTileN * 32;
    
    if (mask.has_value()){
        int mask_stride = mask.value().size(-1) * mask.value().size(-2);
        cudaFuncSetAttribute(batchedSpGemmKernel<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
        batchedSpGemmKernel<true><<<grid, block, maxbytes>>>(
            m, n, k, 
            lhs_matrix.data<float>(), lhs_stride, 
            rhs_matrix.data<float>(), rhs_stride,
            output_matrix.data<float>(), output_stride,
            metadata.data<int16_t>(), meta_stride,
            mask.value().data<float>(), mask_stride, (batch_size / mask.value().size(0))
        );
    } else {
        cudaFuncSetAttribute(batchedSpGemmKernel<false>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
        batchedSpGemmKernel<false><<<grid, block, maxbytes>>>(
            m, n, k, 
            lhs_matrix.data<float>(), lhs_stride, 
            rhs_matrix.data<float>(), rhs_stride,
            output_matrix.data<float>(), output_stride,
            metadata.data<int16_t>(), meta_stride
        );
    }

    return {output_matrix, metadata};
}

std::vector<torch::Tensor> batched_sddmmv2_bf16_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask)
{
    int maxbytes = max(TileM * (TileN / 2 + 8) * sizeof(torch::kBFloat16), (TileM + TileN) * TileK * Stages * sizeof(float));

    int m = lhs_matrix.size(-2);
    int k = lhs_matrix.size(-1) / 2;
    int n = rhs_matrix.size(-2);
    int batch_size = lhs_matrix.numel() / (m * k * 2);

    int lhs_stride = m * k * 2;
    int rhs_stride = n * k * 2;
    int output_stride = m * n / 2;
    int meta_stride = m * n / 16;

    auto options_val = torch::TensorOptions().dtype(torch::kBFloat16).device(lhs_matrix.device());
    auto output_matrix = torch::empty({batch_size, m, n/2}, options_val);

    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(lhs_matrix.device());
    auto metadata = torch::empty({batch_size, m, n/16}, options_meta);

    dim3 grid;
    grid.x = m / TileM;
    grid.y = n / TileN;
    grid.z = batch_size;

    dim3 block;
    block.x = TileM * TileN / wTileM / wTileN * 32;

    if (mask.has_value()){
        int mask_stride = mask.value().size(-1) * mask.value().size(-2);
        cudaFuncSetAttribute(batchedSpGemmKernel_bf16<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
        batchedSpGemmKernel_bf16<true><<<grid, block, maxbytes>>>(
            m, n, k, 
            (nv_bfloat16*)lhs_matrix.data_ptr(), lhs_stride, 
            (nv_bfloat16*)rhs_matrix.data_ptr(), rhs_stride,
            (nv_bfloat16*)output_matrix.data_ptr(), output_stride,
            metadata.data<int16_t>(), meta_stride,
            (nv_bfloat16*)mask.value().data_ptr(), mask_stride, (batch_size / mask.value().size(0))
        );
    } else {
        cudaFuncSetAttribute(batchedSpGemmKernel_bf16<false>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
        batchedSpGemmKernel_bf16<false><<<grid, block, maxbytes>>>(
            m, n, k, 
            (nv_bfloat16*)lhs_matrix.data_ptr(), lhs_stride, 
            (nv_bfloat16*)rhs_matrix.data_ptr(), rhs_stride,
            (nv_bfloat16*)output_matrix.data_ptr(), output_stride,
            metadata.data<int16_t>(), meta_stride
        );
    }

    return {output_matrix, metadata};
}