//#include <torch/extension.h>
#include <torch/types.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

#define BLOCK_SIZE 16
#define MIN(x, y)               ((x) < (y) ? (x) : (y))
#define CEIL(x,y)               (((x) + (y) - 1) / (y))
//#define DIM_X   16
//#define DIM_Y   16
//#define BLK_M   96
//#define BLK_N   96
//#define BLK_K   16
//#define DIM_XA  32
//#define DIM_YA  8
//#define DIM_XB  32
//#define DIM_YB  8
//#define THR_M   (BLK_M / DIM_X)
//#define THR_N   (BLK_N / DIM_Y) 

//#define fetch(A, m, n, bound)   offs_##A[MIN((m)*LD##A+n, bound)]
//#define cazA                    (blx*BLK_M*LDA + idxA*LDA + idyA)
//#define cazB                    (bly*BLK_N*LDB + idxB*LDB + idyB)


using namespace at;


template <typename scalar_t>
__global__ void sgemm_cuda_forward_kernel(int M, int N, int K,
                        const float * __restrict__ A, int LDA,
                        const float * __restrict__ B, int LDB,
                        float *       __restrict__ C, int LDC)
{
    int blx=blockIdx.x,  bly=blockIdx.y;
    int tx=threadIdx.x, ty=threadIdx.y;

    int idxN = blx * blockDim.x + tx; // index of N-dim
    int idxM = bly * blockDim.y + ty; // index of M-dim

    __shared__ float sA[BLOCK_SIZE][BLOCK_SIZE];
    __shared__ float sB[BLOCK_SIZE][BLOCK_SIZE];

    int boundA = M * K - 1;
    int boundB = N * K - 1;

    float out = 0;
    unsigned int i = 0;
    while (i < K) {
        // copy the data from global memory to shared memory
        sA[ty][tx] = A[MIN( idxM * K + tx + i, boundA)];
        sB[ty][tx] = B[MIN( idxN * K + ty + i, boundB)];
        // synchronize the local threads
        __syncthreads();

        // do the compuation
        for (unsigned int j=0; j < BLOCK_SIZE; j++) {
            if ( i + j < K ) {
                out += sA[ty][j] * sB[j][tx];
            }
        }

        // synchronize the local threads
        __syncthreads();

        // move to next iteration
        i += BLOCK_SIZE;
    }
    
    if ( (idxM < M) & (idxN < N) ) {
        C[idxM * N + idxN] = out;
    }

}



// main forward function
void sgemm_cuda_forward(
    const Tensor    A, // M x K
    const Tensor    B, // N x K
    Tensor&         C) { // M x N

    // get tensor information of input
    const auto N = B.size(0);
    const auto K = B.size(1); // NOTE: A.size(-1) and B.size(1) should be the same as K
    const auto M = A.numel() / K;

    // launch kernels
    const dim3 grid( CEIL(N, BLOCK_SIZE), CEIL(M, BLOCK_SIZE) );
    const dim3 blocks( BLOCK_SIZE, BLOCK_SIZE);

    AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "sgemm_forward_cuda", ([&] {
        sgemm_cuda_forward_kernel<scalar_t><<<grid, blocks>>>(
                                    (int)M, (int)N, (int)K,
                                    A.data_ptr<float>(), (int)K, // matrix: column major
                                    B.data_ptr<float>(), (int)K,
                                    C.data_ptr<float>(), (int)N);
    }));

}

#undef BLOCK_SIZE
#undef MIN
#undef CEIL

//#undef DIM_X   
//#undef DIM_Y   
//#undef BLK_M   
//#undef BLK_N   
//#undef BLK_K   
//#undef DIM_XA  
//#undef DIM_YA  
//#undef DIM_XB  
//#undef DIM_YB  
//#undef THR_M   
//#undef THR_N    

//#undef fetch
//#undef cazA
//#undef cazB

