#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define PACKED 32
#define CODEBOOK_SZ 256
#define CODEBOOK_DIM 8


template <int D, int window_size, int nh>
__global__ void quantized_dot_product_fused_1bit_kernel(
    const int32_t* __restrict__ packed,
    const half* __restrict__ norm,
    const half* __restrict__ mean,
    const half* __restrict__ norm2,
    const half* __restrict__ codebook,
    const half* __restrict__ inv_freq,
    const int32_t* __restrict__ offset,
    const half* __restrict__ query_had,
    const half* __restrict__ query,
    half* __restrict__ out,
    const int B, const int H, const int L
) {
    int in_group = threadIdx.x;
    int b = blockIdx.x, h = blockIdx.y, group = blockIdx.z;
    int l = group * window_size + in_group;
    int h_n = h * nh;
    int off = offset[b];

    int num_group = L / window_size;

    __shared__ __align__(8) half C[CODEBOOK_SZ][CODEBOOK_DIM];
    __shared__ __align__(8) half M[D];
    __shared__ __align__(8) half Q_had[nh][D];
    __shared__ __align__(8) half Q[nh][D];
    __shared__ __align__(8) half IF[D/2];

    // Load codebook size: CODEBOOK_SZ * CODEBOOK_DIM
    int tid = threadIdx.x;
    int t_size = blockDim.x;

    int num_float4_entries = (CODEBOOK_SZ * CODEBOOK_DIM) / 8;  // Since each half2 loads 2 values

    #pragma unroll
    for (int i = tid; i < num_float4_entries; i += t_size) {
        ((float4*)&C[i][0])[0] = ((float4*)codebook)[i];
    }

    // Load mean vector size:
    #pragma unroll
    for(int i=tid; i<D/8; i+=t_size) {
        long long mean_idx = b * (H * num_group * D) + h * (num_group * D) + group * (D);
        ((float4 *)M)[i] = ((float4 *)&mean[mean_idx])[i];
    }

    // Load query vector size:
    #pragma unroll
    for(int i=0; i<nh; i++) {
        for(int j=tid; j<D/8; j+=t_size) {
            long long query_idx = b * (H * nh * D) + (h_n + i) * D;
            ((float4 *)Q[i])[j] = ((float4 *)&query[query_idx])[j];
        }
    }

    // Load query vector size:
    #pragma unroll
    for(int i=0; i<nh; i++) {
        for(int j=tid; j<D/8; j+=t_size) {
            long long query_idx = b * (H * nh * D) + (h_n + i) * D;
            ((float4 *)Q_had[i])[j] = ((float4 *)&query_had[query_idx])[j];
        }
    }

    #pragma unroll
    for(int i=tid; i<D/16; i+=t_size) {
        ((float4 *)IF)[i] = ((float4 *)inv_freq)[i];
    }
    __syncthreads();
    half n2;
    half n1;
    int32_t packed_idx;
    float4 codebook_reg;
    float4 m_reg;
    float4 m_reg2;
    float4 query_reg[nh];
    float4 query_reg2[nh];
    float4 inv_freq_reg;
    float sum[nh];
    float curr;

    for(int i=0; i<nh; i++) {
        sum[i] = 0;
    }

    n1 = norm[b * (H*L) + h * (L) + l];
    n2 = norm2[b * (H*L) + h * (L) + l];
    #pragma unroll
    for(int base_i=0; base_i<D; base_i+=PACKED) {
        packed_idx = packed[b * (H*L*D/PACKED)
                    + h * (L*D/PACKED)
                    + (l) * (D/PACKED)
                    + base_i/PACKED];
        #pragma unroll
        for(int i=0; i<PACKED; i++) {
            if(i % CODEBOOK_DIM == 0) {
                int codebook_idx = (packed_idx >> ((i/CODEBOOK_DIM) * 8)) & 255;
                codebook_reg = ((float4 *)C[codebook_idx])[0];
                // m_reg = ((float4 *)&M[base_i + i])[0];
                #pragma unroll
                for(int j=0; j<nh; j++) {
                    query_reg[j] = *((float4 *)&Q_had[j][base_i + i]);
                }
            }
            int reg_idx = i%CODEBOOK_DIM;
            curr = __half2float(((half *)&codebook_reg)[reg_idx]);

            // curr = fma(__half2float(n2), curr, __half2float(((half *)&m_reg)[reg_idx]));
            curr = __half2float(n2) * curr;
            #pragma unroll
            for(int j=0; j<nh; j++) {
                sum[j] = fmaf(curr, __half2float(((half *)&query_reg[j])[reg_idx]), sum[j]);
            }
        }
    }
    
    #pragma unroll
    for(int i=0; i<D/2; i++) {
        if(i % 8 == 0) {
            m_reg = ((float4 *)&M[i])[0];
            m_reg2 = ((float4 *)&M[i + D/2])[0];
            #pragma unroll
            for(int j=0; j<nh; j++) {
                query_reg[j] = *((float4 *)&Q[j][i]);
                query_reg2[j] = *((float4 *)&Q[j][i + D/2]);
            }
            inv_freq_reg = ((float4 *)&IF[i])[0];
        }
        float theta = __half2float(((half *)&inv_freq_reg)[i%8]) * float(l-off);

        float c, s;
        sincosf(theta, &s, &c);

        float m1 = __half2float(((half *)&m_reg)[i%8]);
        float m2 = __half2float(((half *)&m_reg2)[i%8]);

        float rotated_m1 = m1 * c - m2 * s;
        float rotated_m2 = m1 * s + m2 * c;
        m1 = rotated_m1;
        m2 = rotated_m2;

        #pragma unroll
        for(int j=0; j<nh; j++) {
            sum[j] = fmaf(__half2float(((half *)&query_reg[j])[i%8]), m1, fmaf(__half2float(((half *)&query_reg2[j])[i%8]), m2, sum[j]));
        }
    }

    #pragma unroll
    for(int i=0; i<nh; i++) {
        out[b * (H * nh * L) + (h_n + i) * (L) + l] = __float2half(__half2float(n1) * sum[i]);
    }
}

// C++ 함수(바인딩)
template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D) half
    int nh
) {
    TORCH_CHECK(packed.is_contiguous(),      "packed must be contiguous");
    TORCH_CHECK(norm.is_contiguous(),      "norm must be contiguous");
    TORCH_CHECK(mean.is_contiguous(),      "mean must be contiguous");
    TORCH_CHECK(norm2.is_contiguous(),     "norm2 must be contiguous");
    TORCH_CHECK(codebook.is_contiguous(),  "codebook must be contiguous");
    TORCH_CHECK(inv_freq.is_contiguous(),  "inv_freq must be contiguous");
    TORCH_CHECK(offset.is_contiguous(),  "offset must be contiguous");
    TORCH_CHECK(query_had.is_contiguous(),  "query_had must be contiguous");
    TORCH_CHECK(query.is_contiguous(),  "query must be contiguous");
    int dev_index = packed.device().index();
    cudaSetDevice(dev_index);

    int B = packed.size(0);
    int H = packed.size(1);
    int L = packed.size(2);

    auto out = torch::empty({B, H * nh, 1, L}, packed.options().dtype(torch::kHalf));

    // grid: (B * H * (L // window_size)), block: 예) 256
    dim3 blockDim(window_size);
    dim3 gridDim(B, H, L/window_size);

    size_t shared_mem_bytes = 0;

    if(nh==1) {
        quantized_dot_product_fused_1bit_kernel<D, window_size, 1><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(query_had.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            B, H, L
        );
    }
    else if(nh==2) {
        quantized_dot_product_fused_1bit_kernel<D, window_size, 2><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(query_had.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            B, H, L
        );
    }
    else if(nh==4) {
        quantized_dot_product_fused_1bit_kernel<D, window_size, 4><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(query_had.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            B, H, L
        );
    }
    else if(nh==8) {
        quantized_dot_product_fused_1bit_kernel<D, window_size, 8><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(query_had.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            B, H, L
        );
    } 
    return out;
}

template torch::Tensor quantized_dot_product_fused_1bit<128, 32> (
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template torch::Tensor quantized_dot_product_fused_1bit<128, 64> (
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template torch::Tensor quantized_dot_product_fused_1bit<128, 128> (
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);
