//#include <torch/extension.h>
#include <torch/types.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

#define CEIL(x,y)   (((x) + (y) - 1) / (y))
#define THREADS 512

using namespace at;

template <typename scalar_t>
__global__ void decompose_fp_cuda_forward_kernel(
    int*       __restrict__           sign,
    int*       __restrict__           exponent,
    int*       __restrict__           mantissa,
    scalar_t*__restrict__           input,
    size_t                          num_features_per_batch) {

    // get index for element
    uint col = blockIdx.x * blockDim.x + threadIdx.x;

    // break threads not in the range
    if (col >= num_features_per_batch) return;

    // get index & data
    uint idx = blockIdx.y * num_features_per_batch + col;
    int data = (int &)input[idx];

    //printf("idx: %d, data: %d\n", idx, data);

    // mask data to get sign/exponent/mantissa
    sign[idx] = (data & 0x80000000) >> 31;
    exponent[idx] = (data & 0x7F800000) >> 23;
    mantissa[idx] = data & 0x007FFFFF;

}



// main forward function
void decompose_fp_cuda_forward(
    Tensor&         sign,
    Tensor&         exponent,
    Tensor&         mantissa,
    const Tensor    input) {

    // get tensor information of input
    const auto batch_size = input.size(0);
    const auto num_features_per_batch = input.numel() / batch_size;

    // launch kernels
    const dim3 grid( CEIL(num_features_per_batch, THREADS), batch_size);
    const dim3 blocks( THREADS, 1, 1 );
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "fp_to_int_forward_cuda", ([&] {
        decompose_fp_cuda_forward_kernel<scalar_t><<<grid, blocks>>>(
                                    sign.data_ptr<int>(), 
                                    exponent.data_ptr<int>(), 
                                    mantissa.data_ptr<int>(),
                                    input.data_ptr<scalar_t>(),
                                    num_features_per_batch);
    }));

}

#undef CEIL
#undef THREADS
