//#include <torch/extension.h>
#include <torch/types.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

#define CEIL(x,y)   (((x) + (y) - 1) / (y))
#define THREADS 512

using namespace at;

template <typename scalar_t>
__global__ void reconstruct_fp_cuda_forward_kernel(
    int*       __restrict__           sign,
    int*       __restrict__           exponent,
    int*       __restrict__           mantissa,
    scalar_t*  __restrict__           output,
    size_t                          num_features_per_batch) {

    // get index for element
    uint col = blockIdx.x * blockDim.x + threadIdx.x;

    // break threads not in the range
    if (col >= num_features_per_batch) return;

    // get index & data
    uint idx = blockIdx.y * num_features_per_batch + col;

    int data = (sign[idx] << 31) + (exponent[idx] << 23) + mantissa[idx];

    output[idx] = (float &)data;

}



// main forward function
void reconstruct_fp_cuda_forward(
    const Tensor    sign,
    const Tensor    exponent,
    const Tensor    mantissa,
    Tensor&         output) {

    // get tensor information of output
    const auto batch_size = output.size(0);
    const auto num_features_per_batch = output.numel() / batch_size;

    // launch kernels
    const dim3 grid( CEIL(num_features_per_batch, THREADS), batch_size);
    const dim3 blocks( THREADS, 1, 1 );
    AT_DISPATCH_FLOATING_TYPES(output.scalar_type(), "reconstruct_fp_forward_cuda", ([&] {
        reconstruct_fp_cuda_forward_kernel<scalar_t><<<grid, blocks>>>(
                                    sign.data_ptr<int>(), 
                                    exponent.data_ptr<int>(), 
                                    mantissa.data_ptr<int>(),
                                    output.data_ptr<scalar_t>(),
                                    num_features_per_batch);
    }));

}

#undef CEIL
#undef THREADS
