#include <torch/extension.h>

#include <vector>

using namespace at;


// CUDA forward declarations
void reconstruct_fp_cuda_forward(
    const Tensor    sign,
    const Tensor    exponent,
    const Tensor    mantissa,
    Tensor&         output);

// C++ interface

//#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void reconstruct_fp_forward(
    const Tensor    sign,
    const Tensor    exponent,
    const Tensor    mantissa,
    Tensor&         output) {

    CHECK_INPUT(sign);
    CHECK_INPUT(exponent);
    CHECK_INPUT(mantissa);
    CHECK_INPUT(output);

    return reconstruct_fp_cuda_forward(sign, exponent, mantissa, output); 
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &reconstruct_fp_forward, "sign/exponent/mantissa to FP32 output (CUDA)");
}

