#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define PACKED 16
#define CODEBOOK_SZ 256
#define CODEBOOK_DIM 8
#define DQ_PACK 8

template <int D, int window_size>
__global__ void restore_quantized_dq_kernel(
    const int32_t* __restrict__ packed,
    const half* __restrict__ norm,
    const half* __restrict__ mean,
    const half* __restrict__ norm2,
    const int32_t* __restrict__ codebook_idx,
    const half* __restrict__ codebook_scale,
    const half* __restrict__ codebook_offset,
    half* __restrict__ out,
    int B, int H, int L
) {
    int b = blockIdx.x, h = blockIdx.y, group = blockIdx.z;
    int in_group = threadIdx.x;
    int num_group = L / window_size;

    __shared__ int32_t C[CODEBOOK_SZ];
    __shared__ half CS[CODEBOOK_SZ];
    __shared__ half CO[CODEBOOK_SZ];
    extern __shared__ float M[];

    // Load codebook
    #pragma unroll
    for(int j=in_group; j<CODEBOOK_SZ; j+=window_size) {
        C[j] = codebook_idx[j];
    }

    #pragma unroll
    for(int j=in_group; j<CODEBOOK_SZ; j+=window_size) {
        CS[j] = codebook_scale[j];
    }

    #pragma unroll
    for(int j=in_group; j<CODEBOOK_SZ; j+=window_size) {
        CO[j] = codebook_offset[j];
    }
    // Load mean vector (B,H,num_group,1,D)
    for(int i=in_group; i<D; i+=window_size) {
        long long mean_idx = b * (H * num_group * D) + h * (num_group * D) + group * (D) + i;
        M[i] = __half2float(mean[mean_idx]);
    }

    __syncthreads();
    float n2 = __half2float(norm2[b * (H*L) + h * (L) + (group * window_size + in_group)]);
    float n1 = __half2float(norm[b * (H*L) + h * (L) + (group * window_size + in_group)]);

    int base_i = threadIdx.y * PACKED;
    int32_t packed_idx = packed[b * (H*L*D/PACKED)
                        + h * (L*D/PACKED)
                        + (group * window_size + in_group) * (D/PACKED)
                        + base_i/PACKED];
    #pragma unroll
    for(int i=0; i<PACKED; i++) {
        int codebook_idx = (packed_idx >> ((i/CODEBOOK_DIM) * 8)) & 255;
        float curr_val = (float)((C[codebook_idx] >> (4 * (i%CODEBOOK_DIM))) & 15) * __half2float(CS[codebook_idx]);
        curr_val = __half2float(__float2half(curr_val)) + __half2float(CO[codebook_idx]); // To make the result same with torch c*cs + co

        bool sign = (packed_idx >> (PACKED+i)) & 1;
        curr_val = (sign ? curr_val : -curr_val);
        curr_val *= n2;
        curr_val += M[base_i + i];
        curr_val *= n1;

        out[b * (H*L*D)
            + h * (L*D)
            + (group * window_size + in_group) * (D)
            + base_i + i] = __float2half(curr_val);
    }
}

// C++ 함수(바인딩)
template <int D, int window_size>
torch::Tensor restore_quantized_dq(
    torch::Tensor packed,      // (B,H,L,D//16), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H,L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256, 1), int32
    torch::Tensor codebook_scale, // (256, 1), half
    torch::Tensor codebook_offset // (256, 1), half
) {
    TORCH_CHECK(packed.is_contiguous(),      "packed must be contiguous");
    TORCH_CHECK(norm.is_contiguous(),      "norm must be contiguous");
    TORCH_CHECK(mean.is_contiguous(),      "mean must be contiguous");
    TORCH_CHECK(norm2.is_contiguous(),     "norm2 must be contiguous");
    TORCH_CHECK(codebook_idx.is_contiguous(),  "codebook_idx must be contiguous");
    TORCH_CHECK(codebook_scale.is_contiguous(),  "codebook_scale must be contiguous");
    TORCH_CHECK(codebook_offset.is_contiguous(),  "codebook_offset must be contiguous");

    int dev_index = packed.device().index();
    cudaSetDevice(dev_index);

    int B = packed.size(0);
    int H = packed.size(1);
    int L = packed.size(2);

    // 출력 텐서 할당
    auto out = torch::empty({B, H, L, D}, packed.options().dtype(torch::kHalf));

    // grid: (B * H * (L // window_size)), block: 예) 256
    dim3 blockDim(window_size, D/PACKED);
    dim3 gridDim(B, H, L/window_size);

    // mean은 (B,H,L//window_size,1,D) 이므로, 한 block마다 D개의 mean이 필요
    // shared memory 크기 = D * sizeof(float)
    size_t shared_mem_bytes = D * sizeof(float);

    restore_quantized_dq_kernel<D, window_size><<<gridDim, blockDim, shared_mem_bytes>>>(
        reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
        reinterpret_cast<const half*>(norm.data_ptr<at::Half>()),
        reinterpret_cast<const half*>(mean.data_ptr<at::Half>()),
        reinterpret_cast<const half*>(norm2.data_ptr<at::Half>()),
        reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
        reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
        reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
        reinterpret_cast<half*>(out.data_ptr<at::Half>()),
        B, H, L
    );

    return out;
}


template torch::Tensor restore_quantized_dq<128, 64>(
    torch::Tensor packed,      // (B,H,L,D//16), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H,L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256, 1), int32
    torch::Tensor codebook_scale, // (256, 1), half
    torch::Tensor codebook_offset // (256, 1), half
);
