#include <torch/extension.h>

#include <vector>

using namespace at;


// CUDA forward declarations
void decompose_fp_cuda_forward(
    Tensor&         sign,
    Tensor&         exponent,
    Tensor&         mantissa,
    const Tensor    input);

// C++ interface

//#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void decompose_fp_forward(
    Tensor&         sign,
    Tensor&         exponent,
    Tensor&         mantissa,
    const Tensor    input) {

    CHECK_INPUT(sign);
    CHECK_INPUT(exponent);
    CHECK_INPUT(mantissa);
    CHECK_INPUT(input);

    return decompose_fp_cuda_forward(sign, exponent, mantissa, input); 
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &decompose_fp_forward, "extract sign/exponent/mantissa from FP32 input (CUDA)");
}

