//#include <torch/extension.h>
#include <torch/types.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

#define CEIL(x,y)   (((x) + (y) - 1) / (y))
#define THREADS 512

using namespace at;

template <typename scalar_t>
__global__ void truncate_mantissa_cuda_forward_kernel(
    uint32_t                        m_mask,
    scalar_t*__restrict__           input,
    scalar_t*__restrict__           output,
    size_t                          num_features_per_batch) {

    // get index for element
    uint col = blockIdx.x * blockDim.x + threadIdx.x;

    // break threads not in the range
    if (col >= num_features_per_batch) return;

    // get index & data
    uint idx = blockIdx.y * num_features_per_batch + col;
    int data = (int &)input[idx];

    uint32_t sign_exp, trunc_man;

    //printf("idx: %d, data: %d\n", idx, data);

    // mask data to get sign/exponent/mantissa
    sign_exp = data & 0xFF800000;
    trunc_man = data & m_mask;
    //mantissa[idx] = data & 0x007FFFFF;

}



// main forward function
void truncate_mantissa_cuda_forward(
    const int       mbits,
    const Tensor    input,
    Tensor&         output) {

    // get tensor information of input
    const auto batch_size = input.size(0);
    const auto num_features_per_batch = input.numel() / batch_size;

    if ( mbits < 23 ) { // only truncate mantissa if given mbits is smaller then FP32 mantissa bits - 23

        uint32_t m_mask = 0x007FFFFF - (1 << (23 - mbits)) - 1;

        // launch kernels
        const dim3 grid( CEIL(num_features_per_batch, THREADS), batch_size);
        const dim3 blocks( THREADS, 1, 1 );
        AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "truncate_mantissa_forward_cuda", ([&] {
            truncate_mantissa_cuda_forward_kernel<scalar_t><<<grid, blocks>>>(
                                        m_mask,
                                        input.data_ptr<scalar_t>(),
                                        output.data_ptr<scalar_t>(),
                                        num_features_per_batch);
        }));
    }

}

#undef CEIL
#undef THREADS
