#include <cuda.h>
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_sparse.h"
#include <stdio.h>
#include <vector>
#include "cuda_bf16.h"

// Define the Tile Size in different levels

using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;

using ThreadblockShape_bf16 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape_bf16 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape_bf16 = cutlass::gemm::GemmShape<16, 8, 32>;

// Define MMA & Epilogue

using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
using bThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;

// DefaultConfigurations for float & bf16
using DefaultConfig = cutlass::gemm::device::DefaultGemmConfiguration<
    cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, float, float, float, float>;


using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
    float, 128 / cutlass::sizeof_bits<float>::value, float, float, 
    cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;

using EpilogueOp_bf16 = cutlass::epilogue::thread::LinearCombination<
    cutlass::bfloat16_t, 128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value, float, float,
    cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;

// Pipeline stages in GEMM
constexpr int NumStages = 3;


using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma<
    float, cutlass::layout::RowMajor, DefaultConfig::kAlignmentA, 
    float, cutlass::layout::RowMajor, DefaultConfig::kAlignmentB,
    float, cutlass::layout::RowMajor, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
    ThreadblockShape, WarpShape, InstructionShape, NumStages, DefaultConfig::Operator>::ThreadblockMma;


using Mma_bf16 = typename cutlass::gemm::threadblock::DefaultSparseMma<
    cutlass::bfloat16_t, cutlass::layout::RowMajor, 128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, 128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value,
    float, cutlass::layout::RowMajor, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
    ThreadblockShape_bf16, WarpShape_bf16, InstructionShape_bf16, NumStages, cutlass::arch::OpMultiplyAdd>::ThreadblockMma;

using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
    ThreadblockShape, typename Mma::Operator, ThreadblockShape::kK / WarpShape::kK, EpilogueOp,
    EpilogueOp::kCount>::Epilogue;

using Epilogue_bf16 = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
    ThreadblockShape_bf16, typename Mma_bf16::Operator, ThreadblockShape_bf16::kK / WarpShape_bf16::kK, EpilogueOp_bf16,
    EpilogueOp_bf16::kCount>::Epilogue;

// Define Shared Memory Storage Union

union SharedStorage {
    typename Mma::SharedStorage main_loop;
    typename Epilogue::SharedStorage epilogue;
};

union SharedStorage_bf16 {
    typename Mma_bf16::SharedStorage main_loop;
    typename Epilogue_bf16::SharedStorage epilogue;
};


// The main Kernel

__device__ void cutlassSpmmKernel_(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma::IteratorA::Params params_A,
    float* __restrict__ ptr_A,
    typename Mma::IteratorB::Params params_B,
    float* __restrict__ ptr_B,
    typename Epilogue::OutputTileIterator::Params params_D,
    float* __restrict__ ptr_D,
    typename Mma::IteratorE::Params params_E,
    Mma::ElementE* __restrict__ ptr_E,
    typename Epilogue::OutputOp::Params output_op_,
    int gemm_k_size)
{
    extern __shared__ int SharedStorageBase[];

    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage *>(SharedStorageBase);

    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord threadblock_tile_offset=threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    // Early exit if CTA is out of range
    if (grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
        grid_tiled_shape.n() <= threadblock_tile_offset.n())
    {
        return;
    }

    // Compute initial location in logical coordinates
    cutlass::MatrixCoord tb_offset_A{
        threadblock_tile_offset.m() * Mma::Shape::kM,
        threadblock_tile_offset.k() * gemm_k_size / Mma::kSparse
    };

    cutlass::MatrixCoord tb_offset_B{
        threadblock_tile_offset.k() * gemm_k_size,
        threadblock_tile_offset.n() * Mma::Shape::kN
    };

    cutlass::MatrixCoord tb_offset_E{
        threadblock_tile_offset.m() * Mma::Shape::kM,
        threadblock_tile_offset.k() * gemm_k_size / Mma::kSparse
    };

    // Problem size
    int problem_size_k = min(problem_size.k(), (threadblock_tile_offset.k() + 1) * gemm_k_size);

    int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK;

    // Compute position within threadblock
    int thread_idx = threadIdx.x;

    // Construct iterators to A, B, and E operands
    typename Mma::IteratorA iterator_A(
        params_A,
        //ref_A.data(),
        ptr_A,
        {problem_size.m(), problem_size_k / Mma::kSparse},
        thread_idx,
        tb_offset_A
    );

    typename Mma::IteratorB iterator_B(
        params_B,
        //ref_B.data(),
        ptr_B,
        {problem_size_k, problem_size.n()},
        thread_idx,
        tb_offset_B
    );

    typename Mma::IteratorE iterator_E(
        params_E,
        // ref_E.data(),
        ptr_E,
        {problem_size.m(),
        problem_size_k / Mma::kSparse / Mma::kElementsPerElementE},
        thread_idx,
        tb_offset_E
    );

    // Broadcast the warp_id computed by lane 0 to ensure dependent code
    // is compuled as warp-uniform
    int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
    int lane_idx = threadIdx.x % 32;

    //
    //  Main loop
    //

    // Construct thread-scoped matrix multiply
    Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

    typename Mma::FragmentC accumulators;

    accumulators.clear();

    if (gemm_k_iterations > 0){
        mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
    }

    //
    //  Epilogue
    //

    Epilogue::OutputOp output_op(output_op_);

    threadblock_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    cutlass::MatrixCoord threadblock_offset(
        threadblock_tile_offset.m() * Mma::Shape::kM,
        threadblock_tile_offset.n() * Mma::Shape::kN
    );

    int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_tiled_shape.m();
    
    typename Epilogue::OutputTileIterator iterator_D(
        params_D,
        ptr_D,
        problem_size.mn(),
        thread_idx,
        threadblock_offset
    );

    
    Epilogue epilogue(
        shared_storage.epilogue,
        thread_idx,
        warp_idx,
        lane_idx
    );

    epilogue(output_op, iterator_D, accumulators, iterator_D);
}


__device__ void cutlassSpmmKernel_bf16_(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma_bf16::IteratorA::Params params_A,
    cutlass::bfloat16_t* __restrict__ ptr_A,
    typename Mma_bf16::IteratorB::Params params_B,
    cutlass::bfloat16_t* __restrict__ ptr_B,
    typename Epilogue_bf16::OutputTileIterator::Params params_D,
    cutlass::bfloat16_t* __restrict__ ptr_D,
    typename Mma_bf16::IteratorE::Params params_E,
    Mma_bf16::ElementE* __restrict__ ptr_E,
    typename Epilogue_bf16::OutputOp::Params output_op_,
    int gemm_k_size)
{
    extern __shared__ int SharedStorageBase[];

    SharedStorage_bf16& shared_storage = *reinterpret_cast<SharedStorage_bf16 *>(SharedStorageBase);

    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord threadblock_tile_offset=threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    // Early exit if CTA is out of range
    if (grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
        grid_tiled_shape.n() <= threadblock_tile_offset.n())
    {
        return;
    }

    // Compute initial location in logical coordinates
    cutlass::MatrixCoord tb_offset_A{
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        threadblock_tile_offset.k() * gemm_k_size / Mma_bf16::kSparse
    };

    cutlass::MatrixCoord tb_offset_B{
        threadblock_tile_offset.k() * gemm_k_size,
        threadblock_tile_offset.n() * Mma_bf16::Shape::kN
    };

    cutlass::MatrixCoord tb_offset_E{
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        threadblock_tile_offset.k() * gemm_k_size / Mma_bf16::kSparse
    };

    // Problem size
    int problem_size_k = min(problem_size.k(), (threadblock_tile_offset.k() + 1) * gemm_k_size);

    int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma_bf16::Shape::kK - 1) / Mma_bf16::Shape::kK;

    // Compute position within threadblock
    int thread_idx = threadIdx.x;

    // Construct iterators to A, B, and E operands
    typename Mma_bf16::IteratorA iterator_A(
        params_A,
        //ref_A.data(),
        ptr_A,
        {problem_size.m(), problem_size_k / Mma::kSparse},
        thread_idx,
        tb_offset_A
    );

    typename Mma_bf16::IteratorB iterator_B(
        params_B,
        //ref_B.data(),
        ptr_B,
        {problem_size_k, problem_size.n()},
        thread_idx,
        tb_offset_B
    );

    typename Mma_bf16::IteratorE iterator_E(
        params_E,
        // ref_E.data(),
        ptr_E,
        {problem_size.m(),
        problem_size_k / Mma_bf16::kSparse / Mma_bf16::kElementsPerElementE},
        thread_idx,
        tb_offset_E
    );

    // Broadcast the warp_id computed by lane 0 to ensure dependent code
    // is compuled as warp-uniform
    int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
    int lane_idx = threadIdx.x % 32;

    //
    //  Main loop
    //

    // Construct thread-scoped matrix multiply
    Mma_bf16 mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

    typename Mma_bf16::FragmentC accumulators;

    accumulators.clear();

    if (gemm_k_iterations > 0){
        mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
    }

    //
    //  Epilogue
    //

    Epilogue_bf16::OutputOp output_op(output_op_);

    threadblock_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    cutlass::MatrixCoord threadblock_offset(
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        threadblock_tile_offset.n() * Mma_bf16::Shape::kN
    );

    int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_tiled_shape.m();
    
    typename Epilogue_bf16::OutputTileIterator iterator_D(
        params_D,
        ptr_D,
        problem_size.mn(),
        thread_idx,
        threadblock_offset
    );

    
    Epilogue_bf16 epilogue(
        shared_storage.epilogue,
        thread_idx,
        warp_idx,
        lane_idx
    );

    epilogue(output_op, iterator_D, accumulators, iterator_D);
}


__global__ void cutlassSpmmKernel(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma::IteratorA::Params params_A,
    float* __restrict__ ptr_A,
    typename Mma::IteratorB::Params params_B,
    float* __restrict__ ptr_B,
    typename Epilogue::OutputTileIterator::Params params_D,
    float* __restrict__ ptr_D,
    typename Mma::IteratorE::Params params_E,
    Mma::ElementE* __restrict__ ptr_E,
    typename Epilogue::OutputOp::Params output_op_,
    int gemm_k_size)
{
    cutlassSpmmKernel_(
        problem_size, grid_tiled_shape,
        params_A, ptr_A, params_B, ptr_B,
        params_D, ptr_D, params_E, ptr_E,
        output_op_, gemm_k_size
    );
}


__global__ void cutlassSpmmKernel_bf16(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma_bf16::IteratorA::Params params_A,
    cutlass::bfloat16_t* __restrict__ ptr_A,
    typename Mma_bf16::IteratorB::Params params_B,
    cutlass::bfloat16_t* __restrict__ ptr_B,
    typename Epilogue_bf16::OutputTileIterator::Params params_D,
    cutlass::bfloat16_t* __restrict__ ptr_D,
    typename Mma_bf16::IteratorE::Params params_E,
    Mma_bf16::ElementE* __restrict__ ptr_E,
    typename Epilogue_bf16::OutputOp::Params output_op_,
    int gemm_k_size)
{
    cutlassSpmmKernel_bf16_(
        problem_size, grid_tiled_shape,
        params_A, ptr_A, params_B, ptr_B,
        params_D, ptr_D, params_E, ptr_E,
        output_op_, gemm_k_size
    );
}


__global__ void batchedCutlassSpmmKernel(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma::IteratorA::Params params_A,
    float* __restrict__ ptr_A, int64_t stride_A,
    typename Mma::IteratorB::Params params_B,
    float* __restrict__ ptr_B, int64_t stride_B,
    typename Epilogue::OutputTileIterator::Params params_D,
    float* __restrict__ ptr_D, int64_t stride_D,
    typename Mma::IteratorE::Params params_E,
    Mma::ElementE* __restrict__ ptr_E, int64_t stride_E,
    typename Epilogue::OutputOp::Params output_op_,
    int gemm_k_size)
{
    extern __shared__ int SharedStorageBase[];

    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage *>(SharedStorageBase);

    bThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord threadblock_tile_offset=threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    // Early exit if CTA is out of range
    if (grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
        grid_tiled_shape.n() <= threadblock_tile_offset.n())
    {
        return;
    }

    int batch_idx = threadblock_swizzle.get_batch_idx();

    // Compute initial location in logical coordinates
    cutlass::MatrixCoord tb_offset_A{
        threadblock_tile_offset.m() * Mma::Shape::kM,
        0
    };

    cutlass::MatrixCoord tb_offset_B{
        0,
        threadblock_tile_offset.n() * Mma::Shape::kN
    };

    cutlass::MatrixCoord tb_offset_E{
        threadblock_tile_offset.m() * Mma::Shape::kM,
        0
    };

    // Problem size
    int problem_size_k = problem_size.k();

    int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK;

    // Compute position within threadblock
    int thread_idx = threadIdx.x;

    // Construct iterators to A, B, and E operands
    typename Mma::IteratorA iterator_A(
        params_A,
        ptr_A,
        {problem_size.m(), problem_size_k / Mma::kSparse},
        thread_idx,
        tb_offset_A
    );

    iterator_A.add_pointer_offset(stride_A * batch_idx);

    typename Mma::IteratorB iterator_B(
        params_B,
        ptr_B,
        {problem_size_k, problem_size.n()},
        thread_idx,
        tb_offset_B
    );

    iterator_B.add_pointer_offset(stride_B * batch_idx);

    typename Mma::IteratorE iterator_E(
        params_E,
        ptr_E,
        {problem_size.m(),
        problem_size_k / Mma::kSparse / Mma::kElementsPerElementE},
        thread_idx,
        tb_offset_E
    );

    iterator_E.add_pointer_offset(stride_E * batch_idx);

    // Broadcast the warp_id computed by lane 0 to ensure dependent code
    // is compuled as warp-uniform
    int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
    int lane_idx = threadIdx.x % 32;

    //
    //  Main loop
    //

    // Construct thread-scoped matrix multiply
    Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

    typename Mma::FragmentC accumulators;

    accumulators.clear();

    if (gemm_k_iterations > 0){
        mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
    }

    //
    //  Epilogue
    //

    Epilogue::OutputOp output_op(output_op_);

    threadblock_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    cutlass::MatrixCoord threadblock_offset(
        threadblock_tile_offset.m() * Mma::Shape::kM,
        threadblock_tile_offset.n() * Mma::Shape::kN
    );

    // int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_tiled_shape.m();
    
    typename Epilogue::OutputTileIterator iterator_D(
        params_D,
        ptr_D,
        problem_size.mn(),
        thread_idx,
        threadblock_offset
    );

    iterator_D.add_pointer_offset(stride_D * batch_idx);

    
    Epilogue epilogue(
        shared_storage.epilogue,
        thread_idx,
        warp_idx,
        lane_idx
    );

    epilogue(output_op, iterator_D, accumulators, iterator_D);
}


__global__ void batchedCutlassSpmmKernel_bf16(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma_bf16::IteratorA::Params params_A,
    cutlass::bfloat16_t* __restrict__ ptr_A, int64_t stride_A,
    typename Mma_bf16::IteratorB::Params params_B,
    cutlass::bfloat16_t* __restrict__ ptr_B, int64_t stride_B,
    typename Epilogue_bf16::OutputTileIterator::Params params_D,
    cutlass::bfloat16_t* __restrict__ ptr_D, int64_t stride_D,
    typename Mma_bf16::IteratorE::Params params_E,
    Mma_bf16::ElementE* __restrict__ ptr_E, int64_t stride_E,
    typename Epilogue_bf16::OutputOp::Params output_op_,
    int gemm_k_size)
{
    extern __shared__ int SharedStorageBase[];

    SharedStorage_bf16& shared_storage = *reinterpret_cast<SharedStorage_bf16 *>(SharedStorageBase);

    bThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord threadblock_tile_offset=threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    // Early exit if CTA is out of range
    if (grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
        grid_tiled_shape.n() <= threadblock_tile_offset.n())
    {
        return;
    }

    int batch_idx = threadblock_swizzle.get_batch_idx();

    // Compute initial location in logical coordinates
    cutlass::MatrixCoord tb_offset_A{
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        0
    };

    cutlass::MatrixCoord tb_offset_B{
        0,
        threadblock_tile_offset.n() * Mma_bf16::Shape::kN
    };

    cutlass::MatrixCoord tb_offset_E{
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        0
    };

    // Problem size
    int problem_size_k = problem_size.k();

    int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma_bf16::Shape::kK - 1) / Mma_bf16::Shape::kK;

    // Compute position within threadblock
    int thread_idx = threadIdx.x;

    // Construct iterators to A, B, and E operands
    typename Mma_bf16::IteratorA iterator_A(
        params_A,
        ptr_A,
        {problem_size.m(), problem_size_k / Mma_bf16::kSparse},
        thread_idx,
        tb_offset_A
    );

    iterator_A.add_pointer_offset(stride_A * batch_idx);

    typename Mma_bf16::IteratorB iterator_B(
        params_B,
        ptr_B,
        {problem_size_k, problem_size.n()},
        thread_idx,
        tb_offset_B
    );

    iterator_B.add_pointer_offset(stride_B * batch_idx);

    typename Mma_bf16::IteratorE iterator_E(
        params_E,
        ptr_E,
        {problem_size.m(),
        problem_size_k / Mma_bf16::kSparse / Mma_bf16::kElementsPerElementE},
        thread_idx,
        tb_offset_E
    );

    iterator_E.add_pointer_offset(stride_E * batch_idx);

    // Broadcast the warp_id computed by lane 0 to ensure dependent code
    // is compuled as warp-uniform
    int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
    int lane_idx = threadIdx.x % 32;

    //
    //  Main loop
    //

    // Construct thread-scoped matrix multiply
    Mma_bf16 mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

    typename Mma_bf16::FragmentC accumulators;

    accumulators.clear();

    if (gemm_k_iterations > 0){
        mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
    }

    //
    //  Epilogue
    //

    Epilogue_bf16::OutputOp output_op(output_op_);

    threadblock_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    cutlass::MatrixCoord threadblock_offset(
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        threadblock_tile_offset.n() * Mma_bf16::Shape::kN
    );

    // int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_tiled_shape.m();
    
    typename Epilogue_bf16::OutputTileIterator iterator_D(
        params_D,
        ptr_D,
        problem_size.mn(),
        thread_idx,
        threadblock_offset
    );

    iterator_D.add_pointer_offset(stride_D * batch_idx);

    
    Epilogue_bf16 epilogue(
        shared_storage.epilogue,
        thread_idx,
        warp_idx,
        lane_idx
    );

    epilogue(output_op, iterator_D, accumulators, iterator_D);
}


torch::Tensor spmmv2_cuda(
    torch::Tensor tensor_a,
    torch::Tensor tensor_b,
    torch::Tensor tensor_e_reordered)
{
    const int m = tensor_a.size(0);
    const int n = tensor_b.size(1);
    const int k = tensor_b.size(0);

    auto options_val = torch::TensorOptions().dtype(torch::kFloat32).device(tensor_b.device());
    auto output_matrix = torch::empty({m, n}, options_val);

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    auto layout_a = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), problem_size.k() / 2));
    auto layout_b = cutlass::layout::RowMajor::packed(problem_size.kn());
    auto layout_e = Mma::LayoutE::packed(cutlass::make_Coord(problem_size.m(), problem_size.k()/Mma::kSparse / Mma::kElementsPerElementE));
    auto layout_d = cutlass::layout::RowMajor::packed(problem_size.mn());

    float alpha = 1.0f;
    float beta = 0.0f;
    
    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        problem_size,
        {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
        1
    );

    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
    dim3 block(Mma::WarpCount::kCount * 32, 1, 1);

    int smem_size = int(sizeof(SharedStorage));

    cudaFuncSetAttribute(cutlassSpmmKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
    cudaFuncSetAttribute(cutlassSpmmKernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);

    int gemm_k_size = ((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) * Mma::Shape::kK;

    cutlassSpmmKernel<<<grid, block, smem_size>>>(
        problem_size, grid_tiled_shape, 
        layout_a, tensor_a.data<float>(),
        layout_b, tensor_b.data<float>(),
        layout_d, output_matrix.data<float>(),
        layout_e, (Mma::ElementE*)tensor_e_reordered.data_ptr(),
        {alpha, beta}, gemm_k_size);

    return output_matrix;
}


torch::Tensor spmmv2_bf16_cuda(
    torch::Tensor tensor_a,
    torch::Tensor tensor_b,
    torch::Tensor tensor_e_reordered)
{
    const int m = tensor_a.size(0);
    const int n = tensor_b.size(1);
    const int k = tensor_b.size(0);

    auto options_val = torch::TensorOptions().dtype(torch::kBFloat16).device(tensor_b.device());
    auto output_matrix = torch::empty({m, n}, options_val);

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    auto layout_a = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), problem_size.k() / 2));
    auto layout_b = cutlass::layout::RowMajor::packed(problem_size.kn());
    auto layout_e = Mma::LayoutE::packed(cutlass::make_Coord(problem_size.m(), problem_size.k()/Mma_bf16::kSparse / Mma_bf16::kElementsPerElementE));
    auto layout_d = cutlass::layout::RowMajor::packed(problem_size.mn());

    cutlass::bfloat16_t alpha = cutlass::bfloat16_t(1.0);
    cutlass::bfloat16_t beta = cutlass::bfloat16_t(0.0);
    
    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        problem_size,
        {ThreadblockShape_bf16::kM, ThreadblockShape_bf16::kN, ThreadblockShape_bf16::kK},
        1
    );

    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
    dim3 block(Mma_bf16::WarpCount::kCount * 32, 1, 1);

    int smem_size = int(sizeof(SharedStorage_bf16));

    cudaFuncSetAttribute(cutlassSpmmKernel_bf16, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
    cudaFuncSetAttribute(cutlassSpmmKernel_bf16, cudaFuncAttributePreferredSharedMemoryCarveout, 100);

    int gemm_k_size = ((problem_size.k() + Mma_bf16::Shape::kK - 1) / Mma_bf16::Shape::kK) * Mma_bf16::Shape::kK;

    cutlassSpmmKernel_bf16<<<grid, block, smem_size>>>(
        problem_size, grid_tiled_shape, 
        layout_a, (cutlass::bfloat16_t*)tensor_a.data_ptr(),
        layout_b, (cutlass::bfloat16_t*)tensor_b.data_ptr(),
        layout_d, (cutlass::bfloat16_t*)output_matrix.data_ptr(),
        layout_e, (Mma_bf16::ElementE*)tensor_e_reordered.data_ptr(),
        {alpha, beta}, gemm_k_size);

    return output_matrix;
}


torch::Tensor batched_spmmv2_cuda(
    torch::Tensor tensor_a,
    torch::Tensor tensor_b,
    torch::Tensor tensor_e_reordered)
{
    const int m = tensor_a.size(-2);
    const int n = tensor_b.size(-1);
    const int k = tensor_b.size(-2);

    const int batch_size = tensor_b.numel() / (k * n);

    auto options_val = torch::TensorOptions().dtype(torch::kFloat32).device(tensor_b.device());
    auto output_matrix = torch::empty({batch_size, m, n}, options_val);

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    auto layout_a = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), problem_size.k() / 2));
    auto layout_b = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.k(), problem_size.n()));
    auto layout_e = Mma::LayoutE::packed(cutlass::make_Coord(problem_size.m(), problem_size.k()/Mma::kSparse / Mma::kElementsPerElementE));
    auto layout_d = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), problem_size.n()));

    int64_t stride_a = m * k / 2;
    int64_t stride_b = k * n;
    int64_t stride_e = m * k / 8;
    int64_t stride_d = m * n;

    float alpha = 1.0f;
    float beta = 0.0f;

    bThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        problem_size,
        {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
        batch_size
    );

    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);

    dim3 block(Mma::WarpCount::kCount * 32, 1, 1);

    int smem_size = int(sizeof(SharedStorage));

    cudaFuncSetAttribute(batchedCutlassSpmmKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
    cudaFuncSetAttribute(batchedCutlassSpmmKernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);

    int gemm_k_size = ((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) * Mma::Shape::kK;

    batchedCutlassSpmmKernel<<<grid, block, smem_size>>>(
        problem_size, grid_tiled_shape, 
        layout_a, tensor_a.data<float>(), stride_a,
        layout_b, tensor_b.data<float>(), stride_b,
        layout_d, output_matrix.data<float>(), stride_d,
        layout_e, (Mma::ElementE*)tensor_e_reordered.data_ptr(), stride_e,
        {alpha, beta}, gemm_k_size);

    return output_matrix;
}


torch::Tensor batched_spmmv2_bf16_cuda(
    torch::Tensor tensor_a,
    torch::Tensor tensor_b,
    torch::Tensor tensor_e_reordered)
{
    const int m = tensor_a.size(-2);
    const int n = tensor_b.size(-1);
    const int k = tensor_b.size(-2);

    const int batch_size = tensor_b.numel() / (k * n);

    auto options_val = torch::TensorOptions().dtype(torch::kBFloat16).device(tensor_b.device());
    auto output_matrix = torch::empty({batch_size, m, n}, options_val);

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    auto layout_a = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), problem_size.k() / 2));
    auto layout_b = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.k(), problem_size.n()));
    auto layout_e = Mma::LayoutE::packed(cutlass::make_Coord(problem_size.m(), problem_size.k()/Mma_bf16::kSparse / Mma_bf16::kElementsPerElementE));
    auto layout_d = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), problem_size.n()));

    int64_t stride_a = m * k / 2;
    int64_t stride_b = k * n;
    int64_t stride_e = m * k / 16;
    int64_t stride_d = m * n;

    cutlass::bfloat16_t alpha = cutlass::bfloat16_t(1.0);
    cutlass::bfloat16_t beta = cutlass::bfloat16_t(0.0);

    bThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        problem_size,
        {ThreadblockShape_bf16::kM, ThreadblockShape_bf16::kN, ThreadblockShape_bf16::kK},
        batch_size
    );

    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);

    dim3 block(Mma_bf16::WarpCount::kCount * 32, 1, 1);

    int smem_size = int(sizeof(SharedStorage_bf16));

    cudaFuncSetAttribute(batchedCutlassSpmmKernel_bf16, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
    cudaFuncSetAttribute(batchedCutlassSpmmKernel_bf16, cudaFuncAttributePreferredSharedMemoryCarveout, 100);

    int gemm_k_size = ((problem_size.k() + Mma_bf16::Shape::kK - 1) / Mma_bf16::Shape::kK) * Mma_bf16::Shape::kK;

    batchedCutlassSpmmKernel_bf16<<<grid, block, smem_size>>>(
        problem_size, grid_tiled_shape, 
        layout_a, (cutlass::bfloat16_t*)tensor_a.data_ptr(), stride_a,
        layout_b, (cutlass::bfloat16_t*)tensor_b.data_ptr(), stride_b,
        layout_d, (cutlass::bfloat16_t*)output_matrix.data_ptr(), stride_d,
        layout_e, (Mma::ElementE*)tensor_e_reordered.data_ptr(), stride_e,
        {alpha, beta}, gemm_k_size);

    return output_matrix;
}