#include <torch/extension.h>

#include <vector>

using namespace at;


// CUDA forward declarations
void sgemm_cuda_forward(
    const Tensor    A,
    const Tensor    B,
    Tensor&         C);

// C++ interface

//#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void sgemm_forward(
    const Tensor    A,
    const Tensor    B,
    Tensor&         C) {

    CHECK_INPUT(A);
    CHECK_INPUT(B);
    CHECK_INPUT(C);

    return sgemm_cuda_forward(A, B, C); 
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &sgemm_forward, "SGEMM operation (CUDA)");
}

