#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define PACKED 32
#define WARPSIZE 32
#define CODEBOOK_SZ 256
#define CODEBOOK_DIM 8
#define GROUP_PER_BLOCK 4
#define DQ_PACK 8
#define MEAN_GROUP 32

// Reduce sum within the warp using the tree reduction algorithm.
// from KIVI: https://github.com/jy-yuan/KIVI/blob/main/quant/csrc/gemv_cuda.cu
__device__ __forceinline__ float warp_reduce_sum(float sum) {
    #pragma unroll
    for(int i = 4; i >= 0; i--){
      sum += __shfl_down_sync(0xffffffff, sum, 1<<i);
    }
    return sum;
}



// Reduce sum within the warp using the tree reduction algorithm.
// from KIVI: https://github.com/jy-yuan/KIVI/blob/main/quant/csrc/gemv_cuda.cu
__device__ __forceinline__ half2 warp_reduce_sum_half2(half2 sum) {
    // unsigned mask = __activemask();
    #pragma unroll
    for(int i = 4; i >= 0; i--){
      sum = __hadd2(sum, __shfl_down_sync(0xffffffff, sum, 1<<i));
    }
    return sum;
}


struct __align__(64) half32 {
    half data[32];
};


template <int D, int window_size, int token_per_thread, int nh>
__global__ void quantized_weighted_sum_residual_dq_1bit_kernel(
    const int32_t* __restrict__ packed,
    const int32_t* __restrict__ norm_idx,
    const half* __restrict__ norm_scale,
    const half* __restrict__ norm_offset,
    const int32_t* __restrict__ mean_idx,
    const half* __restrict__ mean_scale,
    const half* __restrict__ mean_offset,
    const half* __restrict__ norm2,
    const half* __restrict__ residual,
    const int32_t* __restrict__ codebook_idx,
    const half* __restrict__ codebook_scale,
    const half* __restrict__ codebook_offset,
    const half* __restrict__ weight,
    half* __restrict__ out,
    const int w_stride_0, const int w_stride_1, const int w_stride_2,
    const int B, const int H, const int L, const int R
) {

    int in_group = threadIdx.x * token_per_thread;
    int b = blockIdx.x, h = blockIdx.y, group_base = blockIdx.z * GROUP_PER_BLOCK;
    int h_n = h*nh;

    int num_group_quantized = L / window_size;
    int num_total_group = (L+R+window_size-1) / window_size;

    __shared__ __align__(8) int32_t C[CODEBOOK_SZ][2];
    __shared__ __align__(8) half M[D];

    // Load codebook size: CODEBOOK_SZ * CODEBOOK_DIM
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    int t_size = blockDim.x * blockDim.y;

    // Load codebook
    #pragma unroll
    for(int j=tid; j<CODEBOOK_SZ; j+=t_size) {
        C[j][0] = codebook_idx[j];
    }

    #pragma unroll
    for(int j=tid; j<CODEBOOK_SZ; j+=t_size) {
        ((half*)&C[j][1])[0] = codebook_scale[j];
    }

    #pragma unroll
    for(int j=tid; j<CODEBOOK_SZ; j+=t_size) {
        ((half*)&C[j][1])[1] = codebook_offset[j];
    }


    __syncthreads();
    half n2;
    half n1;
    int32_t packed_idx;
    int base_i = threadIdx.y * PACKED;
    half W[nh];
    half2 curr_val;
    float2 codebook_reg;
    int32_t codebook_idx_reg;
    half codebook_scale_reg;
    half codebook_offset_reg;
    float4 m_reg;
    float4 res_reg;

    half32 res[nh];
    for(int i=0; i<PACKED; i++) {
        for(int j=0; j<nh; j++) {
            res[j].data[i] = __float2half(0.f);
        }
    }

    #pragma unroll
    for(int g=0; g<GROUP_PER_BLOCK; g++) {
        int group = group_base + g;
        int l = group * window_size + in_group;
        if(group >= num_total_group) break;
        else if(group >= num_group_quantized) {
            #pragma unroll
            for(int k=0; k<token_per_thread; k++) {
                if(l+k >= L+R) {
                    break;
                }
                #pragma unroll
                for(int j=0; j<nh; j++) {
                    W[j] = weight[b * w_stride_0 + (h_n + j) * w_stride_1 + l + k];
                }
                #pragma unroll
                for(int i=0; i<PACKED; i+=2) {
                    if(i % 8 == 0) {
                        res_reg = ((float4 *)&residual[b * (H * R * D) + (h) * (R * D) + (l+k-L) * (D) + (base_i + i)])[0];
                    }
                    curr_val = ((half2 *)&res_reg)[(i%8)>>1];
                    #pragma unroll
                    for(int j=0; j<nh; j++) {
                        ((half2*)&res[j])[(i>>1)] = __hfma2({W[j], W[j]}, curr_val, ((half2*)&res[j])[(i>>1)]);
                    }
                }
            }
        }
        else {
            // Load mean vector size:
            #pragma unroll
            for(int i=tid; i<D/DQ_PACK; i+=t_size) {
                // long long mean_idx = b * (H * num_group * D) + h * (num_group * D) + group * (D);
                // ((float4 *)M)[i] = ((float4 *)&mean[mean_idx])[i];
                long long index = b * (H * num_group_quantized * D) + h * (num_group_quantized * D) + group * (D);
                int32_t mean_idx_reg = mean_idx[index/DQ_PACK + i];
                half mean_scale_reg = mean_scale[(index + i*DQ_PACK)/MEAN_GROUP];
                half mean_offset_reg = mean_offset[(index + i*DQ_PACK)/MEAN_GROUP];
                #pragma unroll
                for(int j=0; j<DQ_PACK; j++) {
                    M[i*DQ_PACK + j] = __hfma(__int2half_rn(mean_idx_reg & 15), mean_scale_reg, mean_offset_reg);
                    mean_idx_reg >>= 4;
                }
            }
            __syncthreads();
            for(int k=0; k<token_per_thread; k++) {
                long long index = b * (H*L) + h * (L) + l + k;
                int32_t norm_idx_reg = norm_idx[index/DQ_PACK];
                half norm_scale_reg = norm_scale[index/window_size];
                half norm_offset_reg = norm_offset[index/window_size];
                n1 = __hfma(__int2half_rn((norm_idx_reg >> (4 * ((l+k)%DQ_PACK))) & 15), norm_scale_reg, norm_offset_reg);
                n2 = norm2[index];
                #pragma unroll
                for(int j=0; j<nh; j++) {
                    W[j] = weight[b * w_stride_0 + (h_n + j) * w_stride_1 + l + k];
                }
                packed_idx = packed[b * (H*L*D/PACKED)
                            + h * (L*D/PACKED)
                            + (l + k) * (D/PACKED)
                            + base_i/PACKED];
                #pragma unroll
                for(int i=0; i<PACKED; i+=2) {
                    if(i % CODEBOOK_DIM == 0) {
                        int codebook_idx = (packed_idx >> ((i/CODEBOOK_DIM) * 8)) & 255;
                        codebook_reg = ((float2*)&C[codebook_idx][0])[0];
                        m_reg = ((float4 *)&M[base_i + i])[0];
                        codebook_idx_reg = ((int32_t*)&codebook_reg)[0];
                        codebook_scale_reg = ((half*)&codebook_reg)[2];
                        codebook_offset_reg = ((half*)&codebook_reg)[3];
                    }
                    curr_val.x = (__hfma(__int2half_rn(codebook_idx_reg & 15), codebook_scale_reg, codebook_offset_reg));
                    codebook_idx_reg >>= 4;
                    curr_val.y = (__hfma(__int2half_rn(codebook_idx_reg & 15), codebook_scale_reg, codebook_offset_reg));
                    codebook_idx_reg >>= 4;
                    curr_val = __hfma2({n2, n2}, curr_val, ((half2 *)&m_reg)[(i%CODEBOOK_DIM)>>1]);
                    curr_val = __hmul2({n1, n1}, curr_val);
                    #pragma unroll
                    for(int j=0; j<nh; j++) {
                        ((half2*)&res[j])[(i>>1)] = __hfma2({W[j], W[j]}, curr_val, ((half2*)&res[j])[(i>>1)]);
                    }
                }
            }
        }
        __syncthreads();
    }

        
    int g_dim = ((L+R+window_size-1)/window_size + GROUP_PER_BLOCK-1)/GROUP_PER_BLOCK;
    int out_base_idx = b * (H * nh * D * g_dim) +
                        (h_n) * (D * g_dim) +
                        (blockIdx.z) * (D) +
                        base_i;
    #pragma unroll
    for(int i=0; i<PACKED/2; i++) {
        #pragma unroll
        for(int j=0; j<nh; j++) {
            ((half2*)&res[j])[i] = warp_reduce_sum_half2(((half2*)&res[j])[i]);
        }
    }
    #pragma unroll
    for(int j=0; j<nh; j++) {
        if(threadIdx.x == 0) {
            half32* curr_out = (half32 *)(out + out_base_idx + j*(D*g_dim));
            curr_out[0] = res[j];
        }
    }
}

// C++ 함수(바인딩)
template <int D, int window_size>
torch::Tensor quantized_weighted_sum_residual_dq_1bit (
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B,H,R,D), half
    torch::Tensor codebook_idx,   // (256,1),      int32,
    torch::Tensor codebook_scale, // (256,1),      half,
    torch::Tensor codebook_offset, // (256,1),      half,
    torch::Tensor weight, // (B, H * nh, 1, L+R)
    int nh
) {
    TORCH_CHECK(packed.is_contiguous(),      "packed must be contiguous");
    TORCH_CHECK(norm_idx.is_contiguous(),      "norm_idx must be contiguous");
    TORCH_CHECK(norm_scale.is_contiguous(),      "norm_scale must be contiguous");
    TORCH_CHECK(norm_offset.is_contiguous(),      "norm_offset must be contiguous");
    TORCH_CHECK(mean_idx.is_contiguous(),      "mean_idx must be contiguous");
    TORCH_CHECK(mean_scale.is_contiguous(),      "mean_scale must be contiguous");
    TORCH_CHECK(mean_offset.is_contiguous(),      "mean_offset must be contiguous");
    TORCH_CHECK(norm2.is_contiguous(),     "norm2 must be contiguous");
    TORCH_CHECK(codebook_idx.is_contiguous(),  "codebook_idx must be contiguous");
    TORCH_CHECK(codebook_scale.is_contiguous(),  "codebook_scale must be contiguous");
    TORCH_CHECK(codebook_offset.is_contiguous(),  "codebook_offset must be contiguous");
    // Non-contiguous weight is allowed here
    int dev_index = packed.device().index();
    cudaSetDevice(dev_index);

    int w_stride_0 = weight.stride(0);
    int w_stride_1 = weight.stride(1);
    int w_stride_2 = weight.stride(2);
    int B = packed.size(0);
    int H = packed.size(1);
    int L = packed.size(2);
    int R = residual.size(2);

    auto out = torch::empty({B, H * nh, ((L+R+window_size-1)/window_size + GROUP_PER_BLOCK-1)/GROUP_PER_BLOCK, D}, packed.options().dtype(torch::kHalf));

    // grid: (B * H * (L // window_size)), block: 예) 256
    dim3 blockDim(WARPSIZE, D/PACKED);
    dim3 gridDim(B, H, ((L+R+window_size-1)/window_size + GROUP_PER_BLOCK-1)/GROUP_PER_BLOCK);

    size_t shared_mem_bytes = 0;

    if(nh == 1) {
        quantized_weighted_sum_residual_dq_1bit_kernel<D, window_size, window_size/WARPSIZE, 1><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const int32_t*>(norm_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm_offset.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(mean_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(mean_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            w_stride_0, w_stride_1, w_stride_2,
            B, H, L, R
        );
    }
    else if(nh == 2) {
        quantized_weighted_sum_residual_dq_1bit_kernel<D, window_size, window_size/WARPSIZE, 2><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const int32_t*>(norm_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm_offset.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(mean_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(mean_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            w_stride_0, w_stride_1, w_stride_2,
            B, H, L, R
        );
    }
    else if(nh == 4) {
        quantized_weighted_sum_residual_dq_1bit_kernel<D, window_size, window_size/WARPSIZE, 4><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const int32_t*>(norm_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm_offset.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(mean_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(mean_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            w_stride_0, w_stride_1, w_stride_2,
            B, H, L, R
        );
    }
    else if(nh == 8) {
        quantized_weighted_sum_residual_dq_1bit_kernel<D, window_size, window_size/WARPSIZE, 8><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const int32_t*>(norm_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm_offset.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(mean_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(mean_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            w_stride_0, w_stride_1, w_stride_2,
            B, H, L, R
        );
    }

    out = out.sum(-2, true);
    return out;
}


template torch::Tensor quantized_weighted_sum_residual_dq_1bit<128, 64> (
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B, H, R, D), half
    torch::Tensor codebook_idx,   // (256,1),      int32,
    torch::Tensor codebook_scale, // (256,1),      half,
    torch::Tensor codebook_offset, // (256,1),      half,
    torch::Tensor weight, // (B, H * nh, L)
    int nh
);
