#include <torch/extension.h>

#include <vector>

using namespace at;


// CUDA forward declarations
void truncate_mantissa_cuda_forward(
    const int       mbits,
    const Tensor    input,
    Tensor&         output);

// C++ interface

//#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void truncate_mantissa_forward(
    const int       mbits,
    const Tensor    input,
    Tensor&         output) {

    CHECK_INPUT(input);
    CHECK_INPUT(output);

    return truncate_mantissa_cuda_forward(mbits, input, output); 
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &truncate_mantissa_forward, "truncate_mantissa from FP32 input (CUDA)");
}

