#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define PACKED 16
#define WARPSIZE 32
#define CODEBOOK_SZ 256
#define CODEBOOK_DIM 8
#define GROUP_PER_BLOCK 4

// Reduce sum within the warp using the tree reduction algorithm.
// from KIVI: https://github.com/jy-yuan/KIVI/blob/main/quant/csrc/gemv_cuda.cu
__device__ __forceinline__ float warp_reduce_sum(float sum) {
    #pragma unroll
    for(int i = 4; i >= 0; i--){
      sum += __shfl_down_sync(0xffffffff, sum, 1<<i);
    }
    return sum;
}



// Reduce sum within the warp using the tree reduction algorithm.
// from KIVI: https://github.com/jy-yuan/KIVI/blob/main/quant/csrc/gemv_cuda.cu
__device__ __forceinline__ half2 warp_reduce_sum_half2(half2 sum) {
    // unsigned mask = __activemask();
    #pragma unroll
    for(int i = 4; i >= 0; i--){
      sum = __hadd2(sum, __shfl_down_sync(0xffffffff, sum, 1<<i));
    }
    return sum;
}


struct __align__(32) half16 {
    half data[16];
};


template <int D, int window_size, int token_per_thread, int nh>
__global__ void quantized_weighted_sum_residual_kernel(
    const int32_t* __restrict__ packed,
    const half* __restrict__ norm,
    const half* __restrict__ mean,
    const half* __restrict__ norm2,
    const half* __restrict__ residual,
    const half* __restrict__ codebook,
    const half* __restrict__ weight,
    half* __restrict__ out,
    const int w_stride_0, const int w_stride_1, const int w_stride_2,
    const int B, const int H, const int L, const int R
) {

    int in_group = threadIdx.x * token_per_thread;
    int b = blockIdx.x, h = blockIdx.y, group_base = blockIdx.z * GROUP_PER_BLOCK;
    int h_n = h*nh;

    int num_group_quantized = L / window_size;
    int num_total_group = (L+R+window_size-1) / window_size;

    __shared__ __align__(8) half C[CODEBOOK_SZ][CODEBOOK_DIM];
    __shared__ __align__(8) half M[D];

    // Load codebook size: CODEBOOK_SZ * CODEBOOK_DIM
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    int t_size = blockDim.x * blockDim.y;

    int num_float4_entries = (CODEBOOK_SZ * CODEBOOK_DIM) / 8;  // Since each half2 loads 2 values

    #pragma unroll
    for (int i = tid; i < num_float4_entries; i += t_size) {
        ((float4*)&C[i][0])[0] = ((float4*)codebook)[i];
    }

    // Load mean vector size:

    __syncthreads();
    half n2;
    half n1;
    int32_t packed_idx;
    int base_i = threadIdx.y * PACKED;
    half W[nh];
    half2 curr_val;
    float4 codebook_reg;
    float4 m_reg;
    float4 res_reg;

    half16 res[nh];
    for(int i=0; i<PACKED; i++) {
        for(int j=0; j<nh; j++) {
            res[j].data[i] = __float2half(0.f);
        }
    }

    #pragma unroll
    for(int g=0; g<GROUP_PER_BLOCK; g++) {
        int group = group_base + g;
        int l = group * window_size + in_group;
        if(group >= num_total_group) break;
        else if(group >= num_group_quantized) {
            #pragma unroll
            for(int k=0; k<token_per_thread; k++) {
                if(l+k >= L+R) {
                    break;
                }
                #pragma unroll
                for(int j=0; j<nh; j++) {
                    W[j] = weight[b * w_stride_0 + (h_n + j) * w_stride_1 + l + k];
                }
                #pragma unroll
                for(int i=0; i<PACKED; i+=2) {
                    if(i % 8 == 0) {
                        res_reg = ((float4 *)&residual[b * (H * R * D) + (h) * (R * D) + (l+k-L) * (D) + (base_i + i)])[0];
                    }
                    curr_val = ((half2 *)&res_reg)[(i%8)>>1];
                    #pragma unroll
                    for(int j=0; j<nh; j++) {
                        ((half2*)&res[j])[(i>>1)] = __hfma2({W[j], W[j]}, curr_val, ((half2*)&res[j])[(i>>1)]);
                    }
                }
            }
        }
        else {
            #pragma unroll 4
            for(int i=tid; i<D/8; i+=t_size) {
                long long mean_idx = b * (H * num_group_quantized * D) + h * (num_group_quantized * D) + group * (D);
                ((float4 *)M)[i] = ((float4 *)&mean[mean_idx])[i];
            }
            __syncthreads();
            for(int k=0; k<token_per_thread; k++) {
                n1 = norm[b * (H*L) + h * (L) + l + k];
                n2 = norm2[b * (H*L) + h * (L) + l + k];
                #pragma unroll
                for(int j=0; j<nh; j++) {
                    W[j] = weight[b * w_stride_0 + (h_n + j) * w_stride_1 + l + k];
                }
                packed_idx = packed[b * (H*L*D/PACKED)
                            + h * (L*D/PACKED)
                            + (l + k) * (D/PACKED)
                            + base_i/PACKED];
                #pragma unroll
                for(int i=0; i<PACKED; i+=2) {
                    if(i % CODEBOOK_DIM == 0) {
                        int codebook_idx = (packed_idx >> ((i/CODEBOOK_DIM) * 8)) & 255;
                        codebook_reg = ((float4 *)C[codebook_idx])[0];
                        m_reg = ((float4 *)&M[base_i + i])[0];
                    }
                    curr_val = ((half2 *)&codebook_reg)[((i%CODEBOOK_DIM)>>1)];
                    int32_t signs = (packed_idx >> (PACKED+i));
                    int32_t sign1 = ((signs) & 1)^1;
                    int32_t sign2 = ((signs >> 1) & 1)^1;
                    ((int32_t*)&curr_val)[0] ^= ((sign2 << 31) | (sign1 << 15)); 

                    curr_val = __hfma2({n2, n2}, curr_val, ((half2 *)&m_reg)[(i%CODEBOOK_DIM)>>1]);
                    curr_val = __hmul2({n1, n1}, curr_val);
                    #pragma unroll
                    for(int j=0; j<nh; j++) {
                        ((half2*)&res[j])[(i>>1)] = __hfma2({W[j], W[j]}, curr_val, ((half2*)&res[j])[(i>>1)]);
                    }
                }
            }
        }
        __syncthreads();
    }

        
    int g_dim = ((L+R+window_size-1)/window_size + GROUP_PER_BLOCK-1)/GROUP_PER_BLOCK;
    int out_base_idx = b * (H * nh * D * g_dim) +
                        (h_n) * (D * g_dim) +
                        (blockIdx.z) * (D) +
                        base_i;
    #pragma unroll
    for(int i=0; i<PACKED/2; i++) {
        #pragma unroll
        for(int j=0; j<nh; j++) {
            ((half2*)&res[j])[i] = warp_reduce_sum_half2(((half2*)&res[j])[i]);
        }
    }
    #pragma unroll
    for(int j=0; j<nh; j++) {
        if(threadIdx.x == 0) {
            half16* curr_out = (half16 *)(out + out_base_idx + j*(D*g_dim));
            curr_out[0] = res[j];
        }
    }
}

// C++ 함수(바인딩)
template <int D, int window_size>
torch::Tensor quantized_weighted_sum_residual(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B,H,R,D), half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor weight, // (B, H * nh, 1, L+R)
    int nh
) {
    TORCH_CHECK(packed.is_contiguous(),      "packed must be contiguous");
    TORCH_CHECK(norm.is_contiguous(),      "norm must be contiguous");
    TORCH_CHECK(mean.is_contiguous(),      "mean must be contiguous");
    TORCH_CHECK(norm2.is_contiguous(),     "norm2 must be contiguous");
    TORCH_CHECK(codebook.is_contiguous(),  "codebook must be contiguous");
    // Non-contiguous weight is allowed here
    int dev_index = packed.device().index();
    cudaSetDevice(dev_index);

    int w_stride_0 = weight.stride(0);
    int w_stride_1 = weight.stride(1);
    int w_stride_2 = weight.stride(2);
    int B = packed.size(0);
    int H = packed.size(1);
    int L = packed.size(2);
    int R = residual.size(2);

    auto out = torch::empty({B, H * nh, ((L+R+window_size-1)/window_size + GROUP_PER_BLOCK-1)/GROUP_PER_BLOCK, D}, packed.options().dtype(torch::kHalf));

    // grid: (B * H * (L // window_size)), block: 예) 256
    dim3 blockDim(WARPSIZE, D/PACKED);
    dim3 gridDim(B, H, ((L+R+window_size-1)/window_size + GROUP_PER_BLOCK-1)/GROUP_PER_BLOCK);

    size_t shared_mem_bytes = 0;

    if(nh == 1) {
        quantized_weighted_sum_residual_kernel<D, window_size, window_size/WARPSIZE, 1><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            w_stride_0, w_stride_1, w_stride_2,
            B, H, L, R
        );
    }
    else if(nh == 2) {
        quantized_weighted_sum_residual_kernel<D, window_size, window_size/WARPSIZE, 2><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            w_stride_0, w_stride_1, w_stride_2,
            B, H, L, R
        );
    }
    else if(nh == 4) {
        quantized_weighted_sum_residual_kernel<D, window_size, window_size/WARPSIZE, 4><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            w_stride_0, w_stride_1, w_stride_2,
            B, H, L, R
        );
    }
    else if(nh == 8) {
        quantized_weighted_sum_residual_kernel<D, window_size, window_size/WARPSIZE, 8><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            w_stride_0, w_stride_1, w_stride_2,
            B, H, L, R
        );
    }

    out = out.sum(-2, true);
    return out;
}


template torch::Tensor quantized_weighted_sum_residual<128, 32> (
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B, H, R, D), half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor weight, // (B, H * nh, L)
    int nh
);

template torch::Tensor quantized_weighted_sum_residual<128, 64> (
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B, H, R, D), half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor weight, // (B, H * nh, L)
    int nh
);

template torch::Tensor quantized_weighted_sum_residual<128, 128> (
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B, H, R, D), half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor weight, // (B, H * nh, L)
    int nh
);
