#include <torch/extension.h>
#include <vector>

torch::Tensor block_spmmv2_cuda(
    torch::Tensor tensor_a_,
    torch::Tensor tensor_b_,
    torch::Tensor tensor_e_,
    torch::Tensor indices);

torch::Tensor block_spmmv2(
    torch::Tensor tensor_a_,
    torch::Tensor tensor_b_,
    torch::Tensor tensor_e_,
    torch::Tensor indices)
{
    return block_spmmv2_cuda(tensor_a_, tensor_b_, tensor_e_, indices);
}


torch::Tensor block_spmmv2_bf16_cuda(
    torch::Tensor tensor_a_,
    torch::Tensor tensor_b_,
    torch::Tensor tensor_e_,
    torch::Tensor indices);

torch::Tensor block_spmmv2_bf16(
    torch::Tensor tensor_a_,
    torch::Tensor tensor_b_,
    torch::Tensor tensor_e_,
    torch::Tensor indices)
{
    return block_spmmv2_bf16_cuda(tensor_a_, tensor_b_, tensor_e_, indices);
}


torch::Tensor batched_block_spmmv2_cuda(
    torch::Tensor tensor_a_,
    torch::Tensor tensor_b_,
    torch::Tensor tensor_e_,
    torch::Tensor indices);

torch::Tensor batched_block_spmmv2(
    torch::Tensor tensor_a_,
    torch::Tensor tensor_b_,
    torch::Tensor tensor_e_,
    torch::Tensor indices)
{
    return batched_block_spmmv2_cuda(tensor_a_, tensor_b_, tensor_e_, indices);
}


torch::Tensor batched_block_spmmv2_bf16_cuda(
    torch::Tensor tensor_a_,
    torch::Tensor tensor_b_,
    torch::Tensor tensor_e_,
    torch::Tensor indices);

torch::Tensor batched_block_spmmv2_bf16(
    torch::Tensor tensor_a_,
    torch::Tensor tensor_b_,
    torch::Tensor tensor_e_,
    torch::Tensor indices)
{
    return batched_block_spmmv2_bf16_cuda(tensor_a_, tensor_b_, tensor_e_, indices);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("block_spmmv2", &block_spmmv2, "Cutlass SpMM kernel");
    m.def("block_spmmv2_bf16", &block_spmmv2_bf16, "Cutlass SpMM kernel");
    m.def("batched_block_spmmv2", &batched_block_spmmv2, "Cutlass SpMM kernel");
    m.def("batched_block_spmmv2_bf16", &batched_block_spmmv2_bf16, "Cutlass SpMM kernel");
}