#include <torch/extension.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <math.h>
#define NUM_ELEM_PER_THREAD 2
#define N 8
#define PACKED 16
#define N2 4
#define TILE_SIZE 32

__device__ __forceinline__ int32_t half2_extract_sign(half2 h) {
    half h_lo = __low2half(h);
    half h_hi = __high2half(h);

    bool is_negative_lo = __hlt(h_lo, __float2half(0.0f));
    bool is_negative_hi = __hlt(h_hi, __float2half(0.0f));

    int32_t sign_lo = is_negative_lo ? 0 : 1;
    int32_t sign_hi = is_negative_hi ? 0 : 2;
    
    return sign_hi | sign_lo;
}

// distArgminHalfKernel
__global__ void distArgminHalfPackedKernel(
    const half* __restrict__ A,  // [L, N]
    const half* __restrict__ B,  // [M, N]
    int L,
    int M,
    int32_t* __restrict__ C
)
{
    int row = (blockIdx.x * blockDim.x + threadIdx.x) * NUM_ELEM_PER_THREAD;

    half2 a_reg2[NUM_ELEM_PER_THREAD][N2];
    float minDist[NUM_ELEM_PER_THREAD];
    int   minIdx[NUM_ELEM_PER_THREAD];
    int32_t packed = 0;

    for(int i=0; i<NUM_ELEM_PER_THREAD; i++) {
        minDist[i] = 1e30f;
        minIdx[i] = -1;
    }

    const float4* A2 = reinterpret_cast<const float4*>(A + row * N);
    if(row < L) {
        #pragma unroll
        for(int i=0; i<NUM_ELEM_PER_THREAD; i++) {
            ((float4 *)&a_reg2[i])[0] = A2[i];
            #pragma unroll
            for(int j= 0; j < N2; j++){
                int32_t signs = half2_extract_sign(a_reg2[i][j]);
                packed |= (signs << (PACKED + i*N + j*2));
                a_reg2[i][j] = __habs2(a_reg2[i][j]);
            }
        }
    }


    __shared__ __align__(8) half2 b_tile2[TILE_SIZE][N2];
    float4 b_reg;

    #pragma unroll
    for(int tileStart = 0; tileStart < M; tileStart += TILE_SIZE){
        int tileRow = tileStart + threadIdx.x;

        if(tileRow < M) {
            const float4* B2 = reinterpret_cast<const float4*>(B + tileRow * N);
            ((float4 *)&b_tile2[threadIdx.x])[0] = B2[0];
        }
        // __syncthreads(); needed if TILE_SIZE > 32

        #pragma unroll
        for(int t = 0; t < TILE_SIZE; t++){
            b_reg = *((float4 *)&b_tile2[t]);
            for(int i=0; i<NUM_ELEM_PER_THREAD; i++) {
                int b_idx = tileStart + t;
                if(b_idx >= M) break;

                float dist = 0.f;
                #pragma unroll
                for(int k = 0; k < N2; k++){
                    float2 diff = __half22float2(__hsub2(a_reg2[i][k], ((half2 *)&b_reg)[k]));
                    dist += (diff.x * diff.x + diff.y * diff.y);
                }
                if(dist < minDist[i]){
                    minDist[i] = dist;
                    minIdx[i] = b_idx;
                }
            }
        }
        // __syncthreads(); needed if TILE_SIZE > 32
    }

    #pragma unroll
    for(int i=0; i<NUM_ELEM_PER_THREAD; i++) {
        packed |= ((int32_t)minIdx[i] << (i * N));
    }

    if(row < L) C[row/NUM_ELEM_PER_THREAD] = packed;
}


torch::Tensor dist_argmin_half_packed(
    torch::Tensor A,
    torch::Tensor B
){
    TORCH_CHECK(A.is_contiguous());
    TORCH_CHECK(B.is_contiguous());
    TORCH_CHECK(A.is_cuda());
    TORCH_CHECK(B.is_cuda());
    TORCH_CHECK(A.dim() == 2 && A.size(1) == N && A.size(0)%NUM_ELEM_PER_THREAD==0);
    TORCH_CHECK(B.dim() == 2 && B.size(1) == N);
    TORCH_CHECK(B.size(0) <= 256);
    int dev_index = A.device().index();
    cudaSetDevice(dev_index);


    int64_t L = A.size(0);
    int64_t M = B.size(0);

    auto out_opts = torch::TensorOptions().dtype(torch::kInt32).device(A.device());
    auto C = torch::empty({L/NUM_ELEM_PER_THREAD}, out_opts);

    int blockSize = TILE_SIZE;
    int gridSize = (L + blockSize * NUM_ELEM_PER_THREAD - 1) / (blockSize * NUM_ELEM_PER_THREAD);

    distArgminHalfPackedKernel<<<gridSize, blockSize>>>(
        reinterpret_cast<const half*>(A.data_ptr()),
        reinterpret_cast<const half*>(B.data_ptr()),
        (int)L,
        (int)M,
        C.data_ptr<int32_t>()
    );

    auto err = cudaGetLastError();
    TORCH_CHECK(err == cudaSuccess);

    return C;
}
