//#include <torch/extension.h>
#include <torch/types.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

#define BLOCK_SIZE 16
#define MIN(x, y)               ((x) < (y) ? (x) : (y))
#define CEIL(x,y)               (((x) + (y) - 1) / (y))
#define FLOOR(x,y)              ((x) / (y))

// information related to FP format (follows IEEE FP32)
#define M_HIDDEN 0x800000
#define F_BITS 23
#define M_BITS 24

using namespace at;

// decompose fp data
inline __device__ void decompose_fp(int* sign, int* exp, int* frac, float data) {

    int data_int = (int &)data;
    
    sign[0] = (data_int & 0x80000000) >> 31;
    exp[0] = (data_int & 0x7F800000) >> 23;
    frac[0] = data_int & 0x007FFFFF;
    
}

// get mantissa
inline __device__ int get_mantissa(int exp, int frac) {

    int m;

    if ( exp > 0 ) { // norm
        m = M_HIDDEN + frac;
    } else { // denorm
        m = frac;
    }

    return m;

}

// align mantissa
inline __device__ int align_mantissa(int shift_amount, int m) {

    int m_align = (shift_amount > 31)? 0 : (m >> shift_amount);

    return m_align;

}

// get extra bits after shifting
inline __device__ int get_extra_bits(int shift_amount, int frac, int num_extra_bits, int idx) {
    
    int extra_bits, mask, tmp; 

    if ( shift_amount >= num_extra_bits ) {
        // generate mask
        mask = ( 1 << num_extra_bits ) - 1;
        // get extra bits
        tmp = shift_amount - num_extra_bits;
        extra_bits = (tmp > 31)? 0 : ((frac >> tmp) & mask);
    } else {
        // generate mask
        mask = ( 1 << shift_amount ) - 1;
        // get extra bits
        tmp = num_extra_bits - shift_amount; // guarantees smaller than 31
        extra_bits = (frac & mask) << tmp;
    }

    //printf("idx: %d, mask: %d, shift_amount: %d, tmp: %d, extra_bits: %d \n", idx, mask, shift_amount, tmp, extra_bits);

    return extra_bits;

}

// get the bit poisition of leading one
inline __device__ int get_leading_one_bp(uint64_t val) {

    // mask 1 from leading one to lsb
    val |= (val >> 1);
    val |= (val >> 2);
    val |= (val >> 4);
    val |= (val >> 8);
    val |= (val >> 16);
    val |= (val >> 32);

    // count 1 in val and return
    return __popcll(val);

}

// normalize fp output
inline __device__ float normalization(int64_t m_add, int e_add, int num_extra_bits, int rounding_mode, int idx) {


    uint64_t m_out;
    uint32_t s_out, f_out;
    int e_out;
    int shift_amount;
    bool in_norm, out_norm, out_e_zero;
    bool g, r, s, f1, roundUp;

    // check e_add
    in_norm = (e_add > 0)? true : false;

    // get sign
    s_out = (m_add >= 0)? 0 : 1;
    // get mantissa
    m_out = (uint64_t)abs(m_add);

    // normalize exponent
    int leading_one_bp = get_leading_one_bp(m_out);
    shift_amount = leading_one_bp - M_BITS;
    e_out = e_add + shift_amount - num_extra_bits;

    // check output condition
    out_e_zero = (e_out == 0)? true : false;
    if ( (e_out > 0) | ( out_e_zero & (leading_one_bp >= (M_BITS + num_extra_bits)) ) ) {
        out_norm = true;
    } else {
        out_norm = false;
    }
    
    // update exponent for
    if (!out_norm) { // out-denorm
        shift_amount = shift_amount - e_out;
        e_out = 0;
        if ( in_norm ) { // in-norm
            shift_amount += 1;
        }
    } else if (!in_norm) { // out-norm & in-denorm
        e_out += 1;
    }

    // normalize fraction
    if ( shift_amount > 0 ) {
        f_out = (uint32_t)(m_out >> shift_amount);
    } else {
        f_out = (uint32_t)(m_out << abs(shift_amount));
    }
   
    // update fraction for normal values
    if ( out_norm ) {
        // subtract leadig one for normal values
        f_out = f_out - M_HIDDEN;
    }
 
    // rounding_mode - 0 : truncation, no rounding
    if ( rounding_mode == 1 ) { // round-to-nearest even // NOTE - this mode is not fully debugged yet (need to debug with correct mode)
        // rounding
        // get grs
        if ( shift_amount > 1 ) {
            g = (bool)( m_out & (1 << (shift_amount - 1)) );
        } else {
            g = 0;
        }
        if ( shift_amount > 2 ) {
            r = (bool)( m_out & (1 << (shift_amount - 2)) );
        } else {
            r = 0;
        }
        if ( shift_amount > 2 ) {
            s = (bool)( m_out & ((1 << (shift_amount - 2)) - 1) );
        } else {
            s = 0;
        }
        // get last bit of fraction
        f1 = (bool)( f_out & 1 );

        // rounding value
        roundUp = ( g & (r | s) ) | ( f1 & g & ((!r) & (!s)) );
        f_out = f_out + roundUp;

        // re-norm
        if ( f_out > M_HIDDEN ){
            f_out = (f_out - M_HIDDEN) >> 1;
            e_out = e_out + 1;
        }
    }


    // update output
    uint32_t out_int = (s_out << 31) + (((uint32_t)e_out) << 23) + ((uint32_t)f_out);
    float out = (float &)out_int;

    return out;

}


template <typename scalar_t>
__global__ void prealign_linear_cuda_forward_kernel(int M, int N, int K,
                                        const int num_systolic_row,
                                        const int num_split,
                                        const int num_extra_bits,
                                        const int rounding_mode,
                                        const float * __restrict__ max_Aabs,
                                        const float * __restrict__ A, int LDA,
                                        const float * __restrict__ B, int LDB,
                                        float *       __restrict__ C, int LDC)
{
    int blx=blockIdx.x,  bly=blockIdx.y;
    int tx=threadIdx.x, ty=threadIdx.y;

    int idxN = blx * blockDim.x + tx; // index of N-dim
    int idxM = bly * blockDim.y + ty; // index of M-dim

    __shared__ int      sA_m[BLOCK_SIZE][BLOCK_SIZE];
    __shared__ int      sA_e_aligned[BLOCK_SIZE][BLOCK_SIZE];
    __shared__ int8_t   sB[BLOCK_SIZE][BLOCK_SIZE];

    int boundA = M * K - 1;
    int boundB = N * K - 1;
    int boundMaxA = M * num_split - 1;

    //// fetch e_aligned values!!
    //int e_aligned = ( ( (int &)max_Aabs[MIN(idxM, boundMaxA)] ) & 0x7F800000 ) >> 23;
    int e_aligned;

    float f_add = 0;
    int64_t m_add = 0;
    unsigned int i = 0;
    float tmp_b, tmp_a;
    int sign_a, exp_a, frac_a, man_a, shift_amount, m_aligned, extra_bits;
    int sys_array_idx_a;
    int sys_array_row_count = 0;
    while (i < K) { 
        // fetch proper e_aligned
        sys_array_idx_a = FLOOR(tx + i, num_systolic_row);
        //sys_array_idx_a = 0;
        e_aligned = ( ( (int &)max_Aabs[MIN(idxM * num_split + sys_array_idx_a, boundMaxA)] ) & 0x7F800000 ) >> 23;
        sA_e_aligned[ty][tx] = e_aligned;

        //if ((idxM < M) & (tx+i < K)) {
        //    printf("idxM: %d, K: %d, sys_array_idx_a: %d, max_val: %e\n", idxM, tx+i, sys_array_idx_a, max_Aabs[MIN(idxM * num_split + sys_array_idx_a, boundMaxA)]);
        //}

        /* copy the data from global memory to shared memory */
        // pre-load data
        tmp_b = B[MIN( idxN * K + ty + i, boundB)];
        tmp_a = A[MIN( idxM * K + tx + i, boundA)];

        // pre-processing data (get required info)
        // B: use sign value only
        //sB[ty][tx] = (tmp_b > 0)? 1 : -1;
        sB[ty][tx] = (int8_t)tmp_b;

        // A: decompose floating point data
        decompose_fp(&sign_a, &exp_a, &frac_a, tmp_a);
        man_a = get_mantissa(exp_a, frac_a);
        // align mantissa
        shift_amount = e_aligned - exp_a;
        if ( (shift_amount > 0) & (exp_a==0) ) {
            shift_amount = shift_amount - 1;
        }
        m_aligned = align_mantissa(shift_amount, man_a);
        extra_bits = get_extra_bits(shift_amount, man_a, num_extra_bits, idxM); // different from python, but it is correct
        // keep, shift_amount, aligned mantissa
        sA_m[ty][tx] = (1 + (-2)*sign_a) * ((m_aligned << num_extra_bits) + extra_bits);

        // synchronize the local threads
        __syncthreads();

        // do the compuation
        for (unsigned int j=0; j < BLOCK_SIZE; j++) {
            if ( i + j < K ) {
                m_add += sA_m[ty][j] * sB[j][tx];
                sys_array_row_count += 1;
                // upate f_out if sys_array is finished
                if ( sys_array_row_count == num_systolic_row ) {
                    e_aligned = sA_e_aligned[ty][j];
                    f_add += normalization(m_add, e_aligned, num_extra_bits, rounding_mode, idxM);
                    m_add = 0;
                    sys_array_row_count = 0;
                    //if ((idxM < M) & (idxN < N)) {
                    //    printf("idxM: %d, idxN: %d, f_add: %e\n", idxM, idxN, f_add);
                    //}
                } else if ( (i + j) == (K - 1) ) { // update f_add if the sys array is not finished
                    e_aligned = sA_e_aligned[ty][j];
                    f_add += normalization(m_add, e_aligned, num_extra_bits, rounding_mode, idxM);
                }
            }
        }

        // synchronize the local threads
        __syncthreads();

        // move to next iteration
        i += BLOCK_SIZE;
    }
    
    if ( (idxM < M) & (idxN < N) ) {
        //// update f_add if the sys array is not finished
        //if ( sys_array_row_count > 0 ) {
        //    //e_aligned = sA_e_aligned[ty][BLOCK_SIZE-1];
        //    e_aligned = sA_e_aligned[ty][K-1];
        //    f_add += normalization(m_add, e_aligned, num_extra_bits, rounding_mode, idxM);
        //}
        // normalize & write float output data
        //C[idxM * N + idxN] = normalization(m_add, e_aligned, num_extra_bits, rounding_mode, idxM);
        C[idxM * N + idxN] = f_add;
    }

}

// main forward function
void prealign_linear_cuda_forward(
    const Tensor    A, // M x K
    const Tensor    B, // N x K ( use sign value only )
    Tensor&         C, // M x N
    const int       num_systolic_row,
    const int       num_extra_bits,
    const int       rounding_mode) {

    // get tensor information of input
    const auto N = B.size(0);
    const auto K = B.size(1); // NOTE: A.size(-1) and B.size(1) should be the same as K
    const auto M = A.numel() / K;

    int num_split = CEIL(K, num_systolic_row);
    //printf("num_split: %d\n", num_split);

    // Match dim of A as 3-D for max_pool1d (Assume that the A dim is 2-D or 3-D)
    Tensor max_Aabs = abs(A); // get absolute value
    if ( A.size(0) == M ) { // A: 2-D
        max_Aabs = unsqueeze(abs(A), 0); //input to max_pool1d should be 3-D
    }
    max_Aabs = max_pool1d( max_Aabs , num_systolic_row, IntArrayRef({}), IntArrayRef(0), IntArrayRef(1), true );

    // launch kernels
    const dim3 grid( CEIL(N, BLOCK_SIZE), CEIL(M, BLOCK_SIZE) );
    const dim3 blocks( BLOCK_SIZE, BLOCK_SIZE);

    AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "prealign_linear_forward_cuda", ([&] {
        prealign_linear_cuda_forward_kernel<scalar_t><<<grid, blocks>>>(
                                                (int)M, (int)N, (int)K,
                                                num_systolic_row,
                                                num_split,
                                                num_extra_bits,
                                                rounding_mode,
                                                max_Aabs.data_ptr<float>(), 
                                                A.data_ptr<float>(), (int)K, // matrix: column major
                                                B.data_ptr<float>(), (int)K,
                                                C.data_ptr<float>(), (int)N);
    }));

}



#undef BLOCK_SIZE
#undef MIN
#undef CEIL

#undef M_HIDDEN
#undef F_BITS
#undef M_BITS
