#include <torch/extension.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>

constexpr int N = 8;
constexpr int N2 = N/2;
constexpr int TILE_SIZE = 64;

// distArgminHalfKernel2
__global__ void distArgminHalfKernel(
    const half* __restrict__ A,  // [L, N]
    const half* __restrict__ B,  // [M, N]
    int L,
    int M,
    uint8_t* __restrict__ C
)
{
    int row = blockIdx.x * blockDim.x + threadIdx.x;

    half2 a_reg2[N2];
    const half2* A2 = reinterpret_cast<const half2*>(A + row * N);
    if (row < L) {
        #pragma unroll
        for(int i = 0; i < N2; i++){
            a_reg2[i] = A2[i];
        }
    }

    float minDist = 1e30f;
    int   minIdx  = -1;

    __shared__ half2 b_tile2[N2][TILE_SIZE];

    for(int tileStart = 0; tileStart < M; tileStart += TILE_SIZE){
        int tileRow = tileStart + threadIdx.x;

        if(tileRow < M) {
            const half2* B2 = reinterpret_cast<const half2*>(B + tileRow * N);
            #pragma unroll
            for(int k = 0; k < N2; k++){
                b_tile2[k][threadIdx.x] = B2[k];
            }
        }
        __syncthreads();

        #pragma unroll
        for(int t = 0; t < TILE_SIZE; t++){
            int b_idx = tileStart + t;
            if(b_idx >= M) break;

            float dist = 0.f;
            #pragma unroll
            for(int k = 0; k < N2; k++){
                float2 a = __half22float2(a_reg2[k]);
                float2 b = __half22float2(b_tile2[k][t]);
                float2 diff = make_float2(a.x - b.x, a.y - b.y);
                dist += (diff.x * diff.x + diff.y * diff.y);
            }
            if(dist < minDist){
                minDist = dist;
                minIdx  = b_idx;
            }
        }
        __syncthreads();
    }

    if(row < L) C[row] = static_cast<uint8_t>(minIdx);
}


torch::Tensor dist_argmin_half(
    torch::Tensor A,
    torch::Tensor B
){
    TORCH_CHECK(A.is_cuda());
    TORCH_CHECK(B.is_cuda());
    TORCH_CHECK(A.dim() == 2 && A.size(1) == N);
    TORCH_CHECK(B.dim() == 2 && B.size(1) == N);
    TORCH_CHECK(B.size(0) <= 256);
    int dev_index = A.device().index();
    cudaSetDevice(dev_index);


    int64_t L = A.size(0);
    int64_t M = B.size(0);

    auto out_opts = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
    auto C = torch::empty({L}, out_opts);

    int blockSize = TILE_SIZE;
    int gridSize = (L + blockSize - 1) / blockSize;

    distArgminHalfKernel<<<gridSize, blockSize>>>(
        reinterpret_cast<const half*>(A.data_ptr()),
        reinterpret_cast<const half*>(B.data_ptr()),
        (int)L,
        (int)M,
        C.data_ptr<uint8_t>()
    );

    auto err = cudaGetLastError();
    TORCH_CHECK(err == cudaSuccess);

    return C;
}
