#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define PACKED 32
#define CODEBOOK_SZ 256
#define CODEBOOK_DIM 8
#define DQ_PACK 8
#define MEAN_GROUP 32


template <int D, int window_size, int nh>
__global__ void quantized_dot_product_fused_residual_dq_1bit_kernel(
    const int32_t* __restrict__ packed,
    const int32_t* __restrict__ norm_idx,
    const half* __restrict__ norm_scale,
    const half* __restrict__ norm_offset,
    const int32_t* __restrict__ mean_idx,
    const half* __restrict__ mean_scale,
    const half* __restrict__ mean_offset,
    const half* __restrict__ norm2,
    const int32_t* __restrict__ codebook_idx,
    const half* __restrict__ codebook_scale,
    const half* __restrict__ codebook_offset,
    const half* __restrict__ inv_freq,
    const int32_t* __restrict__ offset,
    const half* __restrict__ residual_cos,
    const half* __restrict__ residual_sin,
    const half* __restrict__ residual,
    const half* __restrict__ query_had,
    const half* __restrict__ query,
    float* __restrict__ out,
    const int B, const int H, const int L, const int R
) {
    int in_group = threadIdx.x;
    int b = blockIdx.x, h = blockIdx.y, group = blockIdx.z;
    int tid = threadIdx.x;
    int t_size = blockDim.x;
    int l = group * window_size + in_group;
    int h_n = h * nh;
    int off = offset[b];

    __shared__ __align__(8) int32_t C[CODEBOOK_SZ][2];
    __shared__ __align__(8) half M[D];
    __shared__ __align__(8) half Q_had[nh][D];
    __shared__ __align__(8) half Q[nh][D];
    __shared__ __align__(8) half IF[D/2];

    int num_group = L / window_size;
    if(group >= num_group) {
        // Load query vector size:
        #pragma unroll 4
        for(int i=0; i<nh; i++) {
            for(int j=tid; j<D/8; j+=t_size) {
                long long query_idx = b * (H * nh * D) + (h_n + i) * D;
                ((float4 *)Q[i])[j] = ((float4 *)&query[query_idx])[j];
            }
        }
        __syncthreads();
        if(l >= L+R) {
            return;
        }
        float sum[nh];
        float4 residual_reg;
        float4 residual_reg2;

        float4 sin_reg;
        float4 cos_reg;
        float4 query_reg[nh];
        float4 query_reg2[nh];

        int idx = l-L;
        #pragma unroll
        for(int i=0; i<nh; i++) {
            sum[i] = 0;
        }
        #pragma unroll
        for(int i=0; i<D/2; i++) {
            if(i % 8 == 0) {
                residual_reg = *((float4 *)&residual[b * (H * R * D) + (h) * (R * D) + (idx) * (D) + i]);
                residual_reg2 = *((float4 *)&residual[b * (H * R * D) + (h) * (R * D) + (idx) * (D) + i+D/2]);
                sin_reg = *((float4 *)&residual_sin[b*(R*D) + idx * (D) + i]);
                cos_reg = *((float4 *)&residual_cos[b*(R*D) + idx * (D) + i]);
  
                #pragma unroll
                for(int j=0; j<nh; j++) {
                    query_reg[j] = *((float4 *)&Q[j][i]);
                    query_reg2[j] = *((float4 *)&Q[j][i+D/2]);
                }
            }
            #pragma unroll
            for(int j=0; j<nh; j++) {
                half s = (((half *)&sin_reg)[i%8]);
                half c = (((half *)&cos_reg)[i%8]);
                half r1 = (((half *)&residual_reg)[i%8]);
                half r2 = (((half *)&residual_reg2)[i%8]);
                float rotated_r1 = __half2float(__float2half(__half2float(r1) * __half2float(c))) - __half2float(__float2half(__half2float(r2) * __half2float(s)));
                float rotated_r2 = __half2float(__float2half(__half2float(r1) * __half2float(s))) + __half2float(__float2half(__half2float(r2) * __half2float(c)));
                sum[j] = fmaf(rotated_r1, __half2float(((half *)&query_reg[j])[i%8]), sum[j]);
                sum[j] = fmaf(rotated_r2, __half2float(((half *)&query_reg2[j])[i%8]), sum[j]);
            }
        }
        #pragma unroll
        for(int i=0; i<nh; i++) {
            out[b * (H * nh * (L+R)) + (h_n + i) * (L+R) + l] = (sum[i]);       
        }
    }

    else {
        // Dot product for the quantized part
        // Load codebook size: CODEBOOK_SZ * CODEBOOK_DIM
        half n2;
        half n1;
        int32_t packed_idx;
        float2 codebook_reg;
        int32_t codebook_idx_reg;
        half codebook_scale_reg;
        half codebook_offset_reg;
        float4 m_reg;
        float4 m_reg2;
        float4 query_reg[nh];
        float4 query_reg2[nh];
        float4 inv_freq_reg;
        float sum[nh];
        float curr;

        #pragma unroll
        for(int j=tid; j<CODEBOOK_SZ; j+=t_size) {
            C[j][0] = codebook_idx[j];
        }

        #pragma unroll
        for(int j=tid; j<CODEBOOK_SZ; j+=t_size) {
            ((half*)&C[j][1])[0] = codebook_scale[j];
        }

        #pragma unroll
        for(int j=tid; j<CODEBOOK_SZ; j+=t_size) {
            ((half*)&C[j][1])[1] = codebook_offset[j];
        }


        // Load mean vector size:
        #pragma unroll
        for(int i=tid; i<D/DQ_PACK; i+=t_size) {
            // long long mean_idx = b * (H * num_group * D) + h * (num_group * D) + group * (D);
            // ((float4 *)M)[i] = ((float4 *)&mean[mean_idx])[i];
            long long index = b * (H * num_group * D) + h * (num_group * D) + group * (D);
            int32_t mean_idx_reg = mean_idx[index/DQ_PACK + i];
            half mean_scale_reg = mean_scale[(index + i*DQ_PACK)/MEAN_GROUP];
            half mean_offset_reg = mean_offset[(index + i*DQ_PACK)/MEAN_GROUP];
            #pragma unroll
            for(int j=0; j<DQ_PACK; j++) {
                M[i*DQ_PACK + j] = __hfma(__int2half_rn(mean_idx_reg & 15), mean_scale_reg, mean_offset_reg);
                mean_idx_reg >>= 4;
            }
        }

        // Load query vector size:
        #pragma unroll
        for(int i=0; i<nh; i++) {
            for(int j=tid; j<D/8; j+=t_size) {
                long long query_idx = b * (H * nh * D) + (h_n + i) * D;
                ((float4 *)Q[i])[j] = ((float4 *)&query[query_idx])[j];
            }
        }

        // Load query vector size:
        #pragma unroll
        for(int i=0; i<nh; i++) {
            for(int j=tid; j<D/8; j+=t_size) {
                long long query_idx = b * (H * nh * D) + (h_n + i) * D;
                ((float4 *)Q_had[i])[j] = ((float4 *)&query_had[query_idx])[j];
            }
        }

        #pragma unroll
        for(int i=tid; i<D/16; i+=t_size) {
            ((float4 *)IF)[i] = ((float4 *)inv_freq)[i];
        }
        __syncthreads();

        for(int i=0; i<nh; i++) {
            sum[i] = 0;
        }

        long long index = b * (H*L) + h * (L) + l;
        int32_t norm_idx_reg = norm_idx[index/DQ_PACK];
        half norm_scale_reg = norm_scale[index/window_size];
        half norm_offset_reg = norm_offset[index/window_size];
        n1 = __hfma(__int2half_rn((norm_idx_reg >> (4 * (l%DQ_PACK))) & 15), norm_scale_reg, norm_offset_reg);
        n2 = norm2[index];
        #pragma unroll
        for(int base_i=0; base_i<D; base_i+=PACKED) {
            packed_idx = packed[b * (H*L*D/PACKED)
                        + h * (L*D/PACKED)
                        + (l) * (D/PACKED)
                        + base_i/PACKED];
            #pragma unroll
            for(int i=0; i<PACKED; i++) {
                if(i % CODEBOOK_DIM == 0) {
                    int codebook_idx = (packed_idx >> ((i/CODEBOOK_DIM) * 8)) & 255;
                    codebook_reg = ((float2*)&C[codebook_idx][0])[0];
                    #pragma unroll
                    for(int j=0; j<nh; j++) {
                        query_reg[j] = *((float4 *)&Q_had[j][base_i + i]);
                    }
                    codebook_idx_reg = ((int32_t*)&codebook_reg)[0];
                    codebook_scale_reg = ((half*)&codebook_reg)[2];
                    codebook_offset_reg = ((half*)&codebook_reg)[3];
                }
                int reg_idx = i%CODEBOOK_DIM;
                curr = __half2float(__hfma(__int2half_rn(codebook_idx_reg & 15), codebook_scale_reg, codebook_offset_reg));
                codebook_idx_reg >>= 4;
                curr = __half2float(n2) * curr;
                #pragma unroll
                for(int j=0; j<nh; j++) {
                    sum[j] = fmaf(curr, __half2float(((half *)&query_reg[j])[reg_idx]), sum[j]);
                }
            }
        }
        
        #pragma unroll
        for(int i=0; i<D/2; i++) {
            if(i % 8 == 0) {
                m_reg = ((float4 *)&M[i])[0];
                m_reg2 = ((float4 *)&M[i + D/2])[0];
                #pragma unroll
                for(int j=0; j<nh; j++) {
                    query_reg[j] = *((float4 *)&Q[j][i]);
                    query_reg2[j] = *((float4 *)&Q[j][i + D/2]);
                }
                inv_freq_reg = ((float4 *)&IF[i])[0];
            }
            float theta = __half2float(((half *)&inv_freq_reg)[i%8]) * float(l-off);

            float c, s;
            sincosf(theta, &s, &c);

            float m1 = __half2float(((half *)&m_reg)[i%8]);
            float m2 = __half2float(((half *)&m_reg2)[i%8]);

            float rotated_m1 = m1 * c - m2 * s;
            float rotated_m2 = m1 * s + m2 * c;
            m1 = rotated_m1;
            m2 = rotated_m2;

            #pragma unroll
            for(int j=0; j<nh; j++) {
                sum[j] = fmaf(__half2float(((half *)&query_reg[j])[i%8]), m1, fmaf(__half2float(((half *)&query_reg2[j])[i%8]), m2, sum[j]));
            }
        }

        #pragma unroll
        for(int i=0; i<nh; i++) {
            out[b * (H * nh * (L+R)) + (h_n + i) * (L+R) + l] = (__half2float(n1) * sum[i]);
        }
    }
}

// C++ 함수(바인딩)
template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_residual_dq_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256,1),      int32,
    torch::Tensor codebook_scale, // (256,1),      half,
    torch::Tensor codebook_offset, // (256,1),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor residual_cos, // (B, R, D) half
    torch::Tensor residual_sin, // (B, R, D) half
    torch::Tensor residual, // (B, H, R, D) half
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D) half
    int nh
) {
    TORCH_CHECK(packed.is_contiguous(),      "packed must be contiguous");
    TORCH_CHECK(norm_idx.is_contiguous(),      "norm_idx must be contiguous");
    TORCH_CHECK(norm_scale.is_contiguous(),      "norm_scale must be contiguous");
    TORCH_CHECK(norm_offset.is_contiguous(),      "norm_offset must be contiguous");
    TORCH_CHECK(mean_idx.is_contiguous(),      "mean_idx must be contiguous");
    TORCH_CHECK(mean_scale.is_contiguous(),      "mean_scale must be contiguous");
    TORCH_CHECK(mean_offset.is_contiguous(),      "mean_offset must be contiguous");
    TORCH_CHECK(norm2.is_contiguous(),     "norm2 must be contiguous");
    TORCH_CHECK(codebook_idx.is_contiguous(),  "codebook_idx must be contiguous");
    TORCH_CHECK(codebook_scale.is_contiguous(),  "codebook_scale must be contiguous");
    TORCH_CHECK(codebook_offset.is_contiguous(),  "codebook_offset must be contiguous");
    TORCH_CHECK(inv_freq.is_contiguous(),  "inv_freq must be contiguous");
    TORCH_CHECK(offset.is_contiguous(),  "offset must be contiguous");
    TORCH_CHECK(residual_cos.is_contiguous(),  "residual_cos must be contiguous");
    TORCH_CHECK(residual_sin.is_contiguous(),  "residual_sin must be contiguous");
    TORCH_CHECK(residual.is_contiguous(),  "residual must be contiguous");
    TORCH_CHECK(query_had.is_contiguous(),  "query_had must be contiguous");
    TORCH_CHECK(query.is_contiguous(),  "query must be contiguous");
    int dev_index = packed.device().index();
    cudaSetDevice(dev_index);

    int B = packed.size(0);
    int H = packed.size(1);
    int L = packed.size(2);
    int R = residual.size(2);
    auto out = torch::empty({B, H * nh, 1, L+R}, packed.options().dtype(torch::kFloat));

    // grid: (B * H * (L // window_size)), block: 예) 256
    dim3 blockDim(window_size);
    dim3 gridDim(B, H, (L+R+window_size-1)/window_size);

    size_t shared_mem_bytes = 0;

    if(nh==1) {
        quantized_dot_product_fused_residual_dq_1bit_kernel<D, window_size, 1><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const int32_t*>(norm_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm_offset.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(mean_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(mean_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(residual_cos.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual_sin.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query_had.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<float*>(out.data_ptr<float>()),
            B, H, L, R
        );
    }
    else if(nh==2) {
        quantized_dot_product_fused_residual_dq_1bit_kernel<D, window_size, 2><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const int32_t*>(norm_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm_offset.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(mean_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(mean_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(residual_cos.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual_sin.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query_had.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<float*>(out.data_ptr<float>()),
            B, H, L, R
        );
    }
    else if(nh==4) {
        quantized_dot_product_fused_residual_dq_1bit_kernel<D, window_size, 4><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const int32_t*>(norm_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm_offset.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(mean_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(mean_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(residual_cos.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual_sin.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query_had.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<float*>(out.data_ptr<float>()),
            B, H, L, R
        );
    }
    else if(nh==8) {
        quantized_dot_product_fused_residual_dq_1bit_kernel<D, window_size, 8><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
            reinterpret_cast<const int32_t*>(norm_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(norm_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm_offset.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(mean_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(mean_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(mean_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<const half*>(residual_cos.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual_sin.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(residual.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query_had.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<float*>(out.data_ptr<float>()),
            B, H, L, R
        );
    } 
    return out;
}


template torch::Tensor quantized_dot_product_fused_residual_dq_1bit<128, 64> (
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256,1),      int32,
    torch::Tensor codebook_scale, // (256,1),      half,
    torch::Tensor codebook_offset, // (256,1),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor residual_cos, // (B, R, D) half
    torch::Tensor residual_sin, // (B, R, D) half
    torch::Tensor residual, // (B, H, R, D) half
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);
