#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define PACKED 16
#define TILE_SIZE 64
#define CODEBOOK_SZ 256
#define CODEBOOK_DIM 8

template <int D>
__global__ void adjust_scale_kernel(
    const int32_t* __restrict__ packed,
    const half* __restrict__ codebook,
    half* __restrict__ norm,
    const half* __restrict__ orig,
    int B,
    int H,
    int L
) {
    int b = blockIdx.x, h = blockIdx.y, l = blockIdx.z * blockDim.x;
    int tid = threadIdx.x;
    int base_idx = b * (H * L) + h * L + (l + tid);
    __shared__ float C[CODEBOOK_DIM][CODEBOOK_SZ];

    // Load codebook
    #pragma unroll
    for(int i=0; i<CODEBOOK_DIM; i++) {
        for(int j=tid; j<CODEBOOK_SZ; j+=TILE_SIZE) {
            C[i][j] = __half2float(codebook[j * CODEBOOK_DIM + i]);
        }
    }

    if(l + tid >= L) return;
    __syncthreads();

    float orig_sum = 0., dot_sum = 0.;
    float scale = __half2float(norm[base_idx]);
    #pragma unroll
    for(int base_i=0; base_i<D; base_i+=PACKED) {
        int32_t packed_idx = packed[base_idx * D / PACKED
                         + base_i/PACKED];
        #pragma unroll
        for(int i=0; i<PACKED; i++) {
            float original_val = __half2float(orig[base_idx * D + base_i + i]);
            int codebook_idx = (packed_idx >> ((i/CODEBOOK_DIM) * 8)) & 255;
            float curr_val = C[i % CODEBOOK_DIM][codebook_idx];
            bool sign = (packed_idx >> (PACKED+i)) & 1;
            curr_val = (sign ? curr_val : -curr_val);

            orig_sum += original_val * original_val;
            dot_sum += original_val * curr_val;
        }
    }
    scale *= orig_sum / dot_sum;
    norm[base_idx] = __float2half(scale);
}

template <int D>
torch::Tensor adjust_scale(
    torch::Tensor packed, // (B, H, L, D//PACKED)
    torch::Tensor codebook,
    torch::Tensor norm, // (B, H, L, 1)
    torch::Tensor orig // (B, H, L, D)
) {
    TORCH_CHECK(packed.is_contiguous(),      "packed must be contiguous");
    TORCH_CHECK(codebook.is_contiguous(),      "codebook must be contiguous");
    TORCH_CHECK(norm.is_contiguous(),      "norm must be contiguous");
    TORCH_CHECK(orig.is_contiguous(),      "orig must be contiguous");
    int dev_index = packed.device().index();
    cudaSetDevice(dev_index);

    int B = packed.size(0);
    int H = packed.size(1);
    int L = packed.size(2);

    dim3 blockDim(TILE_SIZE);
    dim3 gridDim(B, H, (L + TILE_SIZE - 1)/TILE_SIZE);

    adjust_scale_kernel<D><<<gridDim, blockDim, 0>>>(
        reinterpret_cast<const int32_t*>(packed.data_ptr<int32_t>()),
        reinterpret_cast<const half*>(codebook.data_ptr<at::Half>()),
        reinterpret_cast<half*>(norm.data_ptr<at::Half>()),
        reinterpret_cast<const half*>(orig.data_ptr<at::Half>()),
        B, H, L
    );

    return norm;
}


template torch::Tensor adjust_scale<128>(
    torch::Tensor packed, // (B, H, L, D//PACKED)
    torch::Tensor codebook,
    torch::Tensor norm, // (B, H, L, 1)
    torch::Tensor orig // (B, H, L, D)
);
