/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <faiss/gpu/utils/DeviceUtils.h>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <faiss/gpu/utils/Select.cuh>

namespace faiss {
namespace gpu {

// Number of warps that the kernel is instantiated with
constexpr int kWarps = 8;
constexpr int kLanes = kWarpSize;

constexpr int kMaxDistance = std::numeric_limits<int>::max();

// Performs a binary matrix multiplication, returning the lowest k results in
// `vecs` for each `query` in terms of Hamming distance (a fused kernel)
// Each warp calculates distance for a single query
template <int NumWarpQ, int NumThreadQ, typename BinaryType>
__launch_bounds__(kWarps* kLanes) __global__ void binaryDistanceAnySize(
        const Tensor<BinaryType, 2, true> vecs,
        const Tensor<BinaryType, 2, true> query,
        Tensor<int, 2, true> outK,
        Tensor<idx_t, 2, true> outV,
        int k) {
    if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) {
        // A matrix tile (query, k)
        __shared__ BinaryType
                queryTile[kWarps][kLanes + 1]; // avoid bank conflict

        // B matrix tile (vec, k)
        __shared__ BinaryType
                vecTile[kLanes][kLanes + 1]; // avoid bank conflict

        WarpSelect<
                int,
                idx_t,
                false,
                Comparator<int>,
                NumWarpQ,
                NumThreadQ,
                kWarps * kLanes>
                heap(kMaxDistance, -1, k);

        int warpId = threadIdx.y;
        int laneId = threadIdx.x;

        // Each warp handles a single query
        idx_t warpQuery = idx_t(blockIdx.x) * kWarps + warpId;
        bool queryInBounds = warpQuery < query.getSize(0);

        // Each warp loops through the entire chunk of vectors
        for (idx_t blockVec = 0; blockVec < vecs.getSize(0);
             blockVec += kLanes) {
            int threadDistance = 0;

            // Reduction dimension
            for (idx_t blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) {
                idx_t laneK = blockK + laneId;
                bool kInBounds = laneK < vecs.getSize(1);

                queryTile[warpId][laneId] = queryInBounds && kInBounds
                        ? query[warpQuery][laneK]
                        : 0;

                // kWarps warps are responsible for loading 32 vecs
#pragma unroll
                for (int i = 0; i < kLanes / kWarps; ++i) {
                    int warpVec = i * kWarps + warpId;
                    idx_t vec = blockVec + warpVec;
                    bool vecInBounds = vec < vecs.getSize(0);

                    vecTile[warpVec][laneId] =
                            vecInBounds && kInBounds ? vecs[vec][laneK] : 0;
                }

                __syncthreads();

                // Compare distances
#pragma unroll
                for (int i = 0; i < kLanes; ++i) {
                    threadDistance +=
                            __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
                }

                __syncthreads();
            }

            // Lanes within a warp are different vec results against the same
            // query Only submit distances which represent real (query, vec)
            // pairs
            bool valInBounds =
                    queryInBounds && (blockVec + laneId < vecs.getSize(0));
            threadDistance = valInBounds ? threadDistance : kMaxDistance;
            idx_t id = valInBounds ? blockVec + laneId : idx_t(-1);

            heap.add(threadDistance, id);
        }

        heap.reduce();

        if (warpQuery < query.getSize(0)) {
            heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k);
        }
    }
}

// Version of the kernel that avoids a loop over the reduction dimension, and
// thus avoids reloading the query vectors
template <
        int NumWarpQ,
        int NumThreadQ,
        typename BinaryType,
        int ReductionLimit = kLanes>
__global__ void __launch_bounds__(kWarps* kLanes) binaryDistanceLimitSize(
        const Tensor<BinaryType, 2, true> vecs,
        const Tensor<BinaryType, 2, true> query,
        Tensor<int, 2, true> outK,
        Tensor<idx_t, 2, true> outV,
        int k) {
    if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) {
        // A matrix tile (query, k)
        __shared__ BinaryType
                queryTile[kWarps][kLanes + 1]; // avoid bank conflict

        // B matrix tile (vec, k)
        __shared__ BinaryType
                vecTile[kLanes][kLanes + 1]; // avoid bank conflict

        WarpSelect<
                int,
                idx_t,
                false,
                Comparator<int>,
                NumWarpQ,
                NumThreadQ,
                kWarps * kLanes>
                heap(kMaxDistance, -1, k);

        int warpId = threadIdx.y;
        int laneId = threadIdx.x;

        // Each warp handles a single query
        int laneK = laneId;
        idx_t warpQuery = idx_t(blockIdx.x) * kWarps + warpId;
        bool kInBounds = laneK < vecs.getSize(1);
        bool queryInBounds = warpQuery < query.getSize(0);

        queryTile[warpId][laneId] =
                queryInBounds && kInBounds ? query[warpQuery][laneK] : 0;

        // Each warp loops through the entire chunk of vectors
        for (idx_t blockVec = 0; blockVec < vecs.getSize(0);
             blockVec += kLanes) {
            int threadDistance = 0;

            // kWarps warps are responsible for loading 32 vecs
#pragma unroll
            for (int i = 0; i < kLanes / kWarps; ++i) {
                int warpVec = i * kWarps + warpId;
                idx_t vec = blockVec + warpVec;
                bool vecInBounds = vec < vecs.getSize(0);

                vecTile[warpVec][laneId] =
                        vecInBounds && kInBounds ? vecs[vec][laneK] : 0;
            }

            __syncthreads();

            // Compare distances
#pragma unroll
            for (int i = 0; i < ReductionLimit; ++i) {
                threadDistance +=
                        __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
            }

            __syncthreads();

            // Lanes within a warp are different vec results against the same
            // query Only submit distances which represent real (query, vec)
            // pairs
            bool valInBounds =
                    queryInBounds && (blockVec + laneId < vecs.getSize(0));
            threadDistance = valInBounds ? threadDistance : kMaxDistance;
            idx_t id = valInBounds ? blockVec + laneId : idx_t(-1);

            heap.add(threadDistance, id);
        }

        heap.reduce();

        if (warpQuery < query.getSize(0)) {
            heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k);
        }
    }
}

template <typename BinaryType>
void runBinaryDistanceAnySize(
        Tensor<BinaryType, 2, true>& vecs,
        Tensor<BinaryType, 2, true>& query,
        Tensor<int, 2, true>& outK,
        Tensor<idx_t, 2, true>& outV,
        int k,
        cudaStream_t stream) {
    dim3 grid(utils::divUp(query.getSize(0), kWarps));
    dim3 block(getWarpSizeCurrentDevice(), kWarps);

    if (k == 1) {
        binaryDistanceAnySize<1, 1, BinaryType>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) {
        binaryDistanceAnySize<32, 2, BinaryType>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 64) {
        binaryDistanceAnySize<64, 3, BinaryType>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 128) {
        binaryDistanceAnySize<128, 3, BinaryType>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 256) {
        binaryDistanceAnySize<256, 4, BinaryType>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 512) {
        binaryDistanceAnySize<512, 8, BinaryType>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 1024) {
        binaryDistanceAnySize<1024, 8, BinaryType>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    }
#if GPU_MAX_SELECTION_K >= 2048
    else if (k <= 2048) {
        binaryDistanceAnySize<2048, 8, BinaryType>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    }
#endif
}

template <typename BinaryType, int ReductionLimit>
void runBinaryDistanceLimitSize(
        Tensor<BinaryType, 2, true>& vecs,
        Tensor<BinaryType, 2, true>& query,
        Tensor<int, 2, true>& outK,
        Tensor<idx_t, 2, true>& outV,
        int k,
        cudaStream_t stream) {
    dim3 grid(utils::divUp(query.getSize(0), kWarps));
    dim3 block(getWarpSizeCurrentDevice(), kWarps);

    if (k == 1) {
        binaryDistanceLimitSize<1, 1, BinaryType, ReductionLimit>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) {
        binaryDistanceLimitSize<32, 2, BinaryType, ReductionLimit>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 64) {
        binaryDistanceLimitSize<64, 3, BinaryType, ReductionLimit>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 128) {
        binaryDistanceLimitSize<128, 3, BinaryType, ReductionLimit>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 256) {
        binaryDistanceLimitSize<256, 4, BinaryType, ReductionLimit>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 512) {
        binaryDistanceLimitSize<512, 8, BinaryType, ReductionLimit>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    } else if (k <= 1024) {
        binaryDistanceLimitSize<1024, 8, BinaryType, ReductionLimit>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    }
#if GPU_MAX_SELECTION_K >= 2048
    else if (k <= 2048) {
        binaryDistanceLimitSize<2048, 8, BinaryType, ReductionLimit>
                <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
    }
#endif
}

void runBinaryDistance(
        Tensor<unsigned char, 2, true>& vecs,
        Tensor<unsigned char, 2, true>& query,
        Tensor<int, 2, true>& outK,
        Tensor<idx_t, 2, true>& outV,
        int k,
        cudaStream_t stream) {
    FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
    FAISS_ASSERT(vecs.getSize(1) == query.getSize(1));

    FAISS_ASSERT(outK.getSize(1) == k);
    FAISS_ASSERT(outV.getSize(1) == k);

    // For the optimized uint32 kernel, we handle 32 * 8 = 256 max dims
    constexpr int kReductionLimit32 = 8;

    // For the optimized uint8 kernel, we handle 8 * 16 = 128 max dims
    constexpr int kReductionLimit8 = 16;

    // All other cases (large or small) go through the general kernel

    if (vecs.getSize(1) % sizeof(unsigned int) == 0 &&
        (vecs.getSize(1) / sizeof(unsigned int)) <= kReductionLimit32) {
        auto vecs32 = vecs.castResize<unsigned int>();
        auto query32 = query.castResize<unsigned int>();

        // Optimize for vectors with dimensions a multiple of 32 that are less
        // than 32 * kReductionLimit (256) dimensions in size
        runBinaryDistanceLimitSize<unsigned int, kReductionLimit32>(
                vecs32, query32, outK, outV, k, stream);

    } else if (vecs.getSize(1) <= kReductionLimit8) {
        // Optimize for vectors with dimensions a multiple of 32 that are less
        // than 32 * kReductionLimit (256) dimensions in size
        runBinaryDistanceLimitSize<unsigned char, kReductionLimit8>(
                vecs, query, outK, outV, k, stream);
    } else {
        // Arbitrary size kernel
        runBinaryDistanceAnySize<unsigned char>(
                vecs, query, outK, outV, k, stream);
    }
}

} // namespace gpu
} // namespace faiss
