#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>


template <int D, int window_size, int nh>
__global__ void window_rope_dot_product_kernel(
    const half* __restrict__ mean,
    const half* __restrict__ norm,
    const half* __restrict__ query,
    const half* __restrict__ inv_freq,
    const int32_t* __restrict__ offset,
    half* __restrict__ out,
    const int B, const int H, const int L
) {
    int in_group = threadIdx.x;
    int b = blockIdx.x, h = blockIdx.y, group = blockIdx.z;
    int l = group * window_size + in_group;
    int h_n = h * nh;
    int off = offset[b];

    int num_group = L / window_size;

    __shared__ __align__(8) half M[D];
    __shared__ __align__(8) half Q[nh][D];
    __shared__ __align__(8) half IF[D/2];

    // Load codebook size: CODEBOOK_SZ * CODEBOOK_DIM
    int tid = threadIdx.x;
    int t_size = blockDim.x;

    // Load mean vector size:
    #pragma unroll 4
    for(int i=tid; i<D/8; i+=t_size) {
        long long mean_idx = b * (H * num_group * D) + h * (num_group * D) + group * (D);
        ((float4 *)M)[i] = ((float4 *)&mean[mean_idx])[i];
    }

    // Load query vector size:
    #pragma unroll 4
    for(int i=0; i<nh; i++) {
        for(int j=tid; j<D/8; j+=t_size) {
            long long query_idx = b * (H * nh * D) + (h_n + i) * D;
            ((float4 *)Q[i])[j] = ((float4 *)&query[query_idx])[j];
        }
    }

    #pragma unroll 4
    for(int i=tid; i<D/16; i+=t_size) {
        ((float4 *)IF)[i] = ((float4 *)inv_freq)[i];
    }

    __syncthreads();
    float4 m_reg1;
    float4 m_reg2;
    float4 query_reg1[nh];
    float4 query_reg2[nh];
    float4 inv_freq_reg;
    float sum[nh];

    for(int i=0; i<nh; i++) {
        sum[i] = 0;
    }

    float n = __half2float(norm[b * (H*L) + h * (L) + l]);
    for(int i=0; i<D/2; i++) {
        if(i % 8 == 0) {
            m_reg1 = ((float4 *)&M[i])[0];
            m_reg2 = ((float4 *)&M[i + D/2])[0];
            #pragma unroll
            for(int j=0; j<nh; j++) {
                query_reg1[j] = *((float4 *)&Q[j][i]);
                query_reg2[j] = *((float4 *)&Q[j][i + D/2]);
            }
            inv_freq_reg = ((float4 *)&IF[i])[0];
        }
        float theta = __half2float(((half *)&inv_freq_reg)[i%8]) * float(l-off);

        float c, s;
        // sincosf(theta, &s, &c);
        c = std::cos(theta);
        s = std::sin(theta);

        float m1 = __half2float(((half *)&m_reg1)[i%8]);
        float m2 = __half2float(((half *)&m_reg2)[i%8]);

        float rotated_m1 = m1 * c - m2 * s;
        float rotated_m2 = m1 * s + m2 * c;
        m1 = rotated_m1;
        m2 = rotated_m2;

        #pragma unroll
        for(int j=0; j<nh; j++) {
            sum[j] += __half2float(((half *)&query_reg1[j])[i%8]) * m1 + __half2float(((half *)&query_reg2[j])[i%8]) * m2;
        }
    }

    for(int i=0; i<nh; i++) {
        out[b * (H * nh * L) + (h_n + i) * (L) + l] = __float2half(n * sum[i]);
    }
}

// C++ 함수(바인딩)
template <int D, int window_size>
torch::Tensor window_rope_dot_product(
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor query, // (B, H * nh, 1, D)
    torch::Tensor inv_freq, // (D//2)
    torch::Tensor offset, // (B)
    int nh
) {
    TORCH_CHECK(mean.is_contiguous(),      "mean must be contiguous");
    TORCH_CHECK(norm.is_contiguous(),      "norm must be contiguous");
    TORCH_CHECK(query.is_contiguous(),  "query must be contiguous");
    TORCH_CHECK(inv_freq.is_contiguous(), "inv_freq must be contiguous");
    TORCH_CHECK(offset.is_contiguous(), "offset must be contiguous");
    int dev_index = mean.device().index();
    cudaSetDevice(dev_index);

    int B = norm.size(0);
    int H = norm.size(1);
    int L = norm.size(2);

    auto out = torch::empty({B, H * nh, 1, L}, mean.options().dtype(torch::kHalf));

    // grid: (B * H * (L // window_size)), block: 예) 256
    dim3 blockDim(window_size);
    dim3 gridDim(B, H, L/window_size);

    size_t shared_mem_bytes = 0;

    if(nh==1) {
        window_rope_dot_product_kernel<D, window_size, 1><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            B, H, L
        );
    }
    else if(nh==2) {
        window_rope_dot_product_kernel<D, window_size, 2><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            B, H, L
        );
    }
    else if(nh==4) {
        window_rope_dot_product_kernel<D, window_size, 4><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            B, H, L
        );
    }
    else if(nh==8) {
        window_rope_dot_product_kernel<D, window_size, 8><<<gridDim, blockDim, shared_mem_bytes>>>(
            reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(query.data_ptr<at::Half>()),
            reinterpret_cast<const half*>(inv_freq.data_ptr<at::Half>()),
            reinterpret_cast<const int32_t*>(offset.data_ptr<int32_t>()),
            reinterpret_cast<half*>(out.data_ptr<at::Half>()),
            B, H, L
        );
    } 
    return out;
}

template torch::Tensor window_rope_dot_product<128, 32> (
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor query, // (B, H * nh, 1, L)
    torch::Tensor inv_freq, // (D//2)
    torch::Tensor offset, // (B)
    int nh
);


template torch::Tensor window_rope_dot_product<128, 64> (
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor query, // (B, H * nh, 1, L)
    torch::Tensor inv_freq, // (D//2)
    torch::Tensor offset, // (B)
    int nh
);

template torch::Tensor window_rope_dot_product<128, 128> (
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor query, // (B, H * nh, 1, L)
    torch::Tensor inv_freq, // (D//2)
    torch::Tensor offset, // (B)
    int nh
);
