#include <torch/extension.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <math.h>
#define NUM_ELEM_PER_THREAD 4
#define N 8
#define PACKED 32
#define N2 4
#define TILE_SIZE 32
#define DQ_PACK 8

__device__ __forceinline__ half imulh(int32_t a, half b) {
    return __float2half(float(a) * __half2float(b));
}

__device__ __forceinline__ half haddh(half a, half b) {
    return __float2half(__half2float(a) + __half2float(b));
}

__device__ __forceinline__ half hmulh(half a, half b) {
    return __float2half(__half2float(a) * __half2float(b));
}

// distArgminHalfPacked1BitKernel
__global__ void dist_argmin_half_packed_dq_1bit_kernel(
    const half* __restrict__ A,  // [L, N]
    const int32_t* __restrict__ B_idx,  // [M, N/8]
    const half* __restrict__ B_scale,  // [M, N/8]
    const half* __restrict__ B_offset,  // [M, N/8]
    int L,
    int M,
    int32_t* __restrict__ C
)
{
    int row = (blockIdx.x * blockDim.x + threadIdx.x) * NUM_ELEM_PER_THREAD;

    half2 a_reg2[NUM_ELEM_PER_THREAD][N2];
    float minDist[NUM_ELEM_PER_THREAD];
    int   minIdx[NUM_ELEM_PER_THREAD];
    int32_t packed = 0;

    for(int i=0; i<NUM_ELEM_PER_THREAD; i++) {
        minDist[i] = 1e30f;
        minIdx[i] = -1;
    }

    const float4* A2 = reinterpret_cast<const float4*>(A + row * N);
    if (row < L) {
        #pragma unroll
        for(int i=0; i<NUM_ELEM_PER_THREAD; i++) {
            ((float4 *)&a_reg2[i])[0] = A2[i]; // Only available when N==8
        }
    }


    __shared__ __align__(8) half2 b_tile[TILE_SIZE][N2];

    float4 b_reg;

    #pragma unroll
    for(int tileStart = 0; tileStart < M; tileStart += TILE_SIZE){
        int tileRow = tileStart + threadIdx.x;

        if(tileRow < M) {
            int32_t b_idx = B_idx[tileRow];
            half b_offset = B_offset[tileRow];
            half b_scale = B_scale[tileRow];
            #pragma unroll
            for(int k = 0; k < N; k++){
                ((half *)&b_reg)[k] = __hadd(__hmul(__int2half_rn(b_idx & 0xF),  b_scale), b_offset);
                b_idx >>= 4;
            }
            ((float4 *)&b_tile[threadIdx.x])[0] = b_reg;
        }
        // __syncthreads(); needed if TILE_SIZE > 32

        #pragma unroll
        for(int t = 0; t < TILE_SIZE; t++){
            b_reg = *((float4 *)&b_tile[t]);
            #pragma unroll
            for(int i=0; i<NUM_ELEM_PER_THREAD; i++) {
                int b_idx = tileStart + t;
                if(b_idx >= M) break;

                float dist = 0.f;
                #pragma unroll
                for(int k = 0; k < N2; k++){
                    float2 diff = __half22float2(__hsub2(a_reg2[i][k], ((half2 *)&b_reg)[k]));
                    dist += (diff.x * diff.x + diff.y * diff.y);
                }
                if(dist < minDist[i]){
                    minDist[i] = dist;
                    minIdx[i] = b_idx;
                }
            }
        }
        // __syncthreads(); needed if TILE_SIZE > 32
    }

    #pragma unroll
    for(int i=0; i<NUM_ELEM_PER_THREAD; i++) {
        packed |= ((int32_t)minIdx[i] << (i * N));
    }

    if (row < L) {
        C[row/NUM_ELEM_PER_THREAD] = packed;
    }
}


torch::Tensor dist_argmin_half_packed_dq_1bit(
    torch::Tensor A,
    torch::Tensor B_idx,
    torch::Tensor B_scale,
    torch::Tensor B_offset
){
    TORCH_CHECK(A.is_contiguous());
    TORCH_CHECK(B_idx.is_contiguous());
    TORCH_CHECK(B_scale.is_contiguous());
    TORCH_CHECK(B_offset.is_contiguous());
    TORCH_CHECK(A.is_cuda());
    TORCH_CHECK(B_idx.is_cuda());
    TORCH_CHECK(A.dim() == 2 && A.size(1) == N && A.size(0)%NUM_ELEM_PER_THREAD==0);
    TORCH_CHECK(B_idx.dim() == 2 && B_idx.size(1) == N/8);
    TORCH_CHECK(B_scale.dim() == 2 && B_scale.size(1) == N/8);
    TORCH_CHECK(B_offset.dim() == 2 && B_offset.size(1) == N/8);
    int dev_index = A.device().index();
    cudaSetDevice(dev_index);


    int64_t L = A.size(0);
    int64_t M = B_idx.size(0);

    auto out_opts = torch::TensorOptions().dtype(torch::kInt32).device(A.device());
    auto C = torch::empty({L/NUM_ELEM_PER_THREAD}, out_opts);

    int blockSize = TILE_SIZE;
    int gridSize = (L + blockSize * NUM_ELEM_PER_THREAD - 1) / (blockSize * NUM_ELEM_PER_THREAD);

    dist_argmin_half_packed_dq_1bit_kernel<<<gridSize, blockSize>>>(
        reinterpret_cast<const half*>(A.data_ptr()),
        reinterpret_cast<const int32_t*>(B_idx.data_ptr()),
        reinterpret_cast<const half*>(B_scale.data_ptr()),
        reinterpret_cast<const half*>(B_offset.data_ptr()),
        (int)L,
        (int)M,
        C.data_ptr<int32_t>()
    );

    auto err = cudaGetLastError();
    TORCH_CHECK(err == cudaSuccess);

    return C;
}
