#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define PACKED 32
#define TILE_SIZE 64
#define CODEBOOK_SZ 256
#define CODEBOOK_DIM 8
#define DQ_PACK 8

template <int D>
__global__ void adjust_scale_dq_1bit_kernel(
    const int32_t* __restrict__ packed,
    const int32_t* __restrict__ codebook_idx,
    const half* __restrict__ codebook_scale,
    const half* __restrict__ codebook_offset,
    half* __restrict__ norm,
    const half* __restrict__ orig,
    int B,
    int H,
    int L
) {
    int b = blockIdx.x, h = blockIdx.y, l = blockIdx.z * blockDim.x;
    int tid = threadIdx.x;
    int base_idx = b * (H * L) + h * L + (l + tid);
    __shared__ int32_t C[CODEBOOK_SZ];
    __shared__ half CS[CODEBOOK_SZ];
    __shared__ half CO[CODEBOOK_SZ];

    // Load codebook
    #pragma unroll
    for(int j=tid; j<CODEBOOK_SZ; j+=TILE_SIZE) {
        C[j] = codebook_idx[j];
    }

    #pragma unroll
    for(int j=tid; j<CODEBOOK_SZ; j+=TILE_SIZE) {
        CS[j] = codebook_scale[j];
    }

    #pragma unroll
    for(int j=tid; j<CODEBOOK_SZ; j+=TILE_SIZE) {
        CO[j] = codebook_offset[j];
    }

    __syncthreads();

    float orig_sum = 0., dot_sum = 0.;
    float scale = __half2float(norm[base_idx]);
    #pragma unroll
    for(int base_i=0; base_i<D; base_i+=PACKED) {
        int32_t packed_idx = packed[base_idx * D / PACKED
                         + base_i/PACKED];
        #pragma unroll
        for(int i=0; i<PACKED; i++) {
            float original_val = __half2float(orig[base_idx * D + base_i + i]);
            int codebook_idx = (packed_idx >> ((i/CODEBOOK_DIM) * 8)) & 255;
            float curr_val = (float)((C[codebook_idx] >> (4 * (i%CODEBOOK_DIM))) & 15) * __half2float(CS[codebook_idx]);
            curr_val = __half2float(__float2half(curr_val)) + __half2float(CO[codebook_idx]); // To make the result same with torch c*cs + co

            orig_sum += original_val * original_val;
            dot_sum += original_val * curr_val;
        }
    }
    scale *= orig_sum / dot_sum;
    norm[base_idx] = __float2half(scale);
}

template <int D>
torch::Tensor adjust_scale_dq_1bit(
    torch::Tensor packed, // (B, H, L, D//PACKED)
    torch::Tensor codebook_idx,
    torch::Tensor codebook_scale,
    torch::Tensor codebook_offset,
    torch::Tensor norm, // (B, H, L, 1)
    torch::Tensor orig // (B, H, L, D)
) {
    TORCH_CHECK(packed.is_contiguous(),      "packed must be contiguous");
    TORCH_CHECK(codebook_idx.is_contiguous(),      "codebook_idx must be contiguous");
    TORCH_CHECK(codebook_scale.is_contiguous(),      "codebook_scale must be contiguous");
    TORCH_CHECK(codebook_offset.is_contiguous(),      "codebook_offset must be contiguous");
    TORCH_CHECK(norm.is_contiguous(),      "norm must be contiguous");
    TORCH_CHECK(orig.is_contiguous(),      "orig must be contiguous");
    int dev_index = packed.device().index();
    cudaSetDevice(dev_index);

    int B = packed.size(0);
    int H = packed.size(1);
    int L = packed.size(2);

    dim3 blockDim(TILE_SIZE);
    dim3 gridDim(B, H, (L + TILE_SIZE - 1)/TILE_SIZE);

    adjust_scale_dq_1bit_kernel<D><<<gridDim, blockDim, 0>>>(
        reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
        reinterpret_cast<const int32_t*>(codebook_idx.data_ptr<int32_t>()),
        reinterpret_cast<const half*>(codebook_scale.data_ptr<at::Half>()),
        reinterpret_cast<const half*>(codebook_offset.data_ptr<at::Half>()),
        reinterpret_cast<half*>(norm.data_ptr<at::Half>()),
        reinterpret_cast<const half*>(orig.data_ptr<at::Half>()),
        B, H, L
    );

    return norm;
}


template torch::Tensor adjust_scale_dq_1bit<128>(
    torch::Tensor packed, // (B, H, L, D//PACKED)
    torch::Tensor codebook_idx,
    torch::Tensor codebook_scale,
    torch::Tensor codebook_offset,
    torch::Tensor norm, // (B, H, L, 1)
    torch::Tensor orig // (B, H, L, D)
);
