#include <iostream>
#include <vector>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
using namespace std;
#define BLOCK_SIZE 16

// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)


at::Tensor forward_cuda(const at::Tensor& A, const at::Tensor& B, const at::Tensor& scales);
std::vector<at::Tensor> backward_cuda( const at::Tensor& A,	const at::Tensor& B, const at::Tensor& Y, const at::Tensor& grad_Y, const at::Tensor& scales);


at::Tensor MRA_forward(const at::Tensor& A, const at::Tensor& B, const at::Tensor& scales) {
  CHECK_INPUT(A);
  CHECK_INPUT(B);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  return forward_cuda(A, B, scales);
}

std::vector<at::Tensor> MRA_backward(
		const at::Tensor& A,
		const at::Tensor& B,
		const at::Tensor& Y,
        const at::Tensor& grad_Y,
        const at::Tensor& scales){
    CHECK_INPUT(A);
    CHECK_INPUT(B);
    CHECK_INPUT(Y);
    CHECK_INPUT(grad_Y);

    const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
    return backward_cuda(A, B, Y, grad_Y, scales);
}



PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward" , &MRA_forward , "My MRA forward  (CUDA)");
    m.def("backward", &MRA_backward, "My PMRA backward (CUDA)");
}


namespace{

    template <typename scalar_t>
    __global__ void forward_cuda_kernel0(
              at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Y,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Q,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> K,
        const size_t bsz, 
        const size_t n, 
        const size_t qlen,
        const size_t klen,
        const size_t d_head,
        const size_t scale){

        const int row = threadIdx.x;
        const int col = threadIdx.y;
        const int block_r = blockIdx.y;
        const int block_c = blockIdx.z;
        const int b       = blockIdx.x;
        const int r = block_r*BLOCK_SIZE*scale + scale*row;
        const int c = block_c*BLOCK_SIZE*scale + scale*col;

        scalar_t Y_ij=0;
        __shared__ scalar_t As[BLOCK_SIZE][BLOCK_SIZE];
        __shared__ scalar_t Bs[BLOCK_SIZE][BLOCK_SIZE];
        for(int d=0; d<d_head/BLOCK_SIZE; d++){
            As[col][row] = (Q[b][n][r][BLOCK_SIZE*d+col] + Q[b][n][r+1][BLOCK_SIZE*d+col]);
            Bs[col][row] = (K[b][n][c][BLOCK_SIZE*d+row] + K[b][n][c+1][BLOCK_SIZE*d+row]);
            __syncthreads();

            for(int e=0; e<BLOCK_SIZE; e++){
                Y_ij += As[e][row] * Bs[col][e];
            }
            __syncthreads();            
        }
        Y_ij /= 4.0f;
		Y[b][n][r]  [c]   = Y_ij;
        Y[b][n][r]  [c+1] = Y_ij;
        Y[b][n][r+1][c]   = Y_ij;
        Y[b][n][r+1][c+1] = Y_ij;
    }
    
/*    
    template <typename scalar_t>
    __global__ void forward_cuda_kernel_register(
              at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Y,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Q,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> K,
        const size_t bsz, 
        const size_t n, 
        const size_t qlen,
        const size_t klen,
        const size_t d_head,
        const size_t scale){

        const int row = threadIdx.x;
        const int col = threadIdx.y;
        const int b       = blockIdx.x;
        const int block_r = blockIdx.y;
        const int block_c = blockIdx.z;   

        const int r = 2*block_r*BLOCK_SIZE + 2*row;
        const int c = block_c*BLOCK_SIZE*BLOCK_SIZE + BLOCK_SIZE*row + col;

        __shared__ scalar_t As[2*BLOCK_SIZE][2*BLOCK_SIZE];
        //register   scalar_t col_B[BLOCK_SIZE]{};
        register   scalar_t col_Y[2*BLOCK_SIZE]{};
        register   scalar_t B_i;
        for( int d=0; d<d_head; d+=(2*BLOCK_SIZE) ){
            As[2*col]  [2*row]   = Q[b][n][r]  [d+2*col];
            As[2*col]  [2*row+1] = Q[b][n][r+1][d+2*col];
            As[2*col+1][2*row]   = Q[b][n][r]  [d+2*col+1];
            As[2*col+1][2*row+1] = Q[b][n][r+1][d+2*col+1];
            //col_B[d + 16*row + col] = K[b][n][c][d+row];
            __syncthreads();

            for(int i=0; i<2*BLOCK_SIZE; i++){
                B_i = K[b][n][c][d + i];
                for(int j=0; j<2*BLOCK_SIZE; j++){
                    col_Y[j] += As[i][j] * B_i;
                }
            } 
            __syncthreads();
        }
        for(int e=0; e<2*BLOCK_SIZE; e++){
            Y[b][n][2*BLOCK_SIZE*block_r + e][c] = col_Y[e];  
        }
        
      
    }
*/

template <typename scalar_t>
    __global__ void forward_cuda_kernel(
              at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Y,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Q,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> K,
        const size_t bsz, 
        const size_t n, 
        const size_t qlen,
        const size_t klen,
        const size_t d_head,
        const size_t scale){

        const int row = threadIdx.x;
        const int col = threadIdx.y;
        const int block_r = blockIdx.y;
        const int block_c = blockIdx.z;
        const int b       = blockIdx.x;
        const int r = scale * (block_r*BLOCK_SIZE + row);
        const int c = scale * (block_c*BLOCK_SIZE + col);

        scalar_t Y_ij=0;
        __shared__ scalar_t As[BLOCK_SIZE][BLOCK_SIZE];
        __shared__ scalar_t Bs[BLOCK_SIZE][BLOCK_SIZE];
        for(int d=0; d<d_head; d+=BLOCK_SIZE){
            As[col][row]=0;
            Bs[col][row]=0;
            for(int i=0; i<scale; i++){
                As[col][row] += Q[b][n][r+i][d+col];
                Bs[col][row] += K[b][n][c+i][d+row];
            }
            __syncthreads();

            for(int e=0; e<BLOCK_SIZE; e++){
                Y_ij += As[e][row] * Bs[col][e];
            }
            __syncthreads();            
        }
        Y_ij /= (scale*scale);
        for(int i=0; i<scale; i++){
            for(int j=0; j<scale; j++){
                Y[b][n][r+i][c+j]   = Y_ij;
            }
        }		  
      
    }
    
    // compute grad_head_q
    template <typename scalar_t>
    __global__ void backward_cuda_kernel_Q(
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Q,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> K,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Y,
        const at::PackedTensorAccessor64<scalar_t,3,at::RestrictPtrTraits> grad_Y_scale,
              at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> grad_Q,
              at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> grad_K,
        const size_t bsz, 
        const size_t n, 
        const size_t qlen,
        const size_t klen,
        const size_t d_head,
        const size_t scale) 
    {
        const int row = threadIdx.x;
        const int col = threadIdx.y;
        const int block_q = blockIdx.y;
        const int block_d = blockIdx.z;
        const int b       = blockIdx.x;
        const int q = block_q*BLOCK_SIZE + row;
        const int d = block_d*BLOCK_SIZE + col;
      
        __shared__ scalar_t Ys[BLOCK_SIZE][BLOCK_SIZE];
        __shared__ scalar_t Ks[BLOCK_SIZE][BLOCK_SIZE];
        scalar_t Q_ij=0;
        for(int k=0; k<klen/scale; k+=BLOCK_SIZE){
            Ys[col][row] = grad_Y_scale[b][q][k + col];
            Ks[col][row]=0;
            for(int i=0; i<scale; i++){
                Ks[col][row] += K[b][n][scale*k + scale*row + i][d];
            }
            __syncthreads();

            for(int e=0; e<BLOCK_SIZE; e++){
                Q_ij += Ys[e][row] * Ks[col][e];
            }
            __syncthreads();            
        }
        for(int i=0; i<scale; i++){
		    grad_Q[b][n][scale*q + i][d] = Q_ij;          
        }
    }


    template <typename scalar_t>
    __global__ void backward_cuda_kernel_K(
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Q,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> K,
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> Y,
        const at::PackedTensorAccessor64<scalar_t,3,at::RestrictPtrTraits> grad_Y_scale,
              at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> grad_Q,
              at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> grad_K,
        const size_t bsz, 
        const size_t n, 
        const size_t qlen,
        const size_t klen,
        const size_t d_head,
        const size_t scale) 
    {
        const int row = threadIdx.x;
        const int col = threadIdx.y;
        const int block_k = blockIdx.y;
        const int block_d = blockIdx.z;
        const int b       = blockIdx.x;
        const int k = block_k*BLOCK_SIZE + row;
        const int d = block_d*BLOCK_SIZE + col;
      
        /*
        Y = Q K^T , qlen x d_head  x (klen x d_head)^T
        df/dK = ( Q^T df/dY )^T = (df/dY)^T x Q 
        */
        __shared__ scalar_t Ys[BLOCK_SIZE][BLOCK_SIZE];
        __shared__ scalar_t Qs[BLOCK_SIZE][BLOCK_SIZE];
        scalar_t K_ij=0;
        for(int q=0; q<qlen/scale; q+=BLOCK_SIZE){
            Ys[col][row] = grad_Y_scale[b][q + col]  [k];
            Qs[col][row]=0;
            for(int i=0; i<scale; i++){
                Qs[col][row] += Q[b][n][scale*q + scale*row + i][d];
            }
            __syncthreads();

            for(int e=0; e<BLOCK_SIZE; e++){
                K_ij += Ys[e][row] * Qs[col][e] ;
            }
            __syncthreads();            
        } 
        for(int i=0; i<scale; i++){
		    grad_K[b][n][scale*k + i][d] = K_ij;   
        }        
    }    


    template <typename scalar_t>
    __global__ void grad_Y_avg_cuda_kernel0(
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> grad_Y,
              at::PackedTensorAccessor64<scalar_t,3,at::RestrictPtrTraits> grad_Y_scale,
        const size_t bsz, 
        const size_t n, 
        const size_t qlen,
        const size_t klen,
        const size_t d_head,
        const size_t scale) 
    {
        const int b = blockIdx.y;
        const int i = blockIdx.x*blockDim.x + threadIdx.x;
        
        if ( i< (qlen*klen) / (scale*scale) ){
            const int k = i % (klen/scale);
            const int q = i / (klen/scale);
    
            scalar_t temp = 0;
            #pragma unroll
            for(int i=0; i<scale; i++){
                for(int j=0; j<scale; j++){
                    temp += grad_Y[b][n][scale*q + i][scale*k + j];
                }
            }
            grad_Y_scale[b][q][k] = temp/(scale*scale);
        }
    }

    template <typename scalar_t>
    __global__ void grad_Y_avg_cuda_kernel(
        const at::PackedTensorAccessor64<scalar_t,4,at::RestrictPtrTraits> grad_Y,
              at::PackedTensorAccessor64<scalar_t,3,at::RestrictPtrTraits> grad_Y_scale,
        const size_t bsz, 
        const size_t n, 
        const size_t qlen,
        const size_t klen,
        const size_t d_head,
        const size_t scale) 
    {
        const int b = blockIdx.z;
        const int q = blockIdx.x*blockDim.x + threadIdx.x;
        const int k = blockIdx.y*blockDim.y + threadIdx.y;
        
        if ( q < qlen/scale && k < klen/scale ){
            scalar_t temp = 0;
            #pragma unroll
            for(int i=0; i<scale; i++){
                for(int j=0; j<scale; j++){
                    temp += grad_Y[b][n][scale*q + i][scale*k + j];
                }
            }
            grad_Y_scale[b][q][k] = temp/(scale*scale);
        }
    }
    
}//end namespace


// 1. forward  
at::Tensor forward_cuda( const at::Tensor& Q, const at::Tensor& K, const at::Tensor& scales)  { 
    /*
        Q: bsz x n_head x qlen x d_head
        K: bsz x n_head x klen x d_head
        Y: bsz x n_head x qlen x klen
    */
    const auto bsz     = Q.size(0);
    const auto n_head  = Q.size(1);
    const auto qlen    = Q.size(2);
    const auto d_head  = Q.size(3);
    const auto klen    = K.size(2);
      
    const auto N = 0;
      
    const dim3 threads(BLOCK_SIZE,BLOCK_SIZE);
    
    auto dev     = Q.get_device();
    auto options = at::TensorOptions().dtype(Q.dtype())
                                           .layout(at::kStrided)
                                           .device(at::kCUDA, dev)
                                           .requires_grad(true);
        
    auto Y = at::zeros( {bsz, n_head, qlen, klen}, options );  
    auto h_scales = scales.data<int>() ;
    
    for(int n=0; n<n_head; n++){
        int scale = h_scales[n];
        const int n_block_q = qlen/BLOCK_SIZE/scale;
        const int n_block_k = klen/BLOCK_SIZE/scale;
        const int n_block_d = d_head/BLOCK_SIZE;
        const dim3 blocks(bsz,n_block_q, n_block_k);


        AT_DISPATCH_FLOATING_TYPES(Q.type(), "forward_cuda kernel", ([&] {
          forward_cuda_kernel<scalar_t><<<blocks, threads>>>(
            Y.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(), 
            Q.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(), 
            K.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(), 
            bsz, n, qlen, klen, d_head, scale);
        }));
    }
    
    return Y;
}
    
    
std::vector<at::Tensor> backward_cuda(
    const at::Tensor& Q,
    const at::Tensor& K,
    const at::Tensor& Y,
    const at::Tensor& grad_Y,
    const at::Tensor& scales)
{
    const auto bsz     = Q.size(0);
    const auto n_head  = Q.size(1);
    const auto qlen    = Q.size(2);
    const auto d_head  = Q.size(3);
    const auto klen    = K.size(2);
      
    auto grad_Q = at::zeros_like(Q); 
    auto grad_K = at::zeros_like(K); 

    auto dev     = Q.get_device();
    auto options = at::TensorOptions().dtype(Q.dtype())
                                           .layout(at::kStrided)
                                           .device(at::kCUDA, dev)
                                           .requires_grad(true); 
    auto grad_Y_scale = at::zeros( {bsz, qlen, klen}, options );
    auto h_scales = scales.data<int>() ;

    for(int n=0; n<n_head; n++){
        const int scale=h_scales[n];
        const int n_block_q = qlen/BLOCK_SIZE/scale;
        const int n_block_k = klen/BLOCK_SIZE/scale;
        const int n_block_d = d_head/BLOCK_SIZE;
          
        const dim3 threads_Q(BLOCK_SIZE,BLOCK_SIZE);
        const dim3 blocks_Q(bsz,n_block_q, n_block_d);
    
        const dim3 threads_K(BLOCK_SIZE,BLOCK_SIZE);
        const dim3 blocks_K(bsz,n_block_k, n_block_d);
          
        
        int N_avg = (qlen/scale) * (klen/scale);
        dim3 block_avg( (N_avg+255)/256, bsz );
        AT_DISPATCH_FLOATING_TYPES(Q.type(), "grad_Y_avg_cuda_kernel0", 
            ([&] {grad_Y_avg_cuda_kernel0<scalar_t><<< block_avg, 256>>>(
                 grad_Y.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(),
           grad_Y_scale.packed_accessor64<scalar_t,3,at::RestrictPtrTraits>(),           
           bsz, n, qlen, klen, d_head, scale);
        }));  
        
        /*
        dim3 block_avg( (qlen/scale+15)/16, (klen/scale+15)/16, bsz );
        AT_DISPATCH_FLOATING_TYPES(Q.type(), "grad_Y_avg_cuda_kernel", 
            ([&] {grad_Y_avg_cuda_kernel<scalar_t><<< block_avg, threads_Q>>>(
                 grad_Y.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(),
           grad_Y_scale.packed_accessor64<scalar_t,3,at::RestrictPtrTraits>(),           
           bsz, n, qlen, klen, d_head, scale);
        }));  
        */
        AT_DISPATCH_FLOATING_TYPES(Q.type(), "backward_cuda_kernel_Q", 
            ([&] {backward_cuda_kernel_Q<scalar_t><<<blocks_Q, threads_K>>>(
                Q.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(), 
                K.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(),
                Y.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(), 
           grad_Y_scale.packed_accessor64<scalar_t,3,at::RestrictPtrTraits>(),
           grad_Q.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(),
           grad_K.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(),  
           bsz, n, qlen, klen, d_head, scale);
        }));  

        AT_DISPATCH_FLOATING_TYPES(Q.type(), "backward_cuda_kernel_K", 
        ([&] {backward_cuda_kernel_K<scalar_t><<<blocks_K, threads_K>>>(
            Q.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(), 
            K.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(),
            Y.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(), 
       grad_Y_scale.packed_accessor64<scalar_t,3,at::RestrictPtrTraits>(),
       grad_Q.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(),
       grad_K.packed_accessor64<scalar_t,4,at::RestrictPtrTraits>(),  
       bsz, n, qlen, klen, d_head, scale);
        }));  
    }
    
    return {grad_Q, grad_K};
}


      
      
