#include <torch/extension.h>
#include <vector>

std::vector<torch::Tensor> sddmmv2_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask);


std::vector<torch::Tensor> sddmmv2(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask)
{
    return sddmmv2_cuda(lhs_matrix, rhs_matrix, mask);
}

std::vector<torch::Tensor> sddmmv2_bf16_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask);


std::vector<torch::Tensor> sddmmv2_bf16(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask)
{
    return sddmmv2_bf16_cuda(lhs_matrix, rhs_matrix, mask);
}

std::vector<torch::Tensor> batched_sddmmv2_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask);


std::vector<torch::Tensor> batched_sddmmv2(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask)
{
    return batched_sddmmv2_cuda(lhs_matrix, rhs_matrix, mask);
}

std::vector<torch::Tensor> batched_sddmmv2_bf16_cuda(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask);


std::vector<torch::Tensor> batched_sddmmv2_bf16(
    torch::Tensor lhs_matrix,
    torch::Tensor rhs_matrix,
    torch::optional<torch::Tensor> mask)
{
    return batched_sddmmv2_bf16_cuda(lhs_matrix, rhs_matrix, mask);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("sddmmv2", &sddmmv2, "Custom SDDMM kernel");
    m.def("sddmmv2_bf16", &sddmmv2_bf16, "Custom SDDMM bf16 kernel");
    m.def("bsddmmv2", &batched_sddmmv2, "Custom SDDMM kernel");
    m.def("bsddmmv2_bf16", &batched_sddmmv2_bf16, "Custom SDDMM kernel");
}