#include <torch/extension.h>

#include <vector>

using namespace at;


// CUDA forward declarations
void prealign_linear_cuda_forward(
    const Tensor    A, // M x K
    const Tensor    B, // N x K ( use sign value only )
    Tensor&         C, // M x N
    const int       num_systolic_row,
    const int       num_extra_bits,
    const int       rounding_mode);

// C++ interface

//#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void prealign_linear_forward(
    const Tensor    A, // M x K
    const Tensor    B, // N x K ( use sign value only )
    Tensor&         C, // M x N
    const int       num_systolic_row,
    const int       num_extra_bits,
    const int       rounding_mode) { 

    CHECK_INPUT(A);
    CHECK_INPUT(B);
    CHECK_INPUT(C);

    return prealign_linear_cuda_forward(A, B, C, num_systolic_row, num_extra_bits, rounding_mode); 
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &prealign_linear_cuda_forward, "linear function with pre-align input (A - FP32 input, B - binary weight in FP32 format)  (CUDA)");
}

