#include <cuda.h>
#include <torch/extension.h>
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_sparse.h"
#include <stdio.h>
#include <vector>
#include "cuda_bf16.h"
#include "utils/block_iterator/block_tile_access_iterator.h"

// Define the Tile Size in different levels

using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;

using ThreadblockShape_bf16 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape_bf16 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape_bf16 = cutlass::gemm::GemmShape<16, 8, 32>;

// Define ThreadblockSwizzle
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
using bThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;

// Default Configuration for float
using DefaultConfig = cutlass::gemm::device::DefaultGemmConfiguration<
    cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, float, float, float, float>;

using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
    float, 128 / cutlass::sizeof_bits<float>::value, float, float, 
    cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;

using EpilogueOp_bf16 = cutlass::epilogue::thread::LinearCombination<
    cutlass::bfloat16_t, 128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value, float, float,
    cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;

// Pipeline stages in GEMM
constexpr int NumStages = 3;
constexpr int BlockSize = 128;

using DefaultSparseMma = typename cutlass::gemm::threadblock::DefaultSparseMma<
    float, cutlass::layout::RowMajor, DefaultConfig::kAlignmentA, 
    float, cutlass::layout::RowMajor, DefaultConfig::kAlignmentB,
    float, cutlass::layout::RowMajor, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
    ThreadblockShape, WarpShape, InstructionShape, NumStages, DefaultConfig::Operator>;

using DefaultSparseMma_bf16 = typename cutlass::gemm::threadblock::DefaultSparseMma<
    cutlass::bfloat16_t, cutlass::layout::RowMajor, 128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, 128 / cutlass::sizeof_bits<cutlass::bfloat16_t>::value,
    float, cutlass::layout::RowMajor, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
    ThreadblockShape_bf16, WarpShape_bf16, InstructionShape_bf16, NumStages, cutlass::arch::OpMultiplyAdd>;

using blockIteratorB = cutlass::transform::threadblock::block::blockPredicatedTileAccessIterator<
    cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
    float, cutlass::layout::RowMajor, 0, DefaultSparseMma::ThreadMapB, DefaultSparseMma::AccessTypeB, BlockSize>;

using blockIteratorB_bf16 = cutlass::transform::threadblock::block::blockPredicatedTileAccessIterator<
    cutlass::MatrixShape<ThreadblockShape_bf16::kK, ThreadblockShape_bf16::kN>,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, 0, DefaultSparseMma_bf16::ThreadMapB, DefaultSparseMma_bf16::AccessTypeB, BlockSize>;

using Mma = cutlass::gemm::threadblock::SparseMmaMultistage<
    DefaultSparseMma::MmaCore::Shape, DefaultSparseMma::IteratorA, DefaultSparseMma::MmaCore::SmemIteratorA,
    DefaultSparseMma::MmaCore::kCacheOpA, blockIteratorB, DefaultSparseMma::MmaCore::SmemIteratorB,
    DefaultSparseMma::MmaCore::kCacheOpB, float, cutlass::layout::RowMajor,
    DefaultSparseMma::IteratorE, DefaultSparseMma::MmaCore::SmemIteratorE, DefaultSparseMma::MmaCore::kCacheOpE,
    DefaultSparseMma::MmaCore::MmaPolicy, NumStages>;

using Mma_bf16 = cutlass::gemm::threadblock::SparseMmaMultistage<
    DefaultSparseMma_bf16::MmaCore::Shape, DefaultSparseMma_bf16::IteratorA, DefaultSparseMma_bf16::MmaCore::SmemIteratorA,
    DefaultSparseMma_bf16::MmaCore::kCacheOpA, blockIteratorB_bf16, DefaultSparseMma_bf16::MmaCore::SmemIteratorB,
    DefaultSparseMma_bf16::MmaCore::kCacheOpB, cutlass::bfloat16_t, cutlass::layout::RowMajor,
    DefaultSparseMma_bf16::IteratorE, DefaultSparseMma_bf16::MmaCore::SmemIteratorE, DefaultSparseMma_bf16::MmaCore::kCacheOpE,
    DefaultSparseMma_bf16::MmaCore::MmaPolicy, NumStages>;

using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
    ThreadblockShape, typename Mma::Operator, ThreadblockShape::kK / WarpShape::kK, EpilogueOp,
    EpilogueOp::kCount>::Epilogue;

using Epilogue_bf16 = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
    ThreadblockShape_bf16, typename Mma_bf16::Operator, ThreadblockShape_bf16::kK / WarpShape_bf16::kK, EpilogueOp_bf16,
    EpilogueOp_bf16::kCount>::Epilogue;

union SharedStorage {
    typename Mma::SharedStorage main_loop;
    typename Epilogue::SharedStorage epilogue;
};

union SharedStorage_bf16 {
    typename Mma_bf16::SharedStorage main_loop;
    typename Epilogue_bf16::SharedStorage epilogue;
};

//////////////////////////////////////////////////////////////////////////////
// SpMM kernel for hybrid Blocked-ELL & 50% structured sparsity under float //
//////////////////////////////////////////////////////////////////////////////


// The main kernel
__global__ void BlockCutlassSpmmKernel(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord problem_size_sp,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma::IteratorA::Params params_A,
    float* __restrict__ ptr_A,
    typename Mma::IteratorB::Params params_B,
    float* __restrict__ ptr_B,
    typename Epilogue::OutputTileIterator::Params params_D,
    float* __restrict__ ptr_D,
    typename Mma::IteratorE::Params params_E,
    Mma::ElementE* __restrict__ ptr_E,
    typename Epilogue::OutputOp::Params output_op_,
    int* __restrict__ indices,                         // the block indices
    int nnz_block,                                     // the number of nonzero blocks in each row
    int gemm_k_size)
{
    // Get the indices of the current block row
    int row_block = blockIdx.x * ThreadblockShape::kM / BlockSize;
    int* indices_t = indices + nnz_block * row_block;

    extern __shared__ int SharedStorageBase[];

    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage *>(SharedStorageBase);

    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord threadblock_tile_offset=threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    // Early exit if CTA is out of range
    if (grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
        grid_tiled_shape.n() <= threadblock_tile_offset.n())
    {
        return;
    }

    // Compute initial location in logical coordinates
    cutlass::MatrixCoord tb_offset_A{
        threadblock_tile_offset.m() * Mma::Shape::kM,
        0
    };

    cutlass::MatrixCoord tb_offset_B{
        0,
        threadblock_tile_offset.n() * Mma::Shape::kN
    };

    cutlass::MatrixCoord tb_offset_E{
        threadblock_tile_offset.m() * Mma::Shape::kM,
        0
    };

    // Problem size
    int problem_size_k = problem_size_sp.k();

    int gemm_k_iterations = (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK;

    // Compute position within threadblock
    int thread_idx = threadIdx.x;

    // Construct iterators to A, B, and E operands
    typename Mma::IteratorA iterator_A(
        params_A,
        //ref_A.data(),
        ptr_A,
        {problem_size.m(), problem_size_k / Mma::kSparse},
        thread_idx,
        tb_offset_A
    );
    
    typename Mma::IteratorB iterator_B(
        params_B,
        //ref_B.data(),
        ptr_B,
        {problem_size.k(), problem_size.n()},
        thread_idx,
        tb_offset_B,
        indices_t
    );
    
    typename Mma::IteratorE iterator_E(
        params_E,
        // ref_E.data(),
        ptr_E,
        {problem_size.m(),
        problem_size_k / Mma::kSparse / Mma::kElementsPerElementE},
        thread_idx,
        tb_offset_E
    );

    // Broadcast the warp_id computed by lane 0 to ensure dependent code
    // is compuled as warp-uniform
    int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
    int lane_idx = threadIdx.x % 32;

    //
    //  Main loop
    //

    // Construct thread-scoped matrix multiply
    Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

    typename Mma::FragmentC accumulators;

    accumulators.clear();

    if (gemm_k_iterations > 0){
        mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
    }

    //
    //  Epilogue
    //

    Epilogue::OutputOp output_op(output_op_);

    threadblock_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    cutlass::MatrixCoord threadblock_offset(
        threadblock_tile_offset.m() * Mma::Shape::kM,
        threadblock_tile_offset.n() * Mma::Shape::kN
    );

    int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_tiled_shape.m();
    
    typename Epilogue::OutputTileIterator iterator_D(
        params_D,
        ptr_D,
        problem_size.mn(),
        thread_idx,
        threadblock_offset
    );

    
    Epilogue epilogue(
        shared_storage.epilogue,
        thread_idx,
        warp_idx,
        lane_idx
    );

    epilogue(output_op, iterator_D, accumulators, iterator_D);
}


torch::Tensor block_spmmv2_cuda(
    torch::Tensor tensor_a,
    torch::Tensor tensor_b,
    torch::Tensor tensor_e_reordered,
    torch::Tensor indices)
{
    const int m = tensor_a.size(0);
    const int n = tensor_b.size(1);
    const int k = tensor_b.size(0);
    // The 50% sparsity reduce k by half
    // The block sparsity further reduces k
    const int k_sp = tensor_a.size(1) * 2;

    // Get the number of nonzero blocks in each row
    // (m/BlockSize) block rows
    const int nnz_block = indices.numel() / (m / BlockSize);


    auto options_val = torch::TensorOptions().dtype(torch::kFloat32).device(tensor_b.device());
    auto output_matrix = torch::empty({m, n}, options_val);

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);
    cutlass::gemm::GemmCoord problem_size_sp(m, n, k_sp);

    auto layout_a = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), k_sp / 2));
    auto layout_b = cutlass::layout::RowMajor::packed(problem_size.kn());
    auto layout_e = Mma::LayoutE::packed(cutlass::make_Coord(problem_size.m(), k_sp/Mma::kSparse / Mma::kElementsPerElementE));
    auto layout_d = cutlass::layout::RowMajor::packed(problem_size.mn());

    float alpha = 1.0f;
    float beta = 0.0f;
    
    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        problem_size,
        {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
        1
    );

    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
    dim3 block(Mma::WarpCount::kCount * 32, 1, 1);

    int smem_size = int(sizeof(SharedStorage));

    cudaFuncSetAttribute(BlockCutlassSpmmKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
    cudaFuncSetAttribute(BlockCutlassSpmmKernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);

    int gemm_k_size = ((k_sp + Mma::Shape::kK - 1) / Mma::Shape::kK) * Mma::Shape::kK;
    
    BlockCutlassSpmmKernel<<<grid, block, smem_size>>>(
        problem_size, problem_size_sp,
        grid_tiled_shape, 
        layout_a, tensor_a.data<float>(),
        layout_b, tensor_b.data<float>(),
        layout_d, output_matrix.data<float>(),
        layout_e, (Mma::ElementE*)tensor_e_reordered.data_ptr(),
        {alpha, beta}, 
        indices.data<int>(), nnz_block,
        gemm_k_size);
        
    return output_matrix;
}


//////////////////////////////////////////////////////////////////////////////////////
// Batched SpMM kernel for hybrid Blocked-ELL & 50% structured sparsity under float //
//////////////////////////////////////////////////////////////////////////////////////

__global__ void batchedBlockCutlassSpmmKernel(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord problem_size_sp,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma::IteratorA::Params params_A,
    float* __restrict__ ptr_A, int64_t stride_A,
    typename Mma::IteratorB::Params params_B,
    float* __restrict__ ptr_B, int64_t stride_B,
    typename Epilogue::OutputTileIterator::Params params_D,
    float* __restrict__ ptr_D, int64_t stride_D,
    typename Mma::IteratorE::Params params_E,
    Mma::ElementE* __restrict__ ptr_E, int64_t stride_E,
    typename Epilogue::OutputOp::Params output_op_,
    int* __restrict__ indices, int64_t stride_indices,
    int nnz_block, int gemm_k_size)
{
    extern __shared__ int SharedStorageBase[];

    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage *>(SharedStorageBase);

    bThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord threadblock_tile_offset=threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    // Early exit if CTA is out of range
    if (grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
        grid_tiled_shape.n() <= threadblock_tile_offset.n())
    {
        return;
    }

    int batch_idx = threadblock_swizzle.get_batch_idx();

    // Get the indices of the current block 
    int row_block = blockIdx.x * ThreadblockShape::kM / BlockSize;
    int* indices_t = indices + batch_idx * stride_indices + nnz_block * row_block;

    // Compute initial location in logical coordinates
    cutlass::MatrixCoord tb_offset_A{
        threadblock_tile_offset.m() * Mma::Shape::kM,
        0
    };

    cutlass::MatrixCoord tb_offset_B{
        0,
        threadblock_tile_offset.n() * Mma::Shape::kN
    };

    cutlass::MatrixCoord tb_offset_E{
        threadblock_tile_offset.m() * Mma::Shape::kM,
        0
    };

    // Problem size
    int problem_size_k = problem_size_sp.k();

    int gemm_k_iterations = (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK;

    // Compute position within threadblock
    int thread_idx = threadIdx.x;

    // Construct iterators to A, B, and E operands
    typename Mma::IteratorA iterator_A(
        params_A,
        ptr_A,
        {problem_size.m(), problem_size_k / Mma::kSparse},
        thread_idx,
        tb_offset_A
    );

    iterator_A.add_pointer_offset(stride_A * batch_idx);

    typename Mma::IteratorB iterator_B(
        params_B,
        //ref_B.data(),
        ptr_B,
        {problem_size.k(), problem_size.n()},
        thread_idx,
        tb_offset_B,
        indices_t
    );

    iterator_B.add_pointer_offset(stride_B * batch_idx);

    typename Mma::IteratorE iterator_E(
        params_E,
        // ref_E.data(),
        ptr_E,
        {problem_size.m(),
        problem_size_k / Mma::kSparse / Mma::kElementsPerElementE},
        thread_idx,
        tb_offset_E
    );

    iterator_E.add_pointer_offset(stride_E * batch_idx);

    // Broadcast the warp_id computed by lane 0 to ensure dependent code
    // is compuled as warp-uniform
    int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
    int lane_idx = threadIdx.x % 32;

    //
    //  Main loop
    //

    // Construct thread-scoped matrix multiply
    Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

    typename Mma::FragmentC accumulators;

    accumulators.clear();

    if (gemm_k_iterations > 0){
        mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
    }

    //
    //  Epilogue
    //

    Epilogue::OutputOp output_op(output_op_);

    threadblock_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    cutlass::MatrixCoord threadblock_offset(
        threadblock_tile_offset.m() * Mma::Shape::kM,
        threadblock_tile_offset.n() * Mma::Shape::kN
    );

    typename Epilogue::OutputTileIterator iterator_D(
        params_D,
        ptr_D,
        problem_size.mn(),
        thread_idx,
        threadblock_offset
    );

    iterator_D.add_pointer_offset(stride_D * batch_idx);
    
    Epilogue epilogue(
        shared_storage.epilogue,
        thread_idx,
        warp_idx,
        lane_idx
    );

    epilogue(output_op, iterator_D, accumulators, iterator_D);
}


torch::Tensor batched_block_spmmv2_cuda(
    torch::Tensor tensor_a,
    torch::Tensor tensor_b,
    torch::Tensor tensor_e_reordered,
    torch::Tensor indices)
{
    const int m = tensor_a.size(-2);
    const int n = tensor_b.size(-1);
    const int k = tensor_b.size(-2);
    // The 50% sparsity reduce k by half
    // The block sparsity further reduces k
    const int k_sp = tensor_a.size(-1) * 2;

    const int batch_size = tensor_b.numel() / (k * n);

    // Get the number of nonzero blocks in each row
    // (m/BlockSize) block rows
    const int nnz_block = indices.numel() / batch_size / (m / BlockSize);

    auto options_val = torch::TensorOptions().dtype(torch::kFloat32).device(tensor_b.device());
    auto output_matrix = torch::empty({batch_size, m, n}, options_val);

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);
    cutlass::gemm::GemmCoord problem_size_sp(m, n, k_sp);

    auto layout_a = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), k_sp / 2));
    auto layout_b = cutlass::layout::RowMajor::packed(problem_size.kn());
    auto layout_e = Mma::LayoutE::packed(cutlass::make_Coord(problem_size.m(), k_sp/Mma::kSparse / Mma::kElementsPerElementE));
    auto layout_d = cutlass::layout::RowMajor::packed(problem_size.mn());

    int64_t stride_a = m * k_sp / 2;
    int64_t stride_b = k * n;
    int64_t stride_e = m * k_sp / 8;
    int64_t stride_d = m * n;
    int64_t stride_indices = indices.numel() / batch_size;

    float alpha = 1.0f;
    float beta = 0.0f;

    bThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        problem_size,
        {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
        batch_size
    );

    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
    dim3 block(Mma::WarpCount::kCount * 32, 1, 1);

    int smem_size = int(sizeof(SharedStorage));

    cudaFuncSetAttribute(batchedBlockCutlassSpmmKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
    cudaFuncSetAttribute(batchedBlockCutlassSpmmKernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);

    int gemm_k_size = ((k_sp + Mma::Shape::kK - 1) / Mma::Shape::kK) * Mma::Shape::kK;

    batchedBlockCutlassSpmmKernel<<<grid, block, smem_size>>>(
        problem_size, problem_size_sp,
        grid_tiled_shape,
        layout_a, tensor_a.data<float>(), stride_a, 
        layout_b, tensor_b.data<float>(), stride_b,
        layout_d, output_matrix.data<float>(), stride_d, 
        layout_e, (Mma::ElementE*)tensor_e_reordered.data_ptr(), stride_e,
        {alpha, beta},
        indices.data<int>(), stride_indices,
        nnz_block, gemm_k_size
    );

    return output_matrix;
}


/////////////////////////////////////////////////////////////////////////////////
// SpMM kernel for hybrid Blocked-ELL & 50% structured sparsity under bfloat16 //
/////////////////////////////////////////////////////////////////////////////////

// The main kernel
__global__ void BlockCutlassSpmmKernel_bf16(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord problem_size_sp,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma_bf16::IteratorA::Params params_A,
    cutlass::bfloat16_t* __restrict__ ptr_A,
    typename Mma_bf16::IteratorB::Params params_B,
    cutlass::bfloat16_t* __restrict__ ptr_B,
    typename Epilogue_bf16::OutputTileIterator::Params params_D,
    cutlass::bfloat16_t* __restrict__ ptr_D,
    typename Mma_bf16::IteratorE::Params params_E,
    Mma_bf16::ElementE* __restrict__ ptr_E,
    typename Epilogue_bf16::OutputOp::Params output_op_,
    int* __restrict__ indices,                         // the block indices
    int nnz_block,                                     // the number of nonzero blocks in each row
    int gemm_k_size)
{
    // Get the indices of the current block row
    int row_block = blockIdx.x * ThreadblockShape_bf16::kM / BlockSize;
    int* indices_t = indices + nnz_block * row_block;

    extern __shared__ int SharedStorageBase[];

    SharedStorage_bf16& shared_storage = *reinterpret_cast<SharedStorage_bf16 *>(SharedStorageBase);

    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord threadblock_tile_offset=threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    // Early exit if CTA is out of range
    if (grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
        grid_tiled_shape.n() <= threadblock_tile_offset.n())
    {
        return;
    }

    // Compute initial location in logical coordinates
    cutlass::MatrixCoord tb_offset_A{
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        0
    };

    cutlass::MatrixCoord tb_offset_B{
        0,
        threadblock_tile_offset.n() * Mma_bf16::Shape::kN
    };

    cutlass::MatrixCoord tb_offset_E{
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        0
    };

    // Problem size
    int problem_size_k = problem_size_sp.k();

    int gemm_k_iterations = (problem_size_k + Mma_bf16::Shape::kK - 1) / Mma_bf16::Shape::kK;

    // Compute position within threadblock
    int thread_idx = threadIdx.x;

    // Construct iterators to A, B, and E operands
    typename Mma_bf16::IteratorA iterator_A(
        params_A,
        //ref_A.data(),
        ptr_A,
        {problem_size.m(), problem_size_k / Mma_bf16::kSparse},
        thread_idx,
        tb_offset_A
    );
    
    typename Mma_bf16::IteratorB iterator_B(
        params_B,
        //ref_B.data(),
        ptr_B,
        {problem_size.k(), problem_size.n()},
        thread_idx,
        tb_offset_B,
        indices_t
    );
    
    typename Mma_bf16::IteratorE iterator_E(
        params_E,
        // ref_E.data(),
        ptr_E,
        {problem_size.m(),
        problem_size_k / Mma::kSparse / Mma::kElementsPerElementE},
        thread_idx,
        tb_offset_E
    );

    // Broadcast the warp_id computed by lane 0 to ensure dependent code
    // is compuled as warp-uniform
    int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
    int lane_idx = threadIdx.x % 32;

    //
    //  Main loop
    //

    // Construct thread-scoped matrix multiply
    Mma_bf16 mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

    typename Mma_bf16::FragmentC accumulators;

    accumulators.clear();

    if (gemm_k_iterations > 0){
        mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
    }

    //
    //  Epilogue
    //

    Epilogue_bf16::OutputOp output_op(output_op_);

    threadblock_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    cutlass::MatrixCoord threadblock_offset(
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        threadblock_tile_offset.n() * Mma_bf16::Shape::kN
    );

    int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_tiled_shape.m();
    
    typename Epilogue_bf16::OutputTileIterator iterator_D(
        params_D,
        ptr_D,
        problem_size.mn(),
        thread_idx,
        threadblock_offset
    );

    
    Epilogue_bf16 epilogue(
        shared_storage.epilogue,
        thread_idx,
        warp_idx,
        lane_idx
    );

    epilogue(output_op, iterator_D, accumulators, iterator_D);
}


torch::Tensor block_spmmv2_bf16_cuda(
    torch::Tensor tensor_a,
    torch::Tensor tensor_b,
    torch::Tensor tensor_e_reordered,
    torch::Tensor indices)
{
    const int m = tensor_a.size(0);
    const int n = tensor_b.size(1);
    const int k = tensor_b.size(0);
    // The 50% sparsity reduce k by half
    // The block sparsity further reduces k
    const int k_sp = tensor_a.size(1) * 2;

    // Get the number of nonzero blocks in each row
    // (m/BlockSize) block rows
    const int nnz_block = indices.numel() / (m / BlockSize);

    auto options_val = torch::TensorOptions().dtype(torch::kBFloat16).device(tensor_b.device());
    auto output_matrix = torch::empty({m, n}, options_val);

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);
    cutlass::gemm::GemmCoord problem_size_sp(m, n, k_sp);

    auto layout_a = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), k_sp / 2));
    auto layout_b = cutlass::layout::RowMajor::packed(problem_size.kn());
    auto layout_e = Mma::LayoutE::packed(cutlass::make_Coord(problem_size.m(), k_sp/Mma_bf16::kSparse / Mma_bf16::kElementsPerElementE));
    auto layout_d = cutlass::layout::RowMajor::packed(problem_size.mn());

    float alpha = 1.0f;
    float beta = 0.0f;
    
    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        problem_size,
        {ThreadblockShape_bf16::kM, ThreadblockShape_bf16::kN, ThreadblockShape_bf16::kK},
        1
    );

    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
    dim3 block(Mma_bf16::WarpCount::kCount * 32, 1, 1);

    int smem_size = int(sizeof(SharedStorage));

    cudaFuncSetAttribute(BlockCutlassSpmmKernel_bf16, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
    cudaFuncSetAttribute(BlockCutlassSpmmKernel_bf16, cudaFuncAttributePreferredSharedMemoryCarveout, 100);

    int gemm_k_size = ((k_sp + Mma_bf16::Shape::kK - 1) / Mma_bf16::Shape::kK) * Mma_bf16::Shape::kK;
    
    BlockCutlassSpmmKernel_bf16<<<grid, block, smem_size>>>(
        problem_size, problem_size_sp,
        grid_tiled_shape, 
        layout_a, (cutlass::bfloat16_t*)tensor_a.data_ptr(),
        layout_b, (cutlass::bfloat16_t*)tensor_b.data_ptr(),
        layout_d, (cutlass::bfloat16_t*)output_matrix.data_ptr(),
        layout_e, (Mma::ElementE*)tensor_e_reordered.data_ptr(),
        {alpha, beta}, 
        indices.data<int>(), nnz_block,
        gemm_k_size);
        
    return output_matrix;
}


/////////////////////////////////////////////////////////////////////////////////////////
// Batched SpMM kernel for hybrid Blocked-ELL & 50% structured sparsity under bfloat16 //
/////////////////////////////////////////////////////////////////////////////////////////

__global__ void batchedBlockCutlassSpmmKernel_bf16(
    cutlass::gemm::GemmCoord problem_size,
    cutlass::gemm::GemmCoord problem_size_sp,
    cutlass::gemm::GemmCoord grid_tiled_shape,
    typename Mma_bf16::IteratorA::Params params_A,
    cutlass::bfloat16_t* __restrict__ ptr_A, int64_t stride_A,
    typename Mma_bf16::IteratorB::Params params_B,
    cutlass::bfloat16_t* __restrict__ ptr_B, int64_t stride_B,
    typename Epilogue_bf16::OutputTileIterator::Params params_D,
    cutlass::bfloat16_t* __restrict__ ptr_D, int64_t stride_D,
    typename Mma_bf16::IteratorE::Params params_E,
    Mma_bf16::ElementE* __restrict__ ptr_E, int64_t stride_E,
    typename Epilogue_bf16::OutputOp::Params output_op_,
    int* __restrict__ indices, int64_t stride_indices,
    int nnz_block, int gemm_k_size)
{
    extern __shared__ int SharedStorageBase[];

    SharedStorage_bf16& shared_storage = *reinterpret_cast<SharedStorage_bf16 *>(SharedStorageBase);

    bThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord threadblock_tile_offset=threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    // Early exit if CTA is out of range
    if (grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
        grid_tiled_shape.n() <= threadblock_tile_offset.n())
    {
        return;
    }

    int batch_idx = threadblock_swizzle.get_batch_idx();

    // Get the indices of the current block 
    int row_block = blockIdx.x * ThreadblockShape_bf16::kM / BlockSize;
    int* indices_t = indices + batch_idx * stride_indices + nnz_block * row_block;

    // Compute initial location in logical coordinates
    cutlass::MatrixCoord tb_offset_A{
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        0
    };

    cutlass::MatrixCoord tb_offset_B{
        0,
        threadblock_tile_offset.n() * Mma_bf16::Shape::kN
    };

    cutlass::MatrixCoord tb_offset_E{
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        0
    };

    // Problem size
    int problem_size_k = problem_size_sp.k();

    int gemm_k_iterations = (problem_size_k + Mma_bf16::Shape::kK - 1) / Mma_bf16::Shape::kK;

    // Compute position within threadblock
    int thread_idx = threadIdx.x;

    // Construct iterators to A, B, and E operands
    typename Mma_bf16::IteratorA iterator_A(
        params_A,
        ptr_A,
        {problem_size.m(), problem_size_k / Mma_bf16::kSparse},
        thread_idx,
        tb_offset_A
    );

    iterator_A.add_pointer_offset(stride_A * batch_idx);

    typename Mma_bf16::IteratorB iterator_B(
        params_B,
        ptr_B,
        {problem_size.k(), problem_size.n()},
        thread_idx,
        tb_offset_B,
        indices_t
    );

    iterator_B.add_pointer_offset(stride_B * batch_idx);

    typename Mma_bf16::IteratorE iterator_E(
        params_E,
        ptr_E,
        {problem_size.m(),
        problem_size_k / Mma_bf16::kSparse / Mma_bf16::kElementsPerElementE},
        thread_idx,
        tb_offset_E
    );

    iterator_E.add_pointer_offset(stride_E * batch_idx);

    // Broadcast the warp_id computed by lane 0 to ensure dependent code
    // is compuled as warp-uniform
    int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
    int lane_idx = threadIdx.x % 32;

    //
    //  Main loop
    //

    // Construct thread-scoped matrix multiply
    Mma_bf16 mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

    typename Mma_bf16::FragmentC accumulators;

    accumulators.clear();

    if (gemm_k_iterations > 0){
        mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
    }

    //
    //  Epilogue
    //

    Epilogue_bf16::OutputOp output_op(output_op_);

    threadblock_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);

    cutlass::MatrixCoord threadblock_offset(
        threadblock_tile_offset.m() * Mma_bf16::Shape::kM,
        threadblock_tile_offset.n() * Mma_bf16::Shape::kN
    );

    typename Epilogue_bf16::OutputTileIterator iterator_D(
        params_D,
        ptr_D,
        problem_size.mn(),
        thread_idx,
        threadblock_offset
    );

    iterator_D.add_pointer_offset(stride_D * batch_idx);
    
    Epilogue_bf16 epilogue(
        shared_storage.epilogue,
        thread_idx,
        warp_idx,
        lane_idx
    );

    epilogue(output_op, iterator_D, accumulators, iterator_D);
}


torch::Tensor batched_block_spmmv2_bf16_cuda(
    torch::Tensor tensor_a,
    torch::Tensor tensor_b,
    torch::Tensor tensor_e_reordered,
    torch::Tensor indices)
{
    const int m = tensor_a.size(-2);
    const int n = tensor_b.size(-1);
    const int k = tensor_b.size(-2);
    // The 50% sparsity reduce k by half
    // The block sparsity further reduces k
    const int k_sp = tensor_a.size(-1) * 2;

    const int batch_size = tensor_b.numel() / (k * n);

    // Get the number of nonzero blocks in each row
    // (m/BlockSize) block rows
    const int nnz_block = indices.numel() / batch_size / (m / BlockSize);

    auto options_val = torch::TensorOptions().dtype(torch::kBFloat16).device(tensor_b.device());
    auto output_matrix = torch::empty({batch_size, m, n}, options_val);

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);
    cutlass::gemm::GemmCoord problem_size_sp(m, n, k_sp);

    auto layout_a = cutlass::layout::RowMajor::packed(cutlass::make_Coord(problem_size.m(), k_sp / 2));
    auto layout_b = cutlass::layout::RowMajor::packed(problem_size.kn());
    auto layout_e = Mma_bf16::LayoutE::packed(cutlass::make_Coord(problem_size.m(), k_sp/Mma_bf16::kSparse / Mma_bf16::kElementsPerElementE));
    auto layout_d = cutlass::layout::RowMajor::packed(problem_size.mn());

    int64_t stride_a = m * k_sp / 2;
    int64_t stride_b = k * n;
    int64_t stride_e = m * k_sp / 16;
    int64_t stride_d = m * n;
    int64_t stride_indices = indices.numel() / batch_size;

    cutlass::bfloat16_t alpha = cutlass::bfloat16_t(1.0);
    cutlass::bfloat16_t beta = cutlass::bfloat16_t(0.0);

    bThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        problem_size,
        {ThreadblockShape_bf16::kM, ThreadblockShape_bf16::kN, ThreadblockShape_bf16::kK},
        batch_size
    );

    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
    dim3 block(Mma_bf16::WarpCount::kCount * 32, 1, 1);

    int smem_size = int(sizeof(SharedStorage_bf16));

    cudaFuncSetAttribute(batchedBlockCutlassSpmmKernel_bf16, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
    cudaFuncSetAttribute(batchedBlockCutlassSpmmKernel_bf16, cudaFuncAttributePreferredSharedMemoryCarveout, 100);

    int gemm_k_size = ((k_sp + Mma_bf16::Shape::kK - 1) / Mma_bf16::Shape::kK) * Mma_bf16::Shape::kK;

    batchedBlockCutlassSpmmKernel_bf16<<<grid, block, smem_size>>>(
        problem_size, problem_size_sp,
        grid_tiled_shape,
        layout_a, (cutlass::bfloat16_t*)tensor_a.data_ptr(), stride_a, 
        layout_b, (cutlass::bfloat16_t*)tensor_b.data_ptr(), stride_b,
        layout_d, (cutlass::bfloat16_t*)output_matrix.data_ptr(), stride_d, 
        layout_e, (Mma::ElementE*)tensor_e_reordered.data_ptr(), stride_e,
        {alpha, beta},
        indices.data<int>(), stride_indices,
        nnz_block, gemm_k_size
    );

    return output_matrix;
}