#include <torch/extension.h>
#include <vector>

std::vector<torch::Tensor> block_sddmmv2_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices);


std::vector<torch::Tensor> block_sddmmv2(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices)
{
    return block_sddmmv2_cuda(lhs_matrix, rhs_matrix, indices);
}


std::vector<torch::Tensor> block_sddmmv2_bf16_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices);


std::vector<torch::Tensor> block_sddmmv2_bf16(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices)
{
    return block_sddmmv2_bf16_cuda(lhs_matrix, rhs_matrix, indices);
}


std::vector<torch::Tensor> batched_block_sddmmv2_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices);


std::vector<torch::Tensor> batched_block_sddmmv2(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices)
{
    return batched_block_sddmmv2_cuda(lhs_matrix, rhs_matrix, indices);
}

std::vector<torch::Tensor> batched_block_sddmmv2_bf16_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices);


std::vector<torch::Tensor> batched_block_sddmmv2_bf16(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::Tensor indices)
{
    return batched_block_sddmmv2_bf16_cuda(lhs_matrix, rhs_matrix, indices);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("block_sddmmv2", &block_sddmmv2, "Custom SDDMM kernel");
    m.def("block_sddmmv2_bf16", &block_sddmmv2_bf16, "Custom SDDMM kernel");
    m.def("batched_block_sddmmv2", &batched_block_sddmmv2, "Custom SDDMM kernel");
    m.def("batched_block_sddmmv2_bf16", &batched_block_sddmmv2_bf16, "Custom SDDMM kernel");
}