//
//

#include "FalconnLite.h"
#include "Utilities.h"
#include <queue>
#include <set>
#include <algorithm>
#include <iterator>
#include <fstream>

/**
 * We use two different random vectors R and S to simulate [2D]^2 number of random vectors
 *
 * Maintain the projection matrix MATRIX_FHT for parallel processing.
 * Store binary bits for FHWT transform, especially we have to use them for Dbscan-1NN
 * Using priority queue to extract top-k and top-m
 *
 * - For each point Xi, compute its dot product, extract top-k close/far random vectors
 * - For each random vector Ri, reuse dot product matrix, extract top-k close/far points
 *
 */
void FalconnLite::fht_Index2()
{
    // Compute the data center
    VectorXf vecCenter = VectorXf::Zero(FalconnLite::n_features);
    if (FalconnLite::centering) {
        for (int n = 0; n < FalconnLite::n_points; n++)
            vecCenter += FalconnLite::matrix_X.row(n);
        vecCenter /= FalconnLite::n_points;
    }

    /** Param for embedding L1 and L2 **/
    int iFourierEmbed_D = FalconnLite::ker_n_features / 2; // This is becase we need cos() and sin()

    // See: https://github.com/hichamjanati/srf/blob/master/RFF-I.ipynb
    if (FalconnLite::distance == "L1")
        FalconnLite::matrix_R = cauchyGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // K(x, y) = exp(-gamma * L1_dist(X, y))) where gamma = 1/sigma
    else if (FalconnLite::distance == "L2")
        FalconnLite::matrix_R = gaussGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // std = 1/sigma, K(x, y) = exp(-gamma * L2_dist^2(X, y))) where gamma = 1/2 sigma^2

    /** This vector contains (2D)^2 minQue, one queue for each random vector **/
    // Each minQue is a bucket, and each point is hashed into K buckets
    // However, we cannot check all combination (2D)^2, so we greedily check K^2 buckets
    int numBuckets = 4 * FalconnLite::n_proj * FalconnLite::n_proj;
    int num2D = 2 * FalconnLite::n_proj;
    vector<priority_queue< IFPair, vector<IFPair>, greater<> >> vectorMinQue_TopM(numBuckets);

    // Note: Since we do not want to consider (2D)^2 choices of pairs of random vectors as it costs n(2D)^2 (i.e. CEOs idea)
    // Particularly, for each random vector, we have to keep O(n) projection values so that we can have a good estimate
    // and then aggregate to form the top-M points for each pairs of random vectors.
    // Note: We use the heuristic that considers (2D)^2 as number of buckets, and only consider points hashed into this bucket (i.e. Falconn++ idea)
    // We hash a point into (top-P)^2 buckets, and then we only consider the top-M points in each bucket

    int log2Project = log2(FalconnLite::fhtDim);
    bitHD3Generator2(FalconnLite::fhtDim * FalconnLite::n_rotate, FalconnLite::seed, FalconnLite::bitHD1, FalconnLite::bitHD2);

    /** Param for index **/
    FalconnLite::matrix_qProbes = MatrixXi::Zero(FalconnLite::qProbe, FalconnLite::n_points); // the first topK is for close, the second topK is for far away
    FalconnLite::vec2D_Buckets = vector<IVector> (numBuckets);


    // Note: If NUM_LOCKS is large, we might not have enough stack memory if using array
    // if D = 128 = 2^7, then numBuckets = 2^16 = 65536. We aim at 256 KB memory for locks
    // 16K locks is good for million-point data set though it is not good for small data sets.
    constexpr size_t NUM_LOCKS = 16384;
    vector<omp_lock_t> locks(NUM_LOCKS); // NUM_LOCK = 16K locks = only 256 KB
    // Initialize locks
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; i++) {
        omp_init_lock(&locks[i]);
    }

    /**
    Parallel for each the point Xi: (1) Compute and store dot product, and (2) Extract top-k close/far random vectors
    **/
#pragma omp parallel for
    for (int n = 0; n < FalconnLite::n_points; ++n) {
        /**
        Random embedding
        **/
        VectorXf vecX = FalconnLite::matrix_X.row(n);
        if (FalconnLite::centering)
            vecX -= vecCenter;

        VectorXf vecEmbed = VectorXf::Zero(FalconnLite::ker_n_features); // sOptics::ker_n_features >= D

        /// must ensure ker_n_features = n_features on Cosine
        if (FalconnLite::distance == "Cosine")
            vecEmbed.segment(0, FalconnLite::n_features) = vecX;
        else if ((FalconnLite::distance == "L1") || (FalconnLite::distance == "L2"))
        {
            VectorXf vecProject = FalconnLite::matrix_R * vecX;
            vecEmbed.segment(0, iFourierEmbed_D) = vecProject.array().cos();
            vecEmbed.segment(iFourierEmbed_D, iFourierEmbed_D) = vecProject.array().sin(); // start from iEmbbed, copy iEmbed elements
        }
        else if (FalconnLite::distance == "Chi2")
            embedChi2(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
        else if (FalconnLite::distance == "JS")
            embedJS(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);

        /**
        Random projection
        **/

        VectorXf rotatedX1 = VectorXf::Zero(FalconnLite::fhtDim); // NUM_PROJECT > PARAM_KERNEL_EMBED_D
        rotatedX1.segment(0, FalconnLite::ker_n_features) = vecEmbed;

        VectorXf rotatedX2 = rotatedX1;

        for (int r = 0; r < FalconnLite::n_rotate; ++r)
        {
            // Component-wise multiplication with a random sign
            for (int d = 0; d < FalconnLite::fhtDim; ++d)
            {
                rotatedX1(d) *= (2 * static_cast<float>(FalconnLite::bitHD1[r * FalconnLite::fhtDim + d]) - 1);
                rotatedX2(d) *= (2 * static_cast<float>(FalconnLite::bitHD2[r * FalconnLite::fhtDim + d]) - 1);
            }

            // Multiple with Hadamard matrix by calling FWHT transform
            fht_float(rotatedX1.data(), log2Project);
            fht_float(rotatedX2.data(), log2Project);
        }

        // rotateX1 = {x1 * r1, ..., x1 * r_D}
        // rotateX2 = {x1 * s1, ..., x1 * s_D}

        // cout << "We finish random rotating" << endl;

        // This queue is used for finding top-k max hash values and hash index for iProbes on each layer
        priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ1, minQueTopI1; // 1st layer
        priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ2, minQueTopI2; // 2nd layer

        /**
        We use a priority queue to keep top-max abs projection for each repeat
        Always ensure fhtDim >= n_proj
        **/
        for (int r = 0; r < FalconnLite::n_proj; ++r)
        {
            // 1st rotation
            int iSign = sgn(rotatedX1(r));  //投影符号 +1 or -1
            float fAbsHashValue = iSign * rotatedX1(r); // ｜rotatedX1(r)｜

            int Ri_2D = r; // index of random vector in [2D] after consider the sign
            if (iSign < 0)
                // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                    Ri_2D += FalconnLite::n_proj; // Be aware the case that n_proj is not 2^(log2Proj)

            // qProbe
            if ((int)minQueTopQ1.size() < FalconnLite::qProbe)
                minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
            else if (fAbsHashValue > minQueTopQ1.top().m_fValue)
            {
                minQueTopQ1.pop();
                minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
            }

            // iProbe
            if ((int)minQueTopI1.size() < FalconnLite::iProbe)
                minQueTopI1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
            else if (fAbsHashValue > minQueTopI1.top().m_fValue)
            {
                minQueTopI1.pop();
                minQueTopI1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
            }

            // 2nd rotation
            iSign = sgn(rotatedX2(r));
            fAbsHashValue = iSign * rotatedX2(r);

            Ri_2D = r;
            if (iSign < 0)
                // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                    Ri_2D += FalconnLite::n_proj; // set bit at position log2(D)

            // qProbe
            if ((int)minQueTopQ2.size() < FalconnLite::qProbe)
                minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
            else if (fAbsHashValue > minQueTopQ2.top().m_fValue)
            {
                minQueTopQ2.pop();
                minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
            }

            // iProbe (top-iProbe random vector closest to Xn)
            if ((int)minQueTopI2.size() < FalconnLite::iProbe)
                minQueTopI2.emplace(Ri_2D, fAbsHashValue);
            else if (fAbsHashValue > minQueTopI2.top().m_fValue)
            {
                minQueTopI2.pop();
                minQueTopI2.emplace(Ri_2D, fAbsHashValue);
            }
        }

        // Convert to vector
        vector<IFPair> vec_topQ1(FalconnLite::qProbe), vec_topQ2(FalconnLite::qProbe);
        vector<IFPair> vec_topI1(FalconnLite::iProbe), vec_topI2(FalconnLite::iProbe);

        // qProbe
        for (int k = FalconnLite::qProbe - 1; k >= 0; --k)
        {
            vec_topQ1[k] = minQueTopQ1.top();
            minQueTopQ1.pop();

            vec_topQ2[k] = minQueTopQ2.top();
            minQueTopQ2.pop();
        }

        // iProbe-Falconn++
        for (int p = FalconnLite::iProbe - 1; p >= 0; --p)
        {
            vec_topI1[p] = minQueTopI1.top();
            minQueTopI1.pop();

            vec_topI2[p] = minQueTopI2.top();
            minQueTopI2.pop();
        }

        assert(vec_topQ1.size() == FalconnLite::qProbe);
        assert(vec_topQ2.size() == FalconnLite::qProbe);
        assert(vec_topI1.size() == FalconnLite::iProbe);
        assert(vec_topI2.size() == FalconnLite::iProbe);

        // cout << "We finish extracting top-qProbe and top-iProbe for 2 layers." << endl;
        // count << "Now we combine them together." << endl;

        /**
        Use minQue to find the top-qProbe over 2 layers via sum of 2 estimators
        vec1 and vec2 are already sorted, and has length of sOptics::topK
        Note: Heuristic: We consider top-k * top-k pairs for Top-K, and top-p * top-p pairs for Top-M
        Note: We cannot check all combinations due to significant cost
        **/
        priority_queue<IFPair, vector<IFPair>, greater<>> minQueTopQ;

        // qProbe
        for (const auto& ifPair1: vec_topQ1)
        {
            int Ri_2D_1st = ifPair1.m_iIndex;
            float fAbsHashValue1 = ifPair1.m_fValue;

            for (const auto& ifPair2: vec_topQ2)
            {
                int R2_2D_2nd = ifPair2.m_iIndex;
                float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                //We have 2D * 2D buckets (i.e. random vectors)
                int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                // assert(iBucketIndex < vectorMinQue_TopM.size());

                // Push all points into the bucket
                if ((int)minQueTopQ.size() < FalconnLite::qProbe)
                    minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                else if (fAbsSumHash > minQueTopQ.top().m_fValue)
                {
                    minQueTopQ.pop();
                    minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                }
            }
        }

        /** Extract the random vector idx (in the form r1 * 2D + r2) for the point idx n **/
        int k = FalconnLite::qProbe - 1;
        // MinQue has the size TopK
        while (!minQueTopQ.empty())
        {
            IFPair ifPair = minQueTopQ.top(); // index is bucketID, value is sumAbsHash
            minQueTopQ.pop();
            FalconnLite::matrix_qProbes(k, n) = ifPair.m_iIndex;
            k--;
        }



        // iProbe for Falconn++
        for (const auto& ifPair1: vec_topI1)
        {
            int Ri_2D_1st = ifPair1.m_iIndex;
            float fAbsHashValue1 = ifPair1.m_fValue;

            for (const auto& ifPair2: vec_topI2)
            {
                int R2_2D_2nd = ifPair2.m_iIndex;
                float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                //We have 2D * 2D buckets (i.e. random vectors)
                int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                // assert(iBucketIndex < vectorMinQue_TopM.size());

                omp_set_lock(&locks[iBucketIndex % NUM_LOCKS]);

                // Note: This implementation is used for controlling the size of dense bucket.

                if ((int)vectorMinQue_TopM[iBucketIndex].size() < FalconnLite::top_m)
                    vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);

                else if (fAbsSumHash > vectorMinQue_TopM[iBucketIndex].top().m_fValue)
                {
                    vectorMinQue_TopM[iBucketIndex].pop();
                    vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);
                }


                // Push all points into the bucket
                // FalconnLite::vec2D_Buckets[iBucketIndex].push_back(n);

                omp_unset_lock(&locks[iBucketIndex % NUM_LOCKS]);
            }
        }
    }

    // Destroy locks for Falconn++
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; ++i) {
        omp_destroy_lock(&locks[i]);
    }

    // Extract top-M for each bucketIdx - Falconn++
#pragma omp parallel for
    for (int b = 0; b < numBuckets; ++b)
    {
        int k = (int)vectorMinQue_TopM[b].size();
        FalconnLite::vec2D_Buckets[b] = IVector(k);
        while (!vectorMinQue_TopM[b].empty())
        {
            FalconnLite::vec2D_Buckets[b][k - 1] = vectorMinQue_TopM[b].top().m_iIndex;
            vectorMinQue_TopM[b].pop();
            k--;
        }
    }

    // int iTotalPoints = 0;
    // for (int i = 0; i < numBuckets; ++i)
    //     iTotalPoints += FalconnLite::vec2D_Buckets[i].size();
    // cout << "Total points in all buckets: " << iTotalPoints << endl;
}

/**
 * We use two different random vectors R and S to simulate [2D]^2 number of random vectors
 *
 * Maintain the projection matrix MATRIX_FHT for parallel processing.
 * Store binary bits for FHWT transform, especially we have to use them for Dbscan-1NN
 * Using priority queue to extract top-k and top-m
 *
 * - For each point Xi, compute its dot product, extract top-k close/far random vectors
 * - For each random vector Ri, reuse dot product matrix, extract top-k close/far points
 *
 * - If using centering, we center all points and the result will be for the new data set X - vecCenter
 */
void FalconnLite::fht_Index2_repeat(int n_repeats)
{
    // Compute the data center
    VectorXf vecCenter = VectorXf::Zero(FalconnLite::n_features);
    if (FalconnLite::centering) {
        for (int n = 0; n < FalconnLite::n_points; n++)
            vecCenter += FalconnLite::matrix_X.row(n);
        vecCenter /= FalconnLite::n_points;
    }


    /** Global parameter **/
    int iFourierEmbed_D = FalconnLite::ker_n_features / 2; // This is becase we need cos() and sin()
    int numBucketsPerRepeat = 4 * FalconnLite::n_proj * FalconnLite::n_proj;
    int num2D = 2 * FalconnLite::n_proj;
    int log2Project = log2(FalconnLite::fhtDim);

    FalconnLite::matrix_qProbes = MatrixXi::Zero(FalconnLite::qProbe * n_repeats, FalconnLite::n_points); // the first topK is for close, the second topK is for far away
    FalconnLite::vec2D_Buckets = vector<IVector> (numBucketsPerRepeat * n_repeats);

    // Note: If NUM_LOCKS is large, we might not have enough stack memory if using array
    // if D = 128 = 2^7, then numBuckets = 2^16 = 65536. We aim at 256 KB memory for locks
    // 16K locks is good for million-point data set though it is not good for small data sets.
    constexpr size_t NUM_LOCKS = 16384;
    vector<omp_lock_t> locks(NUM_LOCKS); // NUM_LOCK = 16K locks = only 256 KB

    // Initialize locks
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; i++) {
        omp_init_lock(&locks[i]);
    }

    // NOTE: n_repeats should not be large, e.g., 5 or 10
    // For each repeat, we have different random vectors
    for (int repeat = 0; repeat < n_repeats; ++repeat)
    {
        // See: https://github.com/hichamjanati/srf/blob/master/RFF-I.ipynb
        if (FalconnLite::distance == "L1")
            FalconnLite::matrix_R = cauchyGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // K(x, y) = exp(-gamma * L1_dist(X, y))) where gamma = 1/sigma
        else if (FalconnLite::distance == "L2")
            FalconnLite::matrix_R = gaussGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // std = 1/sigma, K(x, y) = exp(-gamma * L2_dist^2(X, y))) where gamma = 1/2 sigma^2

        vector<priority_queue< IFPair, vector<IFPair>, greater<> >> vectorMinQue_TopM(numBucketsPerRepeat);
        bitHD3Generator2(FalconnLite::fhtDim * FalconnLite::n_rotate, FalconnLite::seed, FalconnLite::bitHD1, FalconnLite::bitHD2);

        /**
        Parallel for each the point Xi: (1) Compute and store dot product, and (2) Extract top-k close/far random vectors
        **/
#pragma omp parallel for
        for (int n = 0; n < FalconnLite::n_points; ++n)
        {
            /**
            Random embedding
            **/
            VectorXf vecX = FalconnLite::matrix_X.row(n);
            if (FalconnLite::centering)
                vecX -= vecCenter;

            VectorXf vecEmbed = VectorXf::Zero(FalconnLite::ker_n_features); // sOptics::ker_n_features >= D

            /// must ensure ker_n_features = n_features on Cosine
            if (FalconnLite::distance == "Cosine")
                vecEmbed.segment(0, FalconnLite::n_features) = vecX;
            else if ((FalconnLite::distance == "L1") || (FalconnLite::distance == "L2"))
            {
                VectorXf vecProject = FalconnLite::matrix_R * vecX;
                vecEmbed.segment(0, iFourierEmbed_D) = vecProject.array().cos();
                vecEmbed.segment(iFourierEmbed_D, iFourierEmbed_D) = vecProject.array().sin(); // start from iEmbbed, copy iEmbed elements
            }
            else if (FalconnLite::distance == "Chi2")
                embedChi2(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
            else if (FalconnLite::distance == "JS")
                embedJS(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);

            /**
            Random projection
            **/

            VectorXf rotatedX1 = VectorXf::Zero(FalconnLite::fhtDim); // NUM_PROJECT > PARAM_KERNEL_EMBED_D
            rotatedX1.segment(0, FalconnLite::ker_n_features) = vecEmbed;

            VectorXf rotatedX2 = rotatedX1;

            for (int r = 0; r < FalconnLite::n_rotate; ++r)
            {
                // Component-wise multiplication with a random sign
                for (int d = 0; d < FalconnLite::fhtDim; ++d)
                {
                    rotatedX1(d) *= (2 * static_cast<float>(FalconnLite::bitHD1[r * FalconnLite::fhtDim + d]) - 1);
                    rotatedX2(d) *= (2 * static_cast<float>(FalconnLite::bitHD2[r * FalconnLite::fhtDim + d]) - 1);
                }

                // Multiple with Hadamard matrix by calling FWHT transform
                fht_float(rotatedX1.data(), log2Project);
                fht_float(rotatedX2.data(), log2Project);
            }

            // rotateX1 = {x1 * r1, ..., x1 * r_D}
            // rotateX2 = {x1 * s1, ..., x1 * s_D}

            // cout << "We finish random rotating" << endl;

            // This queue is used for finding top-k max hash values and hash index for iProbes on each layer
            priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ1, minQueTopI1; // 1st layer
            priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ2, minQueTopI2; // 2nd layer

            /**
            We use a priority queue to keep top-max abs projection for each repeat
            Always ensure fhtDim >= n_proj
            **/
            for (int r = 0; r < FalconnLite::n_proj; ++r)
            {
                // 1st rotation
                int iSign = sgn(rotatedX1(r));
                float fAbsHashValue = iSign * rotatedX1(r);

                int Ri_2D = r; // index of random vector in [2D] after consider the sign
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // Be aware the case that n_proj is not 2^(log2Proj)

                // qProbe
                if ((int)minQueTopQ1.size() < FalconnLite::qProbe)
                    minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
                else if (fAbsHashValue > minQueTopQ1.top().m_fValue)
                {
                    minQueTopQ1.pop();
                    minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
                }

                // iProbe
                if ((int)minQueTopI1.size() < FalconnLite::iProbe)
                    minQueTopI1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
                else if (fAbsHashValue > minQueTopI1.top().m_fValue)
                {
                    minQueTopI1.pop();
                    minQueTopI1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
                }

                // 2nd rotation
                iSign = sgn(rotatedX2(r));
                fAbsHashValue = iSign * rotatedX2(r);

                Ri_2D = r;
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // set bit at position log2(D)

                // qProbe
                if ((int)minQueTopQ2.size() < FalconnLite::qProbe)
                    minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
                else if (fAbsHashValue > minQueTopQ2.top().m_fValue)
                {
                    minQueTopQ2.pop();
                    minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
                }

                // iProbe (top-iProbe random vector closest to Xn)
                if ((int)minQueTopI2.size() < FalconnLite::iProbe)
                    minQueTopI2.emplace(Ri_2D, fAbsHashValue);
                else if (fAbsHashValue > minQueTopI2.top().m_fValue)
                {
                    minQueTopI2.pop();
                    minQueTopI2.emplace(Ri_2D, fAbsHashValue);
                }
            }

            // Convert to vector
            vector<IFPair> vec_topQ1(FalconnLite::qProbe), vec_topQ2(FalconnLite::qProbe);
            vector<IFPair> vec_topI1(FalconnLite::iProbe), vec_topI2(FalconnLite::iProbe);

            // qProbe
            for (int k = FalconnLite::qProbe - 1; k >= 0; --k)
            {
                vec_topQ1[k] = minQueTopQ1.top();
                minQueTopQ1.pop();

                vec_topQ2[k] = minQueTopQ2.top();
                minQueTopQ2.pop();
            }

            // iProbe-Falconn++
            for (int p = FalconnLite::iProbe - 1; p >= 0; --p)
            {
                vec_topI1[p] = minQueTopI1.top();
                minQueTopI1.pop();

                vec_topI2[p] = minQueTopI2.top();
                minQueTopI2.pop();
            }

            assert(vec_topQ1.size() == FalconnLite::qProbe);
            assert(vec_topQ2.size() == FalconnLite::qProbe);
            assert(vec_topI1.size() == FalconnLite::iProbe);
            assert(vec_topI2.size() == FalconnLite::iProbe);

            // cout << "We finish extracting top-qProbe and top-iProbe for 2 layers." << endl;
            // count << "Now we combine them together." << endl;

            /**
            Use minQue to find the top-qProbe over 2 layers via sum of 2 estimators
            vec1 and vec2 are already sorted, and has length of sOptics::topK
            Note: Heuristic: We consider top-k * top-k pairs for Top-K, and top-p * top-p pairs for Top-M
            Note: We cannot check all combinations due to significant cost
            **/
            priority_queue<IFPair, vector<IFPair>, greater<>> minQueTopQ;

            // qProbe
            for (const auto& ifPair1: vec_topQ1)
            {
                int Ri_2D_1st = ifPair1.m_iIndex;
                float fAbsHashValue1 = ifPair1.m_fValue;

                for (const auto& ifPair2: vec_topQ2)
                {
                    int R2_2D_2nd = ifPair2.m_iIndex;
                    float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                    //We have 2D * 2D buckets (i.e. random vectors)
                    int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                    // assert(iBucketIndex < vectorMinQue_TopM.size());

                    // Push all points into the bucket
                    if ((int)minQueTopQ.size() < FalconnLite::qProbe)
                        minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                    else if (fAbsSumHash > minQueTopQ.top().m_fValue)
                    {
                        minQueTopQ.pop();
                        minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                    }
                }
            }

            /** Extract the random vector idx (in the form r1 * 2D + r2) for the point idx n **/
            int k = FalconnLite::qProbe - 1;
            // MinQue has the size TopK
            while (!minQueTopQ.empty())
            {
                IFPair ifPair = minQueTopQ.top(); // index is bucketID, value is sumAbsHash
                minQueTopQ.pop();

                // Be aware of the index shift for different repeat
                // m_iIndex is bucketIdx in [0, numBucketsPerTable)
                FalconnLite::matrix_qProbes(k + repeat * FalconnLite::qProbe, n) = ifPair.m_iIndex + repeat * numBucketsPerRepeat;
                k--;
            }



            // iProbe for Falconn++
            for (const auto& ifPair1: vec_topI1)
            {
                int Ri_2D_1st = ifPair1.m_iIndex;
                float fAbsHashValue1 = ifPair1.m_fValue;

                for (const auto& ifPair2: vec_topI2)
                {
                    int R2_2D_2nd = ifPair2.m_iIndex;
                    float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                    //We have 2D * 2D buckets (i.e. random vectors)
                    int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                    // assert(iBucketIndex < vectorMinQue_TopM.size());

                    omp_set_lock(&locks[iBucketIndex % NUM_LOCKS]);

                    // Note: This implementation is used for controlling the size of dense bucket.

                    if ((int)vectorMinQue_TopM[iBucketIndex].size() < FalconnLite::top_m)
                        vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);

                    else if (fAbsSumHash > vectorMinQue_TopM[iBucketIndex].top().m_fValue)
                    {
                        vectorMinQue_TopM[iBucketIndex].pop();
                        vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);
                    }

                    // Push all points into the bucket
                    // FalconnLite::vec2D_Buckets[iBucketIndex].push_back(n);

                    omp_unset_lock(&locks[iBucketIndex % NUM_LOCKS]);
                }
            }
        }

        // Extract top-M for each bucketIdx - Falconn++
#pragma omp parallel for
        for (int b = 0; b < numBucketsPerRepeat; ++b)
        {
            // bucket-idx shift for different repeat
            int new_bucketIdx = b + repeat * numBucketsPerRepeat;
            int k = (int)vectorMinQue_TopM[b].size();

            FalconnLite::vec2D_Buckets[new_bucketIdx] = IVector(k);

            while (!vectorMinQue_TopM[b].empty())
            {
                // Be aware of the index shift for different repeat
                FalconnLite::vec2D_Buckets[new_bucketIdx][k-1] = vectorMinQue_TopM[b].top().m_iIndex; // add into the end of vector
                vectorMinQue_TopM[b].pop();
                k--;
            }
        }

        // int iTotalPoints = 0;
        // for (int i = 0; i < FalconnLite::vec2D_Buckets.size(); ++i)
        //     iTotalPoints += FalconnLite::vec2D_Buckets[i].size();
        // cout << "Total points in all buckets: " << iTotalPoints << endl;
    }

    // Destroy locks for Falconn++
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; ++i) {
        omp_destroy_lock(&locks[i]);
    }

    // Note: Should not clear when running testcase
    // FalconnLite::matrix_X.resize(0, 0);
    // FalconnLite::matrix_R.resize(0, 0);


    // print out the vec2D
    // {
    //     std::string fname = "vec2D_buckets_q" + std::to_string(FalconnLite::qProbe) + ".txt";
    //     std::ofstream ofs(fname);
    //     for (size_t b = 0; b < FalconnLite::vec2D_Buckets.size(); ++b) {
    //         ofs << b << " " << FalconnLite::vec2D_Buckets[b].size();
    //         for (int t = 0; t < std::min((int)FalconnLite::vec2D_Buckets[b].size(), 5); ++t)
    //             ofs << " " << FalconnLite::vec2D_Buckets[b][t];
    //         ofs << "\n";
    //     }
    // }
    // print out each repeat bucket
    for (int repeat = 0; repeat < n_repeats; ++repeat)
    {
        // ... 原来的桶生成代码 ...

        // 输出当前 repeat 的桶内容
        std::string fname = "vec2D_buckets_repeat_" + std::to_string(repeat) + ".txt";
        std::ofstream ofs(fname);
        if (!ofs.is_open()) {
            std::cerr << "Failed to open file " << fname << std::endl;
            continue;
        }

        int numBuckets = numBucketsPerRepeat; // 每个 repeat 的桶数
        for (int b = 0; b < numBuckets; ++b) {
            int bucketIdx = b + repeat * numBucketsPerRepeat;
            ofs << bucketIdx << " " << FalconnLite::vec2D_Buckets[bucketIdx].size();
            for (int t = 0; t < (int)FalconnLite::vec2D_Buckets[bucketIdx].size(); ++t)
                ofs << " " << FalconnLite::vec2D_Buckets[bucketIdx][t];
            ofs << "\n";
        }
        ofs.close();
    }

}

void FalconnLite::fht_Index2_asym_centering_repeat(int n_repeats)
{
    // Compute the data center
    VectorXf vecCenter = VectorXf::Zero(FalconnLite::n_features);
    for (int n = 0; n < FalconnLite::n_points; n++)
        vecCenter += FalconnLite::matrix_X.row(n);
    vecCenter /= FalconnLite::n_points;


    /** Global parameter **/

    int iFourierEmbed_D = FalconnLite::ker_n_features / 2; // This is becase we need cos() and sin()
    int numBucketsPerRepeat = 4 * FalconnLite::n_proj * FalconnLite::n_proj;
    int num2D = 2 * FalconnLite::n_proj;
    int log2Project = log2(FalconnLite::fhtDim);

    FalconnLite::matrix_qProbes = MatrixXi::Zero(FalconnLite::qProbe * n_repeats, FalconnLite::n_points); // the first topK is for close, the second topK is for far away
    FalconnLite::vec2D_Buckets = vector<IVector> (numBucketsPerRepeat * n_repeats);

    // Note: If NUM_LOCKS is large, we might not have enough stack memory if using array
    // if D = 128 = 2^7, then numBuckets = 2^16 = 65536. We aim at 256 KB memory for locks
    // 16K locks is good for million-point data set though it is not good for small data sets.
    constexpr size_t NUM_LOCKS = 16384;
    vector<omp_lock_t> locks(NUM_LOCKS); // NUM_LOCK = 16K locks = only 256 KB

    // Initialize locks
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; i++) {
        omp_init_lock(&locks[i]);
    }

    // NOTE: n_repeats should not be large, e.g., 5 or 10
    // For each repeat, we have different random vectors
    for (int repeat = 0; repeat < n_repeats; ++repeat)
    {
        // See: https://github.com/hichamjanati/srf/blob/master/RFF-I.ipynb
        if (FalconnLite::distance == "L1")
            FalconnLite::matrix_R = cauchyGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // K(x, y) = exp(-gamma * L1_dist(X, y))) where gamma = 1/sigma
        else if (FalconnLite::distance == "L2")
            FalconnLite::matrix_R = gaussGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // std = 1/sigma, K(x, y) = exp(-gamma * L2_dist^2(X, y))) where gamma = 1/2 sigma^2

        vector<priority_queue< IFPair, vector<IFPair>, greater<> >> vectorMinQue_TopM(numBucketsPerRepeat);
        bitHD3Generator2(FalconnLite::fhtDim * FalconnLite::n_rotate, FalconnLite::seed, FalconnLite::bitHD1, FalconnLite::bitHD2);

        /**
        Parallel for each the point Xi: (1) Compute and store dot product, and (2) Extract top-k close/far random vectors
        **/
#pragma omp parallel for
        for (int n = 0; n < FalconnLite::n_points; ++n)
        {
            /**
            Random embedding
            **/
            VectorXf vecX = FalconnLite::matrix_X.row(n); // centering if needed
            VectorXf vecEmbed = VectorXf::Zero(FalconnLite::ker_n_features); // sOptics::ker_n_features >= D

            VectorXf vecX_centered = FalconnLite::matrix_X.row(n) - vecCenter; // centering if needed
            VectorXf vecEmbed_centered = VectorXf::Zero(FalconnLite::ker_n_features); // sOptics::ker_n_features >= D

            /// must ensure ker_n_features = n_features on Cosine
            if (FalconnLite::distance == "Cosine") {
                vecEmbed.segment(0, FalconnLite::n_features) = vecX;
                vecEmbed_centered.segment(0, FalconnLite::n_features) = vecX_centered;
            }
            else if ((FalconnLite::distance == "L1") || (FalconnLite::distance == "L2"))
            {
                VectorXf vecProject = FalconnLite::matrix_R * vecX;
                vecEmbed.segment(0, iFourierEmbed_D) = vecProject.array().cos();
                vecEmbed.segment(iFourierEmbed_D, iFourierEmbed_D) = vecProject.array().sin(); // start from iEmbbed, copy iEmbed elements

                vecProject = FalconnLite::matrix_R * vecX_centered;
                vecEmbed_centered.segment(0, iFourierEmbed_D) = vecProject.array().cos();
                vecEmbed_centered.segment(iFourierEmbed_D, iFourierEmbed_D) = vecProject.array().sin(); // start from iEmbbed, copy iEmbed elements
            }

            else if (FalconnLite::distance == "Chi2") {
                embedChi2(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
                embedChi2(vecX_centered, vecEmbed_centered, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
            }

            else if (FalconnLite::distance == "JS") {
                embedJS(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
                embedJS(vecX_centered, vecEmbed_centered, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
            }


            /**
            Random projection
            **/

            VectorXf rotatedX1 = VectorXf::Zero(FalconnLite::fhtDim); // NUM_PROJECT > PARAM_KERNEL_EMBED_D
            VectorXf rotatedX1_centered = VectorXf::Zero(FalconnLite::fhtDim); // NUM_PROJECT > PARAM_KERNEL_EMBED_D

            rotatedX1.segment(0, FalconnLite::ker_n_features) = vecEmbed;
            rotatedX1_centered.segment(0, FalconnLite::ker_n_features) = vecEmbed_centered;

            VectorXf rotatedX2 = rotatedX1;
            VectorXf rotatedX2_centered = rotatedX1_centered;

            for (int r = 0; r < FalconnLite::n_rotate; ++r)
            {
                // Component-wise multiplication with a random sign
                for (int d = 0; d < FalconnLite::fhtDim; ++d)
                {
                    rotatedX1(d) *= (2 * static_cast<float>(FalconnLite::bitHD1[r * FalconnLite::fhtDim + d]) - 1);
                    rotatedX2(d) *= (2 * static_cast<float>(FalconnLite::bitHD2[r * FalconnLite::fhtDim + d]) - 1);

                    rotatedX1_centered(d) *= (2 * static_cast<float>(FalconnLite::bitHD1[r * FalconnLite::fhtDim + d]) - 1);
                    rotatedX2_centered(d) *= (2 * static_cast<float>(FalconnLite::bitHD2[r * FalconnLite::fhtDim + d]) - 1);
                }

                // Multiple with Hadamard matrix by calling FWHT transform
                fht_float(rotatedX1.data(), log2Project);
                fht_float(rotatedX2.data(), log2Project);

                fht_float(rotatedX1_centered.data(), log2Project);
                fht_float(rotatedX2_centered.data(), log2Project);
            }

            // rotateX1 = {x1 * r1, ..., x1 * r_D}
            // rotateX2 = {x1 * s1, ..., x1 * s_D}

            // cout << "We finish random rotating" << endl;

            // This queue is used for finding top-k max hash values and hash index for iProbes on each layer
            priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ1, minQueTopI1; // 1st layer
            priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ2, minQueTopI2; // 2nd layer

            /**
            We use a priority queue to keep top-max abs projection for each repeat
            Always ensure fhtDim >= n_proj
            **/
            for (int r = 0; r < FalconnLite::n_proj; ++r)
            {
                // Note: 1st rotation

                // For qProbe: we use original data
                int iSign = sgn(rotatedX1(r));
                float fAbsHashValue = iSign * rotatedX1(r);

                int Ri_2D = r; // index of random vector in [2D] after consider the sign
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // Be aware the case that n_proj is not 2^(log2Proj)

                // qProbe
                if ((int)minQueTopQ1.size() < FalconnLite::qProbe)
                    minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
                else if (fAbsHashValue > minQueTopQ1.top().m_fValue)
                {
                    minQueTopQ1.pop();
                    minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
                }

                // For iProbe: we use centered data for better performance
                iSign = sgn(rotatedX1_centered(r));
                fAbsHashValue = iSign * rotatedX1_centered(r);

                Ri_2D = r; // index of random vector in [2D] after consider the sign
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // Be aware the case that n_proj is not 2^(log2Proj)

                // iProbe
                if ((int)minQueTopI1.size() < FalconnLite::iProbe)
                    minQueTopI1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
                else if (fAbsHashValue > minQueTopI1.top().m_fValue)
                {
                    minQueTopI1.pop();
                    minQueTopI1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
                }

                // Note:  2nd rotation

                // For qProbe: we use original data
                iSign = sgn(rotatedX2(r));
                fAbsHashValue = iSign * rotatedX2(r);

                Ri_2D = r;
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // set bit at position log2(D)

                // qProbe
                if ((int)minQueTopQ2.size() < FalconnLite::qProbe)
                    minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
                else if (fAbsHashValue > minQueTopQ2.top().m_fValue)
                {
                    minQueTopQ2.pop();
                    minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
                }

                // For iProbe: we use centered data for better performance
                iSign = sgn(rotatedX2_centered(r));
                fAbsHashValue = iSign * rotatedX2_centered(r);

                Ri_2D = r;
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // set bit at position log2(D)

                if ((int)minQueTopI2.size() < FalconnLite::iProbe)
                    minQueTopI2.emplace(Ri_2D, fAbsHashValue);
                else if (fAbsHashValue > minQueTopI2.top().m_fValue)
                {
                    minQueTopI2.pop();
                    minQueTopI2.emplace(Ri_2D, fAbsHashValue);
                }
            }

            // Convert to vector
            vector<IFPair> vec_topQ1(FalconnLite::qProbe), vec_topQ2(FalconnLite::qProbe);
            vector<IFPair> vec_topI1(FalconnLite::iProbe), vec_topI2(FalconnLite::iProbe);

            // qProbe
            for (int k = FalconnLite::qProbe - 1; k >= 0; --k)
            {
                vec_topQ1[k] = minQueTopQ1.top();
                minQueTopQ1.pop();

                vec_topQ2[k] = minQueTopQ2.top();
                minQueTopQ2.pop();
            }

            // iProbe-Falconn++
            for (int p = FalconnLite::iProbe - 1; p >= 0; --p)
            {
                vec_topI1[p] = minQueTopI1.top();
                minQueTopI1.pop();

                vec_topI2[p] = minQueTopI2.top();
                minQueTopI2.pop();
            }

            assert(vec_topQ1.size() == FalconnLite::qProbe);
            assert(vec_topQ2.size() == FalconnLite::qProbe);
            assert(vec_topI1.size() == FalconnLite::iProbe);
            assert(vec_topI2.size() == FalconnLite::iProbe);

            // cout << "We finish extracting top-qProbe and top-iProbe for 2 layers." << endl;
            // count << "Now we combine them together." << endl;

            /**
            Use minQue to find the top-qProbe over 2 layers via sum of 2 estimators
            vec1 and vec2 are already sorted, and has length of sOptics::topK
            Note: Heuristic: We consider top-k * top-k pairs for Top-K, and top-p * top-p pairs for Top-M
            Note: We cannot check all combinations due to significant cost
            **/
            priority_queue<IFPair, vector<IFPair>, greater<>> minQueTopQ;

            // qProbe
            for (const auto& ifPair1: vec_topQ1)
            {
                int Ri_2D_1st = ifPair1.m_iIndex;
                float fAbsHashValue1 = ifPair1.m_fValue;

                for (const auto& ifPair2: vec_topQ2)
                {
                    int R2_2D_2nd = ifPair2.m_iIndex;
                    float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                    //We have 2D * 2D buckets (i.e. random vectors)
                    int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                    // assert(iBucketIndex < vectorMinQue_TopM.size());

                    // Push all points into the bucket
                    if ((int)minQueTopQ.size() < FalconnLite::qProbe)
                        minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                    else if (fAbsSumHash > minQueTopQ.top().m_fValue)
                    {
                        minQueTopQ.pop();
                        minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                    }
                }
            }

            /** Extract the random vector idx (in the form r1 * 2D + r2) for the point idx n **/
            int k = FalconnLite::qProbe - 1;
            // MinQue has the size TopK
            while (!minQueTopQ.empty())
            {
                IFPair ifPair = minQueTopQ.top(); // index is bucketID, value is sumAbsHash
                minQueTopQ.pop();

                // Be aware of the index shift for different repeat
                // m_iIndex is bucketIdx in [0, numBucketsPerTable)
                FalconnLite::matrix_qProbes(k + repeat * FalconnLite::qProbe, n) = ifPair.m_iIndex + repeat * numBucketsPerRepeat;
                k--;
            }



            // iProbe for Falconn++
            for (const auto& ifPair1: vec_topI1)
            {
                int Ri_2D_1st = ifPair1.m_iIndex;
                float fAbsHashValue1 = ifPair1.m_fValue;

                for (const auto& ifPair2: vec_topI2)
                {
                    int R2_2D_2nd = ifPair2.m_iIndex;
                    float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                    //We have 2D * 2D buckets (i.e. random vectors)
                    int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                    // assert(iBucketIndex < vectorMinQue_TopM.size());

                    omp_set_lock(&locks[iBucketIndex % NUM_LOCKS]);

                    // Note: This implementation is used for controlling the size of dense bucket.

                    if ((int)vectorMinQue_TopM[iBucketIndex].size() < FalconnLite::top_m)
                        vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);

                    else if (fAbsSumHash > vectorMinQue_TopM[iBucketIndex].top().m_fValue)
                    {
                        vectorMinQue_TopM[iBucketIndex].pop();
                        vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);
                    }

                    // Push all points into the bucket
                    // FalconnLite::vec2D_Buckets[iBucketIndex].push_back(n);

                    omp_unset_lock(&locks[iBucketIndex % NUM_LOCKS]);
                }
            }
        }

        // Extract top-M for each bucketIdx - Falconn++
#pragma omp parallel for
        for (int b = 0; b < numBucketsPerRepeat; ++b)
        {
            // bucket-idx shift for different repeat
            int new_bucketIdx = b + repeat * numBucketsPerRepeat;
            int k = (int)vectorMinQue_TopM[b].size();

            FalconnLite::vec2D_Buckets[new_bucketIdx] = IVector(k);

            while (!vectorMinQue_TopM[b].empty())
            {
                // Be aware of the index shift for different repeat
                FalconnLite::vec2D_Buckets[new_bucketIdx][k-1] = vectorMinQue_TopM[b].top().m_iIndex; // add into the end of vector
                vectorMinQue_TopM[b].pop();
                k--;
            }
        }

        // int iTotalPoints = 0;
        // for (int i = 0; i < FalconnLite::vec2D_Buckets.size(); ++i)
        //     iTotalPoints += FalconnLite::vec2D_Buckets[i].size();
        // cout << "Total points in all buckets: " << iTotalPoints << endl;
    }

    // Destroy locks for Falconn++
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; ++i) {
        omp_destroy_lock(&locks[i]);
    }

    // Note: Should not clear when running testcase
    // FalconnLite::matrix_X.resize(0, 0);
    // FalconnLite::matrix_R.resize(0, 0);

}

void FalconnLite::fht_pairIndex2_repeat(int n_repeats)
{
    // Store centering vector
    VectorXf vecCenter = VectorXf::Zero(FalconnLite::n_features);
    if (FalconnLite::centering) {
        for (int n = 0; n < FalconnLite::n_points; n++)
            vecCenter += FalconnLite::matrix_X.row(n);

        vecCenter /= FalconnLite::n_points;
    }

    /** Global parameter **/

    int iFourierEmbed_D = FalconnLite::ker_n_features / 2; // This is becase we need cos() and sin()
    int numBucketsPerRepeat = 4 * FalconnLite::n_proj * FalconnLite::n_proj;
    int num2D = 2 * FalconnLite::n_proj;
    int log2Project = log2(FalconnLite::fhtDim);

    FalconnLite::matrix_qProbes = MatrixXi::Zero(FalconnLite::qProbe * n_repeats, FalconnLite::n_points); // the first topK is for close, the second topK is for far away
    FalconnLite::vec2D_Pair_Buckets = vector<vector<IFPair>> (numBucketsPerRepeat * n_repeats);

    // Note: If NUM_LOCKS is large, we might not have enough stack memory if using array
    // if D = 128 = 2^7, then numBuckets = 2^16 = 65536. We aim at 256 KB memory for locks
    // 16K locks is good for million-point data set though it is not good for small data sets.
    constexpr size_t NUM_LOCKS = 16384;
    vector<omp_lock_t> locks(NUM_LOCKS); // NUM_LOCK = 16K locks = only 256 KB

    // Initialize locks
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; i++) {
        omp_init_lock(&locks[i]);
    }

    for (int repeat = 0; repeat < n_repeats; ++repeat)
    {
        // See: https://github.com/hichamjanati/srf/blob/master/RFF-I.ipynb
        if (FalconnLite::distance == "L1")
            FalconnLite::matrix_R = cauchyGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // K(x, y) = exp(-gamma * L1_dist(X, y))) where gamma = 1/sigma
        else if (FalconnLite::distance == "L2")
            FalconnLite::matrix_R = gaussGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // std = 1/sigma, K(x, y) = exp(-gamma * L2_dist^2(X, y))) where gamma = 1/2 sigma^2

        vector<priority_queue< IFPair, vector<IFPair>, greater<> >> vectorMinQue_TopM(numBucketsPerRepeat);
        bitHD3Generator2(FalconnLite::fhtDim * FalconnLite::n_rotate, FalconnLite::seed, FalconnLite::bitHD1, FalconnLite::bitHD2);

        /**
        Parallel for each the point Xi: (1) Compute and store dot product, and (2) Extract top-k close/far random vectors
        **/
#pragma omp parallel for
        for (int n = 0; n < FalconnLite::n_points; ++n)
        {
            /**
            Random embedding
            **/
            VectorXf vecX = FalconnLite::matrix_X.row(n);
            if (FalconnLite::centering)
                vecX -= vecCenter;

            VectorXf vecEmbed = VectorXf::Zero(FalconnLite::ker_n_features); // sOptics::ker_n_features >= D

            /// must ensure ker_n_features = n_features on Cosine
            if (FalconnLite::distance == "Cosine")
                vecEmbed.segment(0, FalconnLite::n_features) = vecX;
            else if ((FalconnLite::distance == "L1") || (FalconnLite::distance == "L2"))
            {
                VectorXf vecProject = FalconnLite::matrix_R * vecX;
                vecEmbed.segment(0, iFourierEmbed_D) = vecProject.array().cos();
                vecEmbed.segment(iFourierEmbed_D, iFourierEmbed_D) = vecProject.array().sin(); // start from iEmbbed, copy iEmbed elements
            }
            else if (FalconnLite::distance == "Chi2")
                embedChi2(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
            else if (FalconnLite::distance == "JS")
                embedJS(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);

            /**
            Random projection
            **/

            VectorXf rotatedX1 = VectorXf::Zero(FalconnLite::fhtDim); // NUM_PROJECT > PARAM_KERNEL_EMBED_D
            rotatedX1.segment(0, FalconnLite::ker_n_features) = vecEmbed;

            VectorXf rotatedX2 = rotatedX1;

            for (int r = 0; r < FalconnLite::n_rotate; ++r)
            {
                // Component-wise multiplication with a random sign
                for (int d = 0; d < FalconnLite::fhtDim; ++d)
                {
                    rotatedX1(d) *= (2 * static_cast<float>(FalconnLite::bitHD1[r * FalconnLite::fhtDim + d]) - 1);
                    rotatedX2(d) *= (2 * static_cast<float>(FalconnLite::bitHD2[r * FalconnLite::fhtDim + d]) - 1);
                }

                // Multiple with Hadamard matrix by calling FWHT transform
                fht_float(rotatedX1.data(), log2Project);
                fht_float(rotatedX2.data(), log2Project);
            }

            // rotateX1 = {x1 * r1, ..., x1 * r_D}
            // rotateX2 = {x1 * s1, ..., x1 * s_D}

            // cout << "We finish random rotating" << endl;

            // This queue is used for finding top-k max hash values and hash index for iProbes on each layer
            priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ1, minQueTopI1; // 1st layer
            priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ2, minQueTopI2; // 2nd layer

            /**
            We use a priority queue to keep top-max abs projection for each repeat
            Always ensure fhtDim >= n_proj
            **/
            for (int r = 0; r < FalconnLite::n_proj; ++r)
            {
                // 1st rotation
                int iSign = sgn(rotatedX1(r));
                float fAbsHashValue = iSign * rotatedX1(r);

                int Ri_2D = r; // index of random vector in [2D] after consider the sign
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // Be aware the case that n_proj is not 2^(log2Proj)

                // qProbe
                if ((int)minQueTopQ1.size() < FalconnLite::qProbe)
                    minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
                else if (fAbsHashValue > minQueTopQ1.top().m_fValue)
                {
                    minQueTopQ1.pop();
                    minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
                }

                // iProbe
                if ((int)minQueTopI1.size() < FalconnLite::iProbe)
                    minQueTopI1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
                else if (fAbsHashValue > minQueTopI1.top().m_fValue)
                {
                    minQueTopI1.pop();
                    minQueTopI1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
                }

                // 2nd rotation
                iSign = sgn(rotatedX2(r));
                fAbsHashValue = iSign * rotatedX2(r);

                Ri_2D = r;
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // set bit at position log2(D)

                // qProbe
                if ((int)minQueTopQ2.size() < FalconnLite::qProbe)
                    minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
                else if (fAbsHashValue > minQueTopQ2.top().m_fValue)
                {
                    minQueTopQ2.pop();
                    minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
                }

                // iProbe (top-iProbe random vector closest to Xn)
                if ((int)minQueTopI2.size() < FalconnLite::iProbe)
                    minQueTopI2.emplace(Ri_2D, fAbsHashValue);
                else if (fAbsHashValue > minQueTopI2.top().m_fValue)
                {
                    minQueTopI2.pop();
                    minQueTopI2.emplace(Ri_2D, fAbsHashValue);
                }
            }

            // Convert to vector
            vector<IFPair> vec_topQ1(FalconnLite::qProbe), vec_topQ2(FalconnLite::qProbe);
            vector<IFPair> vec_topI1(FalconnLite::iProbe), vec_topI2(FalconnLite::iProbe);

            // qProbe
            for (int k = FalconnLite::qProbe - 1; k >= 0; --k)
            {
                vec_topQ1[k] = minQueTopQ1.top();
                minQueTopQ1.pop();

                vec_topQ2[k] = minQueTopQ2.top();
                minQueTopQ2.pop();
            }

            // iProbe-Falconn++
            for (int p = FalconnLite::iProbe - 1; p >= 0; --p)
            {
                vec_topI1[p] = minQueTopI1.top();
                minQueTopI1.pop();

                vec_topI2[p] = minQueTopI2.top();
                minQueTopI2.pop();
            }

            assert(vec_topQ1.size() == FalconnLite::qProbe);
            assert(vec_topQ2.size() == FalconnLite::qProbe);
            assert(vec_topI1.size() == FalconnLite::iProbe);
            assert(vec_topI2.size() == FalconnLite::iProbe);

            // cout << "We finish extracting top-qProbe and top-iProbe for 2 layers." << endl;
            // count << "Now we combine them together." << endl;

            /**
            Use minQue to find the top-qProbe over 2 layers via sum of 2 estimators
            vec1 and vec2 are already sorted, and has length of sOptics::topK
            Note: Heuristic: We consider top-k * top-k pairs for Top-K, and top-p * top-p pairs for Top-M
            Note: We cannot check all combinations due to significant cost
            **/
            priority_queue<IFPair, vector<IFPair>, greater<>> minQueTopQ;

            // qProbe
            for (const auto& ifPair1: vec_topQ1)
            {
                int Ri_2D_1st = ifPair1.m_iIndex;
                float fAbsHashValue1 = ifPair1.m_fValue;

                for (const auto& ifPair2: vec_topQ2)
                {
                    int R2_2D_2nd = ifPair2.m_iIndex;
                    float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                    //We have 2D * 2D buckets (i.e. random vectors)
                    int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                    // assert(iBucketIndex < vectorMinQue_TopM.size());

                    // Push all points into the bucket
                    if ((int)minQueTopQ.size() < FalconnLite::qProbe)
                        minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                    else if (fAbsSumHash > minQueTopQ.top().m_fValue)
                    {
                        minQueTopQ.pop();
                        minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                    }
                }
            }

            /** Extract the random vector idx (in the form r1 * 2D + r2) for the point idx n **/
            int k = FalconnLite::qProbe - 1;
            // MinQue has the size TopK
            while (!minQueTopQ.empty())
            {
                IFPair ifPair = minQueTopQ.top(); // index is bucketID, value is sumAbsHash
                minQueTopQ.pop();

                // Be aware of the index shift for different repeat
                // m_iIndex is bucketIdx in [0, numBucketsPerTable)
                FalconnLite::matrix_qProbes(k + repeat * FalconnLite::qProbe, n) = ifPair.m_iIndex + repeat * numBucketsPerRepeat;
                k--;
            }



            // iProbe for Falconn++
            for (const auto& ifPair1: vec_topI1)
            {
                int Ri_2D_1st = ifPair1.m_iIndex;
                float fAbsHashValue1 = ifPair1.m_fValue;

                for (const auto& ifPair2: vec_topI2)
                {
                    int R2_2D_2nd = ifPair2.m_iIndex;
                    float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                    //We have 2D * 2D buckets (i.e. random vectors)
                    int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                    // assert(iBucketIndex < vectorMinQue_TopM.size());

                    omp_set_lock(&locks[iBucketIndex % NUM_LOCKS]);

                    // Note: This implementation is used for controlling the size of dense bucket.

                    if ((int)vectorMinQue_TopM[iBucketIndex].size() < FalconnLite::top_m)
                        vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);

                    else if (fAbsSumHash > vectorMinQue_TopM[iBucketIndex].top().m_fValue)
                    {
                        vectorMinQue_TopM[iBucketIndex].pop();
                        vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);
                    }

                    // Push all points into the bucket
                    // FalconnLite::vec2D_Buckets[iBucketIndex].push_back(n);

                    omp_unset_lock(&locks[iBucketIndex % NUM_LOCKS]);
                }
            }
        }

        // Extract top-M for each bucketIdx - Falconn++
#pragma omp parallel for
        for (int b = 0; b < numBucketsPerRepeat; ++b)
        {
            // bucket-idx shift for different repeat
            int new_bucketIdx = b + repeat * numBucketsPerRepeat;
            int k = (int)vectorMinQue_TopM[b].size();


            FalconnLite::vec2D_Pair_Buckets[new_bucketIdx] = vector<IFPair>(k);

            while (!vectorMinQue_TopM[b].empty())
            {
                // Be aware of the index shift for different repeat
                FalconnLite::vec2D_Pair_Buckets[new_bucketIdx][k-1] = vectorMinQue_TopM[b].top(); // add into the end of vector
                vectorMinQue_TopM[b].pop();
                k--;
            }
        }

        // int iTotalPoints = 0;
        // for (int i = 0; i < FalconnLite::vec2D_Pair_Buckets.size(); ++i)
        //     iTotalPoints += FalconnLite::vec2D_Pair_Buckets[i].size();
        // cout << "Total points in all buckets: " << iTotalPoints << endl;

        /**
        For each random vector, extract top-m closest data points
        **/

        // Extract top-M for each bucketIdx - Falconn++
    //#pragma omp parallel for
    //    for (int b = 0; b < numBuckets; ++b)
    //    {
    //        int m = (int)vectorMinQue_bLimit[b].size();
    //
    //        assert(m <= FalconnLite::bucketLimit);
    //
    //        while (!vectorMinQue_bLimit[b].empty())
    //        {
    //            FalconnLite::vec2D_Buckets[b].push_back(vectorMinQue_bLimit[b].top().m_iIndex);
    //            vectorMinQue_bLimit[b].pop();
    //
    //            m--;
    //        }
    //    }


    }

    // Destroy locks for Falconn++
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; ++i) {
        omp_destroy_lock(&locks[i]);
    }

    // Note: Should not clear when running testcase
    // FalconnLite::matrix_X.resize(0, 0);
    // FalconnLite::matrix_R.resize(0, 0);

}

void FalconnLite::fht_pairIndex2_asym_centering_repeat(int n_repeats)
{
    // Compute the data center
    VectorXf vecCenter = VectorXf::Zero(FalconnLite::n_features);
    for (int n = 0; n < FalconnLite::n_points; n++)
        vecCenter += FalconnLite::matrix_X.row(n);
    vecCenter /= FalconnLite::n_points;

    /** Global parameter **/

    int iFourierEmbed_D = FalconnLite::ker_n_features / 2; // This is becase we need cos() and sin()
    int numBucketsPerRepeat = 4 * FalconnLite::n_proj * FalconnLite::n_proj;
    int num2D = 2 * FalconnLite::n_proj;
    int log2Project = log2(FalconnLite::fhtDim);

    FalconnLite::matrix_qProbes = MatrixXi::Zero(FalconnLite::qProbe * n_repeats, FalconnLite::n_points); // the first topK is for close, the second topK is for far away
    FalconnLite::vec2D_Pair_Buckets = vector<vector<IFPair> > (numBucketsPerRepeat * n_repeats);

    // Note: If NUM_LOCKS is large, we might not have enough stack memory if using array
    // if D = 128 = 2^7, then numBuckets = 2^16 = 65536. We aim at 256 KB memory for locks
    // 16K locks is good for million-point data set though it is not good for small data sets.
    constexpr size_t NUM_LOCKS = 16384;
    vector<omp_lock_t> locks(NUM_LOCKS); // NUM_LOCK = 16K locks = only 256 KB

    // Initialize locks
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; i++) {
        omp_init_lock(&locks[i]);
    }

    // NOTE: n_repeats should not be large, e.g., 5 or 10
    // For each repeat, we have different random vectors
    for (int repeat = 0; repeat < n_repeats; ++repeat)
    {
        // See: https://github.com/hichamjanati/srf/blob/master/RFF-I.ipynb
        if (FalconnLite::distance == "L1")
            FalconnLite::matrix_R = cauchyGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // K(x, y) = exp(-gamma * L1_dist(X, y))) where gamma = 1/sigma
        else if (FalconnLite::distance == "L2")
            FalconnLite::matrix_R = gaussGenerator(iFourierEmbed_D, FalconnLite::n_features, 0, 1.0 / FalconnLite::ker_sigma, FalconnLite::seed); // std = 1/sigma, K(x, y) = exp(-gamma * L2_dist^2(X, y))) where gamma = 1/2 sigma^2

        vector<priority_queue< IFPair, vector<IFPair>, greater<> >> vectorMinQue_TopM(numBucketsPerRepeat);
        bitHD3Generator2(FalconnLite::fhtDim * FalconnLite::n_rotate, FalconnLite::seed, FalconnLite::bitHD1, FalconnLite::bitHD2);

        /**
        Parallel for each the point Xi: (1) Compute and store dot product, and (2) Extract top-k close/far random vectors
        **/
#pragma omp parallel for
        for (int n = 0; n < FalconnLite::n_points; ++n)
        {
            /**
            Random embedding
            **/
            VectorXf vecX = FalconnLite::matrix_X.row(n); // centering if needed
            VectorXf vecEmbed = VectorXf::Zero(FalconnLite::ker_n_features); // sOptics::ker_n_features >= D

            VectorXf vecX_centered = FalconnLite::matrix_X.row(n) - vecCenter; // centering if needed
            VectorXf vecEmbed_centered = VectorXf::Zero(FalconnLite::ker_n_features); // sOptics::ker_n_features >= D

            /// must ensure ker_n_features = n_features on Cosine
            if (FalconnLite::distance == "Cosine") {
                vecEmbed.segment(0, FalconnLite::n_features) = vecX;
                vecEmbed_centered.segment(0, FalconnLite::n_features) = vecX_centered;
            }
            else if ((FalconnLite::distance == "L1") || (FalconnLite::distance == "L2"))
            {
                VectorXf vecProject = FalconnLite::matrix_R * vecX;
                vecEmbed.segment(0, iFourierEmbed_D) = vecProject.array().cos();
                vecEmbed.segment(iFourierEmbed_D, iFourierEmbed_D) = vecProject.array().sin(); // start from iEmbbed, copy iEmbed elements

                vecProject = FalconnLite::matrix_R * vecX_centered;
                vecEmbed_centered.segment(0, iFourierEmbed_D) = vecProject.array().cos();
                vecEmbed_centered.segment(iFourierEmbed_D, iFourierEmbed_D) = vecProject.array().sin(); // start from iEmbbed, copy iEmbed elements
            }

            else if (FalconnLite::distance == "Chi2") {
                embedChi2(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
                embedChi2(vecX_centered, vecEmbed_centered, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
            }

            else if (FalconnLite::distance == "JS") {
                embedJS(vecX, vecEmbed, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
                embedJS(vecX_centered, vecEmbed_centered, FalconnLite::ker_n_features, FalconnLite::n_features, FalconnLite::ker_intervalSampling);
            }


            /**
            Random projection
            **/

            VectorXf rotatedX1 = VectorXf::Zero(FalconnLite::fhtDim); // NUM_PROJECT > PARAM_KERNEL_EMBED_D
            VectorXf rotatedX1_centered = VectorXf::Zero(FalconnLite::fhtDim); // NUM_PROJECT > PARAM_KERNEL_EMBED_D

            rotatedX1.segment(0, FalconnLite::ker_n_features) = vecEmbed;
            rotatedX1_centered.segment(0, FalconnLite::ker_n_features) = vecEmbed_centered;

            VectorXf rotatedX2 = rotatedX1;
            VectorXf rotatedX2_centered = rotatedX1_centered;

            for (int r = 0; r < FalconnLite::n_rotate; ++r)
            {
                // Component-wise multiplication with a random sign
                for (int d = 0; d < FalconnLite::fhtDim; ++d)
                {
                    rotatedX1(d) *= (2 * static_cast<float>(FalconnLite::bitHD1[r * FalconnLite::fhtDim + d]) - 1);
                    rotatedX2(d) *= (2 * static_cast<float>(FalconnLite::bitHD2[r * FalconnLite::fhtDim + d]) - 1);

                    rotatedX1_centered(d) *= (2 * static_cast<float>(FalconnLite::bitHD1[r * FalconnLite::fhtDim + d]) - 1);
                    rotatedX2_centered(d) *= (2 * static_cast<float>(FalconnLite::bitHD2[r * FalconnLite::fhtDim + d]) - 1);
                }

                // Multiple with Hadamard matrix by calling FWHT transform
                fht_float(rotatedX1.data(), log2Project);
                fht_float(rotatedX2.data(), log2Project);

                fht_float(rotatedX1_centered.data(), log2Project);
                fht_float(rotatedX2_centered.data(), log2Project);
            }

            // rotateX1 = {x1 * r1, ..., x1 * r_D}
            // rotateX2 = {x1 * s1, ..., x1 * s_D}

            // cout << "We finish random rotating" << endl;

            // This queue is used for finding top-k max hash values and hash index for iProbes on each layer
            priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ1, minQueTopI1; // 1st layer
            priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopQ2, minQueTopI2; // 2nd layer

            /**
            We use a priority queue to keep top-max abs projection for each repeat
            Always ensure fhtDim >= n_proj
            **/
            for (int r = 0; r < FalconnLite::n_proj; ++r)
            {
                // Note: 1st rotation

                // For qProbe: we use original data
                int iSign = sgn(rotatedX1(r));
                float fAbsHashValue = iSign * rotatedX1(r);

                int Ri_2D = r; // index of random vector in [2D] after consider the sign
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // Be aware the case that n_proj is not 2^(log2Proj)

                // qProbe
                if ((int)minQueTopQ1.size() < FalconnLite::qProbe)
                    minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
                else if (fAbsHashValue > minQueTopQ1.top().m_fValue)
                {
                    minQueTopQ1.pop();
                    minQueTopQ1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
                }

                // For iProbe: we use centered data for better performance
                iSign = sgn(rotatedX1_centered(r));
                fAbsHashValue = iSign * rotatedX1_centered(r);

                Ri_2D = r; // index of random vector in [2D] after consider the sign
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // Be aware the case that n_proj is not 2^(log2Proj)

                // iProbe
                if ((int)minQueTopI1.size() < FalconnLite::iProbe)
                    minQueTopI1.emplace(Ri_2D, fAbsHashValue); // emplace is push without creating temp data
                else if (fAbsHashValue > minQueTopI1.top().m_fValue)
                {
                    minQueTopI1.pop();
                    minQueTopI1.emplace(Ri_2D, fAbsHashValue); // No need IFPair()
                }

                // Note:  2nd rotation

                // For qProbe: we use original data
                iSign = sgn(rotatedX2(r));
                fAbsHashValue = iSign * rotatedX2(r);

                Ri_2D = r;
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // set bit at position log2(D)

                // qProbe
                if ((int)minQueTopQ2.size() < FalconnLite::qProbe)
                    minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
                else if (fAbsHashValue > minQueTopQ2.top().m_fValue)
                {
                    minQueTopQ2.pop();
                    minQueTopQ2.emplace(Ri_2D, fAbsHashValue);
                }

                // For iProbe: we use centered data for better performance
                iSign = sgn(rotatedX2_centered(r));
                fAbsHashValue = iSign * rotatedX2_centered(r);

                Ri_2D = r;
                if (iSign < 0)
                    // iBucketIndex |= 1UL << log2Project; // set bit at position log2(D)
                        Ri_2D += FalconnLite::n_proj; // set bit at position log2(D)

                if ((int)minQueTopI2.size() < FalconnLite::iProbe)
                    minQueTopI2.emplace(Ri_2D, fAbsHashValue);
                else if (fAbsHashValue > minQueTopI2.top().m_fValue)
                {
                    minQueTopI2.pop();
                    minQueTopI2.emplace(Ri_2D, fAbsHashValue);
                }
            }

            // Convert to vector
            vector<IFPair> vec_topQ1(FalconnLite::qProbe), vec_topQ2(FalconnLite::qProbe);
            vector<IFPair> vec_topI1(FalconnLite::iProbe), vec_topI2(FalconnLite::iProbe);

            // qProbe
            for (int k = FalconnLite::qProbe - 1; k >= 0; --k)
            {
                vec_topQ1[k] = minQueTopQ1.top();
                minQueTopQ1.pop();

                vec_topQ2[k] = minQueTopQ2.top();
                minQueTopQ2.pop();
            }

            // iProbe-Falconn++
            for (int p = FalconnLite::iProbe - 1; p >= 0; --p)
            {
                vec_topI1[p] = minQueTopI1.top();
                minQueTopI1.pop();

                vec_topI2[p] = minQueTopI2.top();
                minQueTopI2.pop();
            }

            assert(vec_topQ1.size() == FalconnLite::qProbe);
            assert(vec_topQ2.size() == FalconnLite::qProbe);
            assert(vec_topI1.size() == FalconnLite::iProbe);
            assert(vec_topI2.size() == FalconnLite::iProbe);

            // cout << "We finish extracting top-qProbe and top-iProbe for 2 layers." << endl;
            // count << "Now we combine them together." << endl;

            /**
            Use minQue to find the top-qProbe over 2 layers via sum of 2 estimators
            vec1 and vec2 are already sorted, and has length of sOptics::topK
            Note: Heuristic: We consider top-k * top-k pairs for Top-K, and top-p * top-p pairs for Top-M
            Note: We cannot check all combinations due to significant cost
            **/
            priority_queue<IFPair, vector<IFPair>, greater<>> minQueTopQ;

            // qProbe
            for (const auto& ifPair1: vec_topQ1)
            {
                int Ri_2D_1st = ifPair1.m_iIndex;
                float fAbsHashValue1 = ifPair1.m_fValue;

                for (const auto& ifPair2: vec_topQ2)
                {
                    int R2_2D_2nd = ifPair2.m_iIndex;
                    float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                    //We have 2D * 2D buckets (i.e. random vectors)
                    int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                    // assert(iBucketIndex < vectorMinQue_TopM.size());

                    // Push all points into the bucket
                    if ((int)minQueTopQ.size() < FalconnLite::qProbe)
                        minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                    else if (fAbsSumHash > minQueTopQ.top().m_fValue)
                    {
                        minQueTopQ.pop();
                        minQueTopQ.emplace(iBucketIndex, fAbsSumHash);
                    }
                }
            }

            /** Extract the random vector idx (in the form r1 * 2D + r2) for the point idx n **/
            int k = FalconnLite::qProbe - 1;
            // MinQue has the size TopK
            while (!minQueTopQ.empty())
            {
                IFPair ifPair = minQueTopQ.top(); // index is bucketID, value is sumAbsHash
                minQueTopQ.pop();

                // Be aware of the index shift for different repeat
                // m_iIndex is bucketIdx in [0, numBucketsPerTable)
                FalconnLite::matrix_qProbes(k + repeat * FalconnLite::qProbe, n) = ifPair.m_iIndex + repeat * numBucketsPerRepeat;
                k--;
            }



            // iProbe for Falconn++
            for (const auto& ifPair1: vec_topI1)
            {
                int Ri_2D_1st = ifPair1.m_iIndex;
                float fAbsHashValue1 = ifPair1.m_fValue;

                for (const auto& ifPair2: vec_topI2)
                {
                    int R2_2D_2nd = ifPair2.m_iIndex;
                    float fAbsSumHash = ifPair2.m_fValue + fAbsHashValue1; // sum of 2 estimators

                    //We have 2D * 2D buckets (i.e. random vectors)
                    int iBucketIndex = Ri_2D_1st * num2D + R2_2D_2nd; // (totally we have 2D * 2D buckets)

                    // assert(iBucketIndex < vectorMinQue_TopM.size());

                    omp_set_lock(&locks[iBucketIndex % NUM_LOCKS]);

                    // Note: This implementation is used for controlling the size of dense bucket.

                    if ((int)vectorMinQue_TopM[iBucketIndex].size() < FalconnLite::top_m)
                        vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);

                    else if (fAbsSumHash > vectorMinQue_TopM[iBucketIndex].top().m_fValue)
                    {
                        vectorMinQue_TopM[iBucketIndex].pop();
                        vectorMinQue_TopM[iBucketIndex].emplace(n, fAbsSumHash);
                    }

                    // Push all points into the bucket
                    // FalconnLite::vec2D_Buckets[iBucketIndex].push_back(n);

                    omp_unset_lock(&locks[iBucketIndex % NUM_LOCKS]);
                }
            }
        }

        // Extract top-M for each bucketIdx - Falconn++
#pragma omp parallel for
        for (int b = 0; b < numBucketsPerRepeat; ++b)
        {
            // bucket-idx shift for different repeat
            int new_bucketIdx = b + repeat * numBucketsPerRepeat;
            int k = (int)vectorMinQue_TopM[b].size();

            FalconnLite::vec2D_Pair_Buckets[new_bucketIdx] = vector<IFPair>(k);

            while (!vectorMinQue_TopM[b].empty())
            {
                // Be aware of the index shift for different repeat
                FalconnLite::vec2D_Pair_Buckets[new_bucketIdx][k-1] = vectorMinQue_TopM[b].top(); // add into the end of vector
                vectorMinQue_TopM[b].pop();
                k--;
            }
        }

        // int iTotalPoints = 0;
        // for (int i = 0; i < FalconnLite::vec2D_Pair_Buckets.size(); ++i)
        //     iTotalPoints += FalconnLite::vec2D_Pair_Buckets[i].size();
        // cout << "Total points in all buckets: " << iTotalPoints << endl;

        /**
        For each random vector, extract top-m closest data points
        **/

    }

    // Destroy locks for Falconn++
#pragma omp parallel for
    for (size_t i = 0; i < NUM_LOCKS; ++i) {
        omp_destroy_lock(&locks[i]);
    }

    // Note: Should not clear when running testcase
    // FalconnLite::matrix_X.resize(0, 0);
    // FalconnLite::matrix_R.resize(0, 0);

}

/**
 * Finding core points and its approximate neighborhood and approximate core distance using random projection
 * - Store them in vec2D_NeighborDist
 *
 * @param eps
 * @param minPts
 */
RowMajorMatrixXi FalconnLite::rp_sampleNeighborhood2(int topK)
{
    RowMajorMatrixXi matrix_kNN = RowMajorMatrixXi::Zero(FalconnLite::n_points, topK );

    chrono::steady_clock::time_point begin;
    begin = chrono::steady_clock::now();

    std::random_device rd;
    std::mt19937 gen(rd());  // Mersenne Twister engine

#pragma omp parallel for
    for (int n = 0; n < FalconnLite::n_points; ++n)
    {
        VectorXi vecProbes = FalconnLite::matrix_qProbes.col(n);

        int k = 0;
        for (int k1 = 0; k1 < FalconnLite::qProbe && k < topK; ++k1)
        {
            int bucketIdx = vecProbes(k1); // bucketIdx in [(2D)^2]
            IVector bucket = FalconnLite::vec2D_Buckets[bucketIdx];

            // Sample one random point from the bucket Ri
            if (bucket.size() > 1) // the bucket always contain the point itself
            {
                std::uniform_int_distribution<> dist(0, bucket.size() - 1);
                int pointIdx;
                while ( true )
                {
                    pointIdx = bucket[dist(gen)]; // dist(gen) return a random index
                    if (pointIdx != n) // do not put the query point itself
                        break;
                }

                matrix_kNN(n, k) = pointIdx;
                k++;
            }
        }
    }

    FalconnLite::matrix_qProbes.resize(0, 0);
    FalconnLite::vec2D_Buckets.clear();

    if (FalconnLite::verbose)
    {
        cout << "Sampling time = " << chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count() << "[ms]" << endl;
    }

    return matrix_kNN;
}

RowMajorMatrixXi FalconnLite::bucket_sampling(const Ref<const RowMajorMatrixXf> & MATRIX_X, int topK)
{
    if (FalconnLite::verbose)
    {
        cout << "topK: " << topK << endl;
        cout << "qProbe: " << FalconnLite::qProbe << endl;

        cout << "n_points: " << FalconnLite::n_points << endl;
        cout << "n_features: " << FalconnLite::n_features << endl;
        cout << "n_proj: " << FalconnLite::n_proj << endl;
        cout << "topK: " << topK << endl;
        cout << "iProbe: " << FalconnLite::iProbe << endl;

        cout << "distance: " << FalconnLite::distance << endl;
        cout << "kernel features: " << FalconnLite::ker_n_features << endl;
        cout << "sigma: " << FalconnLite::ker_sigma << endl;
        cout << "interval sampling: " << FalconnLite::ker_intervalSampling << endl;
        cout << "n_threads: " << FalconnLite::n_threads << endl;
    }

    omp_set_num_threads(FalconnLite::n_threads);

    FalconnLite::matrix_X = MATRIX_X;
    transformData(FalconnLite::matrix_X, FalconnLite::distance);

    // cout << "Finish transforming data" << endl;

    chrono::steady_clock::time_point begin;

    begin = chrono::steady_clock::now();
    fht_Index2();
    cout << "Build index time = " << chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count() << "[ms]" << endl;

    // begin = chrono::steady_clock::now();
    return rp_sampleNeighborhood2(topK);

    // cout << "Find neighborhoods and distance time = " << chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count() << "[ms]" << endl;

    // return FalconnLite::matrix_kNN.transpose();

}

/**
 * For each table, we retrieve the bucket and aggregate the histogram.
 * We maintain a vector of histogram - which is a map of point idx and its frequency (take much space for large n)
 * @param MATRIX_X
 * @param topK
 * @param n_repeats
 * @return
 */
RowMajorMatrixXi FalconnLite::freq_counting_old(const Ref<const RowMajorMatrixXf> & MATRIX_X, int topK, int n_repeats)
{
    cout << "topK: " << topK << endl;
    cout << "qProbe: " << FalconnLite::qProbe << endl;

    cout << "n_points: " << FalconnLite::n_points << endl;
    cout << "n_features: " << FalconnLite::n_features << endl;
    cout << "n_proj: " << FalconnLite::n_proj << endl;
    cout << "top_m: " << FalconnLite::top_m << endl;
    cout << "iProbe: " << FalconnLite::iProbe << endl;

    cout << "distance: " << FalconnLite::distance << endl;
    cout << "kernel features: " << FalconnLite::ker_n_features << endl;
    cout << "sigma: " << FalconnLite::ker_sigma << endl;
    cout << "interval sampling: " << FalconnLite::ker_intervalSampling << endl;
    cout << "n_threads: " << FalconnLite::n_threads << endl;

    omp_set_num_threads(FalconnLite::n_threads);

    FalconnLite::matrix_X = MATRIX_X;
    transformData(FalconnLite::matrix_X, FalconnLite::distance);
    // cout << "Finish transforming data" << endl;

    chrono::steady_clock::time_point begin;

    vector<tsl::robin_map<int, int>> vecHist(FalconnLite::n_points); // point idx, distance
//#pragma omp parallel for
//    for (int n = 0; n < FalconnLite::n_points; ++n) {
//        vecHist[n].reserve(n_repeats * FalconnLite::bucketLimit); // avoid rehashing
//    }

    float constructTime = 0.0, buildHistTime = 0.0, extractTopKTime = 0.0;
    for (int r = 0; r < n_repeats; ++r)
    {
        begin = chrono::steady_clock::now();
        cout << "Repeat idx: " << r + 1 << "/" << n_repeats << endl;

        fht_Index2();
        constructTime += chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count();

        begin = chrono::steady_clock::now();
#pragma omp parallel for
        for (int n = 0; n < FalconnLite::n_points; ++n) {

            VectorXi vecProbes = FalconnLite::matrix_qProbes.col(n);

            // For each random vectors
            for (int k = 0; k < FalconnLite::qProbe; ++k)
            {
                int bucketIdx = vecProbes(k); // Ri in [(2D)^2]
                IVector bucket = FalconnLite::vec2D_Buckets[bucketIdx];
                // cout << "Bucket size: " << bucket.size() << endl;

                if (bucket.size() > 1) {
                    for (const int pointIdx: bucket) {
                        if (pointIdx != n) // do not put the query point itself
                        {
                            // Update histogram
                            if (vecHist[n].find(pointIdx) != vecHist[n].end())
                                vecHist[n][pointIdx]++;
                            else
                                vecHist[n][pointIdx] = 1;
                        }
                    }
                }
            }
        }

        // FalconnLite::matrix_qProbes.resize(0, 0);
        // FalconnLite::vec2D_Buckets.clear();

        buildHistTime += chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count();
    }

    cout << "Construct index time (ms) = " << constructTime << endl;
    cout << "Build histogram time (ms) = " << buildHistTime << endl;

// Extract top-K from histogram

    begin = chrono::steady_clock::now();

    RowMajorMatrixXi matrix_kNN = RowMajorMatrixXi::Zero(FalconnLite::n_points, topK );
#pragma omp parallel for
    for (int n = 0; n < FalconnLite::n_points; ++n)
    {
        priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopK;

        for (const auto& [pointIdx, count] : vecHist[n])
        {
            if ((int)minQueTopK.size() < topK)
                minQueTopK.emplace(pointIdx, count); // emplace is push without creating temp data
            else if (count > minQueTopK.top().m_fValue)
            {
                minQueTopK.pop();
                minQueTopK.emplace(pointIdx, count); // No need IFPair()
            }
        }

        int k = topK - 1;
        // MinQue has the size TopK
        while (!minQueTopK.empty())
        {
            IFPair ifPair = minQueTopK.top(); // index is bucketID, value is sumAbsHash
            minQueTopK.pop();
            matrix_kNN(n, k) = ifPair.m_iIndex;
            k--;
        }
    }

    extractTopKTime = chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count();
    cout << "Extract top-K time (ms) = " << extractTopKTime << endl;

    return matrix_kNN;
}

/**
 * We create several tables at the same time.
 * For each point, we return top-k frequent points in all tables
 * @param MATRIX_X
 * @param topK
 * @param n_repeats
 * @return
 */
RowMajorMatrixXi FalconnLite::coll_counting(const Ref<const RowMajorMatrixXf> & MATRIX_X, int topK, int n_repeats)
{
    if (FalconnLite::verbose) {

        cout << "topK: " << topK << endl;
        cout << "qProbe: " << FalconnLite::qProbe << endl;

        cout << "n_points: " << FalconnLite::n_points << endl;
        cout << "n_features: " << FalconnLite::n_features << endl;
        cout << "n_proj: " << FalconnLite::n_proj << endl;
        cout << "top_m: " << FalconnLite::top_m << endl;
        cout << "iProbe: " << FalconnLite::iProbe << endl;

        cout << "distance: " << FalconnLite::distance << endl;
        cout << "fht dimensions: " << FalconnLite::fhtDim << endl;
        cout << "kernel features: " << FalconnLite::ker_n_features << endl;
        cout << "sigma: " << FalconnLite::ker_sigma << endl;
        cout << "interval sampling: " << FalconnLite::ker_intervalSampling << endl;
        cout << "n_threads: " << FalconnLite::n_threads << endl;
    }

    omp_set_num_threads(FalconnLite::n_threads);

    FalconnLite::matrix_X = MATRIX_X;
    transformData(FalconnLite::matrix_X, FalconnLite::distance);
    // cout << "Finish transforming data" << endl;

    chrono::steady_clock::time_point begin = chrono::steady_clock::now();
    fht_Index2_repeat(n_repeats);
    cout << "Construct index time (ms) = " << chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count() << endl;

    RowMajorMatrixXi matrix_kNN = RowMajorMatrixXi::Zero(FalconnLite::n_points, topK );

    float numColl = 0.0;

    std::vector<double> constructHistTime_thr(FalconnLite::n_threads, 0.0);
    std::vector<double> extractTopKTime_thr(FalconnLite::n_threads, 0.0);
    std::vector<double> threadWall_thr(FalconnLite::n_threads, 0.0);

    double wall_start = omp_get_wtime();

#pragma omp parallel for reduction(+:numColl)
    for (int n = 0; n < FalconnLite::n_points; ++n) {

        const int tid = omp_get_thread_num();
        double region_start = omp_get_wtime();

        double t0 = omp_get_wtime();

        tsl::robin_map<int, float> tslMap; // point idx, distance
        tslMap.reserve(n_repeats * FalconnLite::top_m * FalconnLite::qProbe); // avoid rehashing

        // Buid histogram over all tables
        for (int repeat = 0; repeat < n_repeats; ++repeat)
        {
            // For each random vectors from all repeats
            VectorXi vecProbes = FalconnLite::matrix_qProbes.col(n);

            for (int k = 0; k < FalconnLite::qProbe; ++k)
            {
                int new_k = k + repeat * FalconnLite::qProbe;

                int bucketIdx = vecProbes(new_k); // Ri in [(2D)^2]
                // cout << "bucketIdx: " << bucketIdx << endl;

                IVector bucket = FalconnLite::vec2D_Buckets[bucketIdx];
                numColl += bucket.size();

                // cout << "Bucket size: " << bucket.size() << endl;

                if (bucket.size() > 1) {
                    for (const auto pointIdx: bucket) {

                        if (pointIdx != n) // do not put the query point itself
                        {
                            // Update histogram
                            if (tslMap.find(pointIdx) == tslMap.end())
                                tslMap[pointIdx] = 1; //pair.m_fValue;
                            else
                                tslMap[pointIdx] += 1; // pair.m_fValue;
                        }
                    }
                }
            }
        }

        constructHistTime_thr[tid] += omp_get_wtime() - t0;

        // Extract top-K from histogram
        double t1 = omp_get_wtime();
        priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopK;
        for (const auto& [pointIdx, est] : tslMap)
        {
            if ((int)minQueTopK.size() < topK)
                minQueTopK.emplace(pointIdx, est); // emplace is push without creating temp data
            else if (est > minQueTopK.top().m_fValue)
            {
                minQueTopK.pop();
                minQueTopK.emplace(pointIdx, est); // No need IFPair()
            }
        }

        tslMap.clear();

        int k = topK - 1;
        // MinQue has the size TopK
        while (!minQueTopK.empty())
        {
            IFPair ifPair = minQueTopK.top(); // index is bucketID, value is sumAbsHash
            minQueTopK.pop();
            matrix_kNN(n, k) = ifPair.m_iIndex;
            k--;
        }


        extractTopKTime_thr[tid] += omp_get_wtime() - t1;
        threadWall_thr[tid] += omp_get_wtime() - region_start; // this thread's wall time inside region

    }

    double wall_elapsed = omp_get_wtime() - wall_start;

    double constructHistTime_sum =
            std::accumulate(constructHistTime_thr.begin(), constructHistTime_thr.end(), 0.0);
    double extractTopKTime_sum =
            std::accumulate(extractTopKTime_thr.begin(), extractTopKTime_thr.end(), 0.0);
    double threadWall_sum =
            std::accumulate(threadWall_thr.begin(), threadWall_thr.end(), 0.0);

    if (FalconnLite::verbose)
    {
        cout << "Avg number of collisions per point = " << numColl / FalconnLite::n_points << endl;

        cout << "Sum of per-thread histogram construction time (s) = " << constructHistTime_sum << endl;
        cout << "Sum of per-thread top-K time (s) = " << extractTopKTime_sum << endl;
        cout << "Sum of per-thread dist, top-k, and some overhead time (s): " << threadWall_sum << " \n";
    }

    cout << "Extract approximate top-k (s): " << wall_elapsed << " \n";

    return matrix_kNN;
}

RowMajorMatrixXi FalconnLite::dist_estimating(const Ref<const RowMajorMatrixXf> & MATRIX_X, int topK, int n_repeats)
{
    if (FalconnLite::verbose) {
        cout << "topK: " << topK << endl;
        cout << "qProbe: " << FalconnLite::qProbe << endl;

        cout << "n_points: " << FalconnLite::n_points << endl;
        cout << "n_features: " << FalconnLite::n_features << endl;
        cout << "n_proj: " << FalconnLite::n_proj << endl;
        cout << "top_m: " << FalconnLite::top_m << endl;
        cout << "iProbe: " << FalconnLite::iProbe << endl;

        cout << "distance: " << FalconnLite::distance << endl;
        cout << "kernel features: " << FalconnLite::ker_n_features << endl;
        cout << "sigma: " << FalconnLite::ker_sigma << endl;
        cout << "interval sampling: " << FalconnLite::ker_intervalSampling << endl;
        cout << "n_threads: " << FalconnLite::n_threads << endl;
    }

    omp_set_num_threads(FalconnLite::n_threads);

    FalconnLite::matrix_X = MATRIX_X;
    transformData(FalconnLite::matrix_X, FalconnLite::distance);
    // cout << "Finish transforming data" << endl;

    chrono::steady_clock::time_point begin = chrono::steady_clock::now();
    fht_pairIndex2_repeat(n_repeats);
    cout << "Construct index time (ms) = " << chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count() << endl;

    RowMajorMatrixXi matrix_kNN = RowMajorMatrixXi::Zero(FalconnLite::n_points, topK );

    float numColl = 0.0;

    std::vector<double> constructHistTime_thr(FalconnLite::n_threads, 0.0);
    std::vector<double> extractTopKTime_thr(FalconnLite::n_threads, 0.0);
    std::vector<double> threadWall_thr(FalconnLite::n_threads, 0.0);

    double wall_start = omp_get_wtime();

    // int emptyBucketCounter = 0;
#pragma omp parallel for reduction(+: numColl)
    for (int n = 0; n < FalconnLite::n_points; ++n) {

        const int tid = omp_get_thread_num();
        double region_start = omp_get_wtime();

        double t0 = omp_get_wtime();


        tsl::robin_map<int, float> tslMap; // point idx, distance
        tslMap.reserve(n_repeats * FalconnLite::top_m * FalconnLite::qProbe); // avoid rehashing

        // Buid histogram over all tables
        for (int repeat = 0; repeat < n_repeats; ++repeat)
        {
            // For each random vectors from all repeats
            VectorXi vecProbes = FalconnLite::matrix_qProbes.col(n);

            for (int k = 0; k < FalconnLite::qProbe; ++k)
            {
                int new_k = k + repeat * FalconnLite::qProbe;

                int bucketIdx = vecProbes(new_k); // Ri in [(2D)^2]
                // cout << "bucketIdx: " << bucketIdx << endl;

                vector<IFPair> bucket = FalconnLite::vec2D_Pair_Buckets[bucketIdx];
                numColl += bucket.size();

                // cout << "Bucket size: " << bucket.size() << endl;

                if (bucket.size() > 1) {
                    for (const auto pair: bucket) {
                        // pair.m_iIndex is point idx, pair.m_fValue is the score
                        if (pair.m_iIndex != n) // do not put the query point itself
                        {
                            // Update histogram
                            if (tslMap.find(pair.m_iIndex) == tslMap.end())
                                tslMap[pair.m_iIndex] = pair.m_fValue;
                            else
                                tslMap[pair.m_iIndex] += pair.m_fValue;
                        }
                    }
                }
                // else
                //     emptyBucketCounter++;
            }
        }

        constructHistTime_thr[tid] += omp_get_wtime() - t0;

        // Extract top-K from histogram
        double t1 = omp_get_wtime();
        priority_queue< IFPair, vector<IFPair>, greater<> > minQueTopK;
        for (const auto& [pointIdx, est] : tslMap)
        {
            if ((int)minQueTopK.size() < topK)
                minQueTopK.emplace(pointIdx, est); // emplace is push without creating temp data
            else if (est > minQueTopK.top().m_fValue)
            {
                minQueTopK.pop();
                minQueTopK.emplace(pointIdx, est); // No need IFPair()
            }
        }


        tslMap.clear();

        int k = topK - 1;
        // MinQue has the size TopK
        while (!minQueTopK.empty())
        {
            IFPair ifPair = minQueTopK.top(); // index is bucketID, value is sumAbsHash
            minQueTopK.pop();
            matrix_kNN(n, k) = ifPair.m_iIndex;
            k--;
        }


        extractTopKTime_thr[tid] += omp_get_wtime() - t1;
        threadWall_thr[tid] += omp_get_wtime() - region_start; // this thread's wall time inside region

    }

    double wall_elapsed = omp_get_wtime() - wall_start;

    double constructHistTime_sum =
            std::accumulate(constructHistTime_thr.begin(), constructHistTime_thr.end(), 0.0);
    double extractTopKTime_sum =
            std::accumulate(extractTopKTime_thr.begin(), extractTopKTime_thr.end(), 0.0);
    double threadWall_sum =
            std::accumulate(threadWall_thr.begin(), threadWall_thr.end(), 0.0);

    if (FalconnLite::verbose)
    {
        cout << "Avg number of collisions per point = " << numColl / FalconnLite::n_points << endl;

        cout << "Sum of per-thread histogram construction time (s) = " << constructHistTime_sum << endl;
        cout << "Sum of per-thread top-K time (s) = " << extractTopKTime_sum << endl;
        cout << "Sum of per-thread dist, top-k, and some overhead time (s): " << threadWall_sum << " \n";
    }

    cout << "Extract approximate top-k (s): " << wall_elapsed << " \n";

    return matrix_kNN;
}

tuple<RowMajorMatrixXi, RowMajorMatrixXf> FalconnLite::approx_kNN(const Ref<const RowMajorMatrixXf> & MATRIX_X, int topK, int n_repeats)
{
    if (FalconnLite::verbose) {
        cout << "topK: " << topK << endl;
        cout << "qProbe: " << FalconnLite::qProbe << endl;
        cout << "iProbe: " << FalconnLite::iProbe << endl;
        cout << "n_repeats: " << n_repeats << endl;
        cout << "n_proj: " << FalconnLite::n_proj << endl;
        cout << "top_m: " << FalconnLite::top_m << endl;

        cout << "distance: " << FalconnLite::distance << endl;
        cout << "kernel features: " << FalconnLite::ker_n_features << endl;
        cout << "sigma: " << FalconnLite::ker_sigma << endl;
        cout << "interval sampling: " << FalconnLite::ker_intervalSampling << endl;
        cout << "n_threads: " << FalconnLite::n_threads << endl;
    }

    omp_set_num_threads(FalconnLite::n_threads);

    FalconnLite::matrix_X = MATRIX_X;
    transformData(FalconnLite::matrix_X, FalconnLite::distance);
    // cout << "Finish transforming data" << endl;

    float numDist = 0.0, numColl = 0.0;

    std::vector<double> computeDistTime_thr(FalconnLite::n_threads, 0.0);
    std::vector<double> extractTopKTime_thr(FalconnLite::n_threads, 0.0);
    std::vector<double> threadWall_thr(FalconnLite::n_threads, 0.0);

    chrono::steady_clock::time_point begin = chrono::steady_clock::now();
    fht_Index2_repeat(n_repeats);
    cout << "Construct index time (ms) = " <<
                    chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count()
                    << endl;

    vector<priority_queue< IFPair, vector<IFPair> > > vecMaxQueue(FalconnLite::n_points); // point idx, distance

    RowMajorMatrixXi indices = RowMajorMatrixXi::Zero(FalconnLite::n_points, topK );
    indices.setConstant(-1);
    RowMajorMatrixXf distances = RowMajorMatrixXf::Zero(FalconnLite::n_points, topK );
    distances.setConstant(numeric_limits<float>::infinity());

    double wall_start = omp_get_wtime();

#pragma omp parallel for reduction(+: numDist, numColl)
    for (int n = 0; n < FalconnLite::n_points; ++n)
    {
        const int tid = omp_get_thread_num();
        double region_start = omp_get_wtime();

        double t0 = omp_get_wtime();

        boost::dynamic_bitset<> bitsetHist(FalconnLite::n_points);

        VectorXi vecProbes = FalconnLite::matrix_qProbes.col(n);
        VectorXf vecPoint = FalconnLite::matrix_X.row(n);

        for (int repeat = 0; repeat < n_repeats; ++repeat)
        {
            for (int k = 0; k < FalconnLite::qProbe; ++k) {

                int bucketIdx = vecProbes(k + repeat * FalconnLite::qProbe);
                numColl += FalconnLite::vec2D_Buckets[bucketIdx].size();

                // Sample one random point from the bucket Ri
                for (const int pointIdx: FalconnLite::vec2D_Buckets[bucketIdx]) {
                    if (pointIdx != n && ~bitsetHist[pointIdx]) // do not put the query point itself
                    {
                        bitsetHist[pointIdx] = true;

                        float dist = computeDist(vecPoint, FalconnLite::matrix_X.row(pointIdx), FalconnLite::distance);

                        if ((int)vecMaxQueue[n].size() < topK)
                            vecMaxQueue[n].emplace(pointIdx, dist); // emplace is push without creating temp data
                        else if (dist < vecMaxQueue[n].top().m_fValue)
                        {
                            vecMaxQueue[n].pop();
                            vecMaxQueue[n].emplace(pointIdx, dist); // No need IFPair()
                        }
                    }
                }
            }
        }

        // computeDistTime += chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count();
        computeDistTime_thr[tid] += omp_get_wtime() - t0;

        numDist += bitsetHist.count();

        // begin = chrono::steady_clock::now();
        double t1 = omp_get_wtime();

        // Extract top-K from histogram
        int k = topK - 1;
        // MaxQue has the size TopK
        while (!vecMaxQueue[n].empty())
        {
            IFPair ifPair = vecMaxQueue[n].top();
            vecMaxQueue[n].pop();

            indices(n, k) = ifPair.m_iIndex;
            distances(n, k) = ifPair.m_fValue;
            k--;
        }

        //        extractTopKTime += chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - begin).count();
        extractTopKTime_thr[tid] += omp_get_wtime() - t1;

        threadWall_thr[tid] += omp_get_wtime() - region_start; // this thread's wall time inside region

    }

    double wall_elapsed = omp_get_wtime() - wall_start;

    double computeDistTime_sum =
            std::accumulate(computeDistTime_thr.begin(), computeDistTime_thr.end(), 0.0);
    double extractTopKTime_sum =
            std::accumulate(extractTopKTime_thr.begin(), extractTopKTime_thr.end(), 0.0);
    double threadWall_sum =
            std::accumulate(threadWall_thr.begin(), threadWall_thr.end(), 0.0);

    if (FalconnLite::verbose)
    {
        cout << "Avg distance computation per point = " << numDist / FalconnLite::n_points << endl<< std::flush;;
        cout << "Avg number of collisions per point = " << numColl / FalconnLite::n_points << endl<< std::flush;;

        cout << "Sum of per-thread dist computation time (s) = " << computeDistTime_sum << endl<< std::flush;;
        cout << "Sum of per-thread top-K time (s) = " << extractTopKTime_sum << endl<< std::flush;;
        cout << "Sum of per-thread dist, top-k, and some overhead time (s): " << threadWall_sum << " \n"<< std::flush;;
    }

    cout << "Extract approximate top-k (s): " << wall_elapsed << " \n";

    return make_tuple(indices, distances);
}


tuple<RowMajorMatrixXi, RowMajorMatrixXf> FalconnLite::approx_join(
    const Ref<const RowMajorMatrixXf> & MATRIX_X, int topK, int n_repeats) {

    if (FalconnLite::verbose) {
        cout << "topK: " << topK << endl;
        cout << "qProbe: " << FalconnLite::qProbe << endl;
        cout << "iProbe: " << FalconnLite::iProbe << endl;
        cout << "n_repeats: " << n_repeats << endl;
        cout << "n_proj: " << FalconnLite::n_proj << endl;
        cout << "top_m: " << FalconnLite::top_m << endl;
        cout << "distance: " << FalconnLite::distance << endl;
        cout << "n_threads: " << FalconnLite::n_threads << endl;
    }

    omp_set_num_threads(FalconnLite::n_threads);
    FalconnLite::matrix_X = MATRIX_X;
    transformData(FalconnLite::matrix_X, FalconnLite::distance);

    double numDist = 0.0, numColl = 0.0; // use double for accumulation
    int duplcount = 0;

    const int N = (int)FalconnLite::n_points;

    // vecTopK: 每个点固定分配 topK 槽（避免并发 push_back 导致内存重分配）
    vector<vector<IFPair>> vecTopK(N);
    for (int i = 0; i < N; ++i) {
        vecTopK[i].resize(topK); // fill slots with default IFPair(-1, inf)
        for (int k = 0; k < topK; ++k) vecTopK[i][k] = IFPair(-1, numeric_limits<float>::infinity());
    }
    // curSize: 当前每个 vecTopK 已经填充的有效元素数（<= topK）
    vector<int> curSize(N, 0);
    // maxVal / maxIdx: 维护当前 top-k 中的最大距离（用于快速替换）
    vector<float> maxVal(N, numeric_limits<float>::infinity());
    vector<int> maxIdx(N, 0);

    RowMajorMatrixXi indices = RowMajorMatrixXi::Constant(N, topK, -1);
    RowMajorMatrixXf distances = RowMajorMatrixXf::Constant(N, topK, numeric_limits<float>::infinity());

    // locks 池，按 idx % NUM_LOCKS 映射
    const size_t NUM_LOCKS = 16384;
    vector<omp_lock_t> locks(NUM_LOCKS);
    for (size_t i = 0; i < NUM_LOCKS; ++i) omp_init_lock(&locks[i]);

    chrono::steady_clock::time_point begin = chrono::steady_clock::now();
    fht_Index2_repeat(n_repeats);
    cout << "Construct index time (ms) = "
         << chrono::duration_cast<chrono::milliseconds>(
                chrono::steady_clock::now() - begin).count()
         << endl;

    // helper: 检查 topkVec 中是否存在 idx（只遍历 curSize 个有效槽）
    auto checkExistLocal = [&](const vector<IFPair> &topkVec, int curSz, int idx) -> bool {
        for (int t = 0; t < curSz; ++t) if (topkVec[t].m_iIndex == idx) return true;
        return false;
    };

    // helper: 在持锁下调用，更新 topK（用 curSize, maxVal, maxIdx）
    auto insertIntoTopKLocal = [&](vector<IFPair>& topkVec, int &curSz, int &maxI, float &maxV, int idx, float dist) -> bool {
        // 已存在（再次检查） -> 不插入
        for (int t = 0; t < curSz; ++t) if (topkVec[t].m_iIndex == idx) return false;

        if (curSz < topK) {
            // 使用下一个槽
            topkVec[curSz] = IFPair(idx, dist);
            // 更新 maxV / maxI
            if (dist > maxV) {
                maxV = dist;
                maxI = curSz;
            } else if (curSz == 0) {
                // 初始填入第一个元素：把 maxVal/maxIdx 设置为该元素
                maxV = topkVec[0].m_fValue;
                maxI = 0;
            }
            curSz++;
            return true;
        } else {
            // 已满：如果 new dist 小于当前最大，则替换并重算最大
            if (dist < maxV) {
                topkVec[maxI] = IFPair(idx, dist);
                // 重新找新的最大
                float mv = topkVec[0].m_fValue;
                int mI = 0;
                for (int t = 1; t < topK; ++t) {
                    if (topkVec[t].m_fValue > mv) { mv = topkVec[t].m_fValue; mI = t; }
                }
                maxV = mv;
                maxI = mI;
                return true;
            }
            return false;
        }
    };

    // 主循环：double-check 模式，先短锁检查 -> 无锁计算 -> 再次短锁插入
    #pragma omp parallel reduction(+: numDist, numColl, duplcount)
    {
        const size_t B = FalconnLite::vec2D_Buckets.size();
        #pragma omp for schedule(dynamic)
        for (size_t b = 0; b < B; ++b) {
            IVector &bucketPoints = FalconnLite::vec2D_Buckets[b];
            numColl += (double)bucketPoints.size();

            size_t m = bucketPoints.size();
            for (size_t ii = 0; ii < m; ++ii) {
                int idx_i = bucketPoints[ii];
                for (size_t jj = ii + 1; jj < m; ++jj) {
                    int idx_j = bucketPoints[jj];
                    if (idx_i == idx_j) continue;

                    // lock ids & order (避免死锁)
                    int lock_i = idx_i % (int)NUM_LOCKS;
                    int lock_j = idx_j % (int)NUM_LOCKS;
                    int minL = std::min(lock_i, lock_j);
                    int maxL = std::max(lock_i, lock_j);

                    // --- 第一次快速检查（持短锁） ---
                    bool exists = false;
                    if (minL == maxL) {
                        omp_set_lock(&locks[minL]);
                        if (checkExistLocal(vecTopK[idx_i], curSize[idx_i], idx_j) ||
                            checkExistLocal(vecTopK[idx_j], curSize[idx_j], idx_i)) exists = true;
                        omp_unset_lock(&locks[minL]);
                    } else {
                        omp_set_lock(&locks[minL]);
                        omp_set_lock(&locks[maxL]);
                        if (checkExistLocal(vecTopK[idx_i], curSize[idx_i], idx_j) ||
                            checkExistLocal(vecTopK[idx_j], curSize[idx_j], idx_i)) exists = true;
                        omp_unset_lock(&locks[maxL]);
                        omp_unset_lock(&locks[minL]);
                    }
                    if (exists) { duplcount++; continue; }

                    // --- 计算距离（无锁） ---
                    float dist = computeDist(FalconnLite::matrix_X.row(idx_i),
                                             FalconnLite::matrix_X.row(idx_j),
                                             FalconnLite::distance);
                    numDist += 1.0;

                    // --- 第二次检查并尝试插入（持锁） ---
                    bool inserted_any = false;
                    if (minL == maxL) {
                        omp_set_lock(&locks[minL]);
                        // 再次检查避免 race
                        if (!checkExistLocal(vecTopK[idx_i], curSize[idx_i], idx_j) &&
                            !checkExistLocal(vecTopK[idx_j], curSize[idx_j], idx_i)) {
                            // 插入两端（分别尝试）
                            bool ins_i = insertIntoTopKLocal(vecTopK[idx_i], curSize[idx_i], maxIdx[idx_i], maxVal[idx_i], idx_j, dist);
                            bool ins_j = insertIntoTopKLocal(vecTopK[idx_j], curSize[idx_j], maxIdx[idx_j], maxVal[idx_j], idx_i, dist);
                            inserted_any = ins_i || ins_j;
                        }
                        omp_unset_lock(&locks[minL]);
                    } else {
                        omp_set_lock(&locks[minL]);
                        omp_set_lock(&locks[maxL]);
                        if (!checkExistLocal(vecTopK[idx_i], curSize[idx_i], idx_j) &&
                            !checkExistLocal(vecTopK[idx_j], curSize[idx_j], idx_i)) {
                            bool ins_i = insertIntoTopKLocal(vecTopK[idx_i], curSize[idx_i], maxIdx[idx_i], maxVal[idx_i], idx_j, dist);
                            bool ins_j = insertIntoTopKLocal(vecTopK[idx_j], curSize[idx_j], maxIdx[idx_j], maxVal[idx_j], idx_i, dist);
                            inserted_any = ins_i || ins_j;
                        }
                        omp_unset_lock(&locks[maxL]);
                        omp_unset_lock(&locks[minL]);
                    }

                    if (!inserted_any) duplcount++;
                }
            }
        }

    } // pragma omp parallel

    // 整理输出：每点 sort 有效元素并写回 indices/distances
    #pragma omp parallel for
    for (int n = 0; n < N; ++n) {
        int sz = curSize[n];
        if (sz == 0) continue;
        // sort first sz elements by ascending distance
        sort(vecTopK[n].begin(), vecTopK[n].begin() + sz,
             [](const IFPair &a, const IFPair &b){ return a.m_fValue < b.m_fValue; });
        int use = std::min(sz, topK);
        for (int k = 0; k < use; ++k) {
            indices(n, k) = vecTopK[n][k].m_iIndex;
            distances(n, k) = vecTopK[n][k].m_fValue;
        }
        // if sz < topK, remaining entries in indices/distances remain -1/inf
    }

    // cleanup locks
    for (size_t i = 0; i < NUM_LOCKS; ++i) omp_destroy_lock(&locks[i]);

    if (FalconnLite::verbose) {
        size_t total_count = 0;
        for (int i = 0; i < N; ++i) total_count += curSize[i];
        cout << "Total unique neighbor count: " << total_count << endl;
        cout << "Duplicate pairs skipped: " << duplcount << endl;
        cout << "Avg distance computations per point = " << (numDist / (double)N) << endl;
        cout << "Avg number of collisions per point = " << (numColl / (double)N) << endl;
    }

    return make_tuple(indices, distances);
}


//
// tuple<RowMajorMatrixXi, RowMajorMatrixXf> FalconnLite::approx_join(
//     const Ref<const RowMajorMatrixXf> & MATRIX_X, int topK, int n_repeats) {
//
//     if (FalconnLite::verbose) {
//         cout << "topK: " << topK << endl;
//         cout << "qProbe: " << FalconnLite::qProbe << endl;
//         cout << "iProbe: " << FalconnLite::iProbe << endl;
//         cout << "n_repeats: " << n_repeats << endl;
//         cout << "n_proj: " << FalconnLite::n_proj << endl;
//         cout << "top_m: " << FalconnLite::top_m << endl;
//         cout << "distance: " << FalconnLite::distance << endl;
//         cout << "n_threads: " << FalconnLite::n_threads << endl;
//     }
//
//     omp_set_num_threads(FalconnLite::n_threads);
//     FalconnLite::matrix_X = MATRIX_X;
//     transformData(FalconnLite::matrix_X, FalconnLite::distance);
//
//     float numDist = 0.0, numColl = 0.0;
//     int duplcount = 0;
//
//     // 初始化 vecTopK，每个点直接有 topK 个元素
//     vector<vector<IFPair>> vecTopK(FalconnLite::n_points, vector<IFPair>(topK, IFPair(-1, numeric_limits<float>::infinity())));
//     vector<float> maxDistVal(FalconnLite::n_points, numeric_limits<float>::infinity());
//     vector<int> maxDistIdx(FalconnLite::n_points, 0);
//
//     RowMajorMatrixXi indices = RowMajorMatrixXi::Constant(FalconnLite::n_points, topK, -1);
//     RowMajorMatrixXf distances = RowMajorMatrixXf::Constant(FalconnLite::n_points, topK, numeric_limits<float>::infinity());
//
//     // ---------------- omp_lock_t 锁数组 ----------------
//     const size_t NUM_LOCKS = 16384;
//     static omp_lock_t locks[NUM_LOCKS];
//     #pragma omp parallel for
//     for (size_t i = 0; i < NUM_LOCKS; ++i) {
//         omp_init_lock(&locks[i]);
//     }
//
//     chrono::steady_clock::time_point begin = chrono::steady_clock::now();
//     fht_Index2_repeat(n_repeats);
//     cout << "Construct index time (ms) = "
//          << chrono::duration_cast<chrono::milliseconds>(
//                 chrono::steady_clock::now() - begin).count()
//          << endl;
//
//     auto checkExist = [&](const vector<IFPair>& topkVec, int idx) {
//         for (auto &p : topkVec)
//             if (p.m_iIndex == idx)
//                 return true;
//         return false;
//     };
//
//     auto updateTopK = [&](vector<IFPair>& topkVec, int &maxIdx, float &maxVal, int idx, float dist, int topK) {
//         for (auto &p : topkVec)
//             if (p.m_iIndex == idx)
//                 return false;
//
//         if (dist < maxVal) {
//             topkVec[maxIdx] = IFPair(idx, dist);
//             maxVal = topkVec[0].m_fValue;
//             maxIdx = 0;
//             for (int i = 1; i < topK; ++i) {
//                 if (topkVec[i].m_fValue > maxVal) {
//                     maxVal = topkVec[i].m_fValue;
//                     maxIdx = i;
//                 }
//             }
//             return true;
//         }
//         return false;
//     };
//
//     // ---------------- 主循环 ----------------
//     #pragma omp parallel reduction(+:numDist, numColl, duplcount)
//     {
//         #pragma omp for schedule(dynamic)
//         for (int repeat = 0; repeat < n_repeats; ++repeat) {
//             for (size_t b = 0; b < FalconnLite::vec2D_Buckets.size(); ++b) {
//                 IVector &bucketPoints = FalconnLite::vec2D_Buckets[b];
//                 numColl += bucketPoints.size();
//
//                 for (size_t i = 0; i < bucketPoints.size(); ++i) {
//                     int idx_i = bucketPoints[i];
//                     for (size_t j = i + 1; j < bucketPoints.size(); ++j) {
//                         int idx_j = bucketPoints[j];
//                         if (idx_i == idx_j) continue;
//
//                         int lock_i = idx_i % NUM_LOCKS;
//                         int lock_j = idx_j % NUM_LOCKS;
//
//                         // ----------- 计算前检查是否已存在 -----------
//                         bool exists_i = false, exists_j = false;
//                         omp_set_lock(&locks[lock_i]);
//                         exists_i = checkExist(vecTopK[idx_i], idx_j);
//                         omp_unset_lock(&locks[lock_i]);
//
//                         omp_set_lock(&locks[lock_j]);
//                         exists_j = checkExist(vecTopK[idx_j], idx_i);
//                         omp_unset_lock(&locks[lock_j]);
//
//                         if (exists_i || exists_j) {
//                             duplcount++;
//                             continue;
//                         }
//
//                         // ----------- 计算距离 -----------
//                         float dist = computeDist(FalconnLite::matrix_X.row(idx_i),
//                                                  FalconnLite::matrix_X.row(idx_j),
//                                                  FalconnLite::distance);
//                         numDist++;
//
//                         // ----------- 更新两端的 TopK -----------
//                         bool inserted_i = false, inserted_j = false;
//                         omp_set_lock(&locks[lock_i]);
//                         inserted_i = updateTopK(vecTopK[idx_i], maxDistIdx[idx_i], maxDistVal[idx_i], idx_j, dist, topK);
//                         omp_unset_lock(&locks[lock_i]);
//
//                         omp_set_lock(&locks[lock_j]);
//                         inserted_j = updateTopK(vecTopK[idx_j], maxDistIdx[idx_j], maxDistVal[idx_j], idx_i, dist, topK);
//                         omp_unset_lock(&locks[lock_j]);
//
//                         if (!inserted_i && !inserted_j)
//                             duplcount++;
//                     }
//                 }
//             }
//         }
//     }
//
//     // ---------------- 输出整理 ----------------
//     #pragma omp parallel for
//     for (int n = 0; n < FalconnLite::n_points; ++n) {
//         auto &topkVec = vecTopK[n];
//         sort(topkVec.begin(), topkVec.end(),
//              [](const IFPair &a, const IFPair &b) { return a.m_fValue < b.m_fValue; });
//         for (size_t k = 0; k < topkVec.size() && k < (size_t)topK; ++k) {
//             indices(n, k) = topkVec[k].m_iIndex;
//             distances(n, k) = topkVec[k].m_fValue;
//         }
//     }
//
//     // ---------------- 销毁锁 ----------------
//     #pragma omp parallel for
//     for (size_t i = 0; i < NUM_LOCKS; ++i) {
//         omp_destroy_lock(&locks[i]);
//     }
//
//     if (FalconnLite::verbose) {
//         size_t total_count = 0;
//         // for (auto &v : vecTopK) total_count += v.size();
//         // cout << "Total unique neighbor count: " << total_count << endl;
//         cout << "Duplicate pairs skipped: " << duplcount << endl;
//         cout << "Avg distance computations per point = " << numDist / FalconnLite::n_points << endl;
//         cout << "Avg collisions per point = " << numColl / FalconnLite::n_points << endl;
//     }
//
//     return make_tuple(indices, distances);
// }
//
//



// tuple<RowMajorMatrixXi, RowMajorMatrixXf> FalconnLite::approx_join(
//     const Ref<const RowMajorMatrixXf> & MATRIX_X, int topK, int n_repeats) {
//
//     if (FalconnLite::verbose) {
//         cout << "topK: " << topK << endl;
//         cout << "qProbe: " << FalconnLite::qProbe << endl;
//         cout << "iProbe: " << FalconnLite::iProbe << endl;
//         cout << "n_repeats: " << n_repeats << endl;
//         cout << "n_proj: " << FalconnLite::n_proj << endl;
//         cout << "top_m: " << FalconnLite::top_m << endl;
//         cout << "distance: " << FalconnLite::distance << endl;
//         cout << "n_threads: " << FalconnLite::n_threads << endl;
//     }
//
//     omp_set_num_threads(FalconnLite::n_threads);
//     FalconnLite::matrix_X = MATRIX_X;
//     transformData(FalconnLite::matrix_X, FalconnLite::distance);
//
//     float numDist = 0.0, numColl = 0.0;
//     int duplcount = 0;
//
//     // 每个点的堆
//     vector<priority_queue<IFPair, vector<IFPair>>> vecMaxQueue(FalconnLite::n_points);
//
//     // 每个点的访问记录
//     vector<vector<int>> visited(FalconnLite::n_points);
//     for (auto &v : visited) {
//         v.reserve(n_repeats * FalconnLite::top_m );
//     }
//
//     RowMajorMatrixXi indices = RowMajorMatrixXi::Constant(FalconnLite::n_points, topK, -1);
//     RowMajorMatrixXf distances = RowMajorMatrixXf::Constant(FalconnLite::n_points, topK, numeric_limits<float>::infinity());
//
//     // ---------------- 初始化锁数组 ----------------
//     const size_t NUM_LOCKS = 16384;
//     static omp_lock_t locks[NUM_LOCKS];
//     #pragma omp parallel for
//     for (size_t i = 0; i < NUM_LOCKS; ++i) omp_init_lock(&locks[i]);
//
//     chrono::steady_clock::time_point begin = chrono::steady_clock::now();
//     fht_Index2_repeat(n_repeats);
//     cout << "Construct index time (ms) = "
//          << chrono::duration_cast<chrono::milliseconds>(
//                 chrono::steady_clock::now() - begin).count()
//          << endl;
//
//     // ---------------- 主循环 ----------------
//     #pragma omp parallel reduction(+:numDist, numColl, duplcount)
//     {
//         #pragma omp for schedule(dynamic)
//         for (int repeat = 0; repeat < n_repeats; ++repeat) {
//             for (size_t b = 0; b < FalconnLite::vec2D_Buckets.size(); ++b) {
//                 IVector &bucketPoints = FalconnLite::vec2D_Buckets[b];
//                 numColl += bucketPoints.size();
//
//                 for (size_t i = 0; i < bucketPoints.size(); ++i) {
//                     int idx_i = bucketPoints[i];
//                     for (size_t j = i + 1; j < bucketPoints.size(); ++j) {
//                         int idx_j = bucketPoints[j];
//                         if (idx_i == idx_j) continue;
//
//                         int lock_i = idx_i % NUM_LOCKS;
//                         int lock_j = idx_j % NUM_LOCKS;
//
//                         // ---------------- 计算前检查 visited ----------------
//                         bool skip = false;
//                         omp_set_lock(&locks[lock_i]);
//                         for (auto &v : visited[idx_i])
//                             if (v == idx_j) { skip = true; break; }
//                         omp_unset_lock(&locks[lock_i]);
//
//                         omp_set_lock(&locks[lock_j]);
//                         for (auto &v : visited[idx_j])
//                             if (v == idx_i) { skip = true; break; }
//                         omp_unset_lock(&locks[lock_j]);
//
//                         if (skip) {
//                             duplcount++;
//                             continue;
//                         }
//
//                         // ---------------- 计算距离 ----------------
//                         float dist = computeDist(FalconnLite::matrix_X.row(idx_i),
//                                                  FalconnLite::matrix_X.row(idx_j),
//                                                  FalconnLite::distance);
//                         numDist++;
//
//                         // ---------------- 更新堆和 visited ----------------
//                         bool inserted_i = false, inserted_j = false;
//
//                         // ---------------- 更新堆并立即记录 visited ----------------
//                         omp_set_lock(&locks[lock_i]);
//                         if (find(visited[idx_i].begin(), visited[idx_i].end(), idx_j) == visited[idx_i].end()) {
//                             visited[idx_i].push_back(idx_j);  // 一旦算过距离立即记录
//                             if ((int)vecMaxQueue[idx_i].size() < topK) {
//                                 vecMaxQueue[idx_i].emplace(idx_j, dist);
//                             } else if (dist < vecMaxQueue[idx_i].top().m_fValue) {
//                                 vecMaxQueue[idx_i].pop();
//                                 vecMaxQueue[idx_i].emplace(idx_j, dist);
//                             }
//                         }
//                         omp_unset_lock(&locks[lock_i]);
//
//                         omp_set_lock(&locks[lock_j]);
//                         if (find(visited[idx_j].begin(), visited[idx_j].end(), idx_i) == visited[idx_j].end()) {
//                             visited[idx_j].push_back(idx_i);  //  对称更新
//                             if ((int)vecMaxQueue[idx_j].size() < topK) {
//                                 vecMaxQueue[idx_j].emplace(idx_i, dist);
//                             } else if (dist < vecMaxQueue[idx_j].top().m_fValue) {
//                                 vecMaxQueue[idx_j].pop();
//                                 vecMaxQueue[idx_j].emplace(idx_i, dist);
//                             }
//                         }
//                         omp_unset_lock(&locks[lock_j]);
//
//                         if (!inserted_i && !inserted_j) duplcount++;
//                     }
//                 }
//             }
//         }
//     }
//
//     // ---------------- 输出整理 ----------------
//     #pragma omp parallel for
//     for (int n = 0; n < FalconnLite::n_points; ++n) {
//         int k = topK - 1;
//         while (!vecMaxQueue[n].empty()) {
//             IFPair ifPair = vecMaxQueue[n].top();
//             vecMaxQueue[n].pop();
//             indices(n, k) = ifPair.m_iIndex;
//             distances(n, k) = ifPair.m_fValue;
//             k--;
//         }
//     }
//
//     // ---------------- 销毁锁 ----------------
//     #pragma omp parallel for
//     for (size_t i = 0; i < NUM_LOCKS; ++i) omp_destroy_lock(&locks[i]);
//
//     if (FalconnLite::verbose) {
//         size_t total_visited_count = 0;
//         for (size_t i = 0; i < visited.size(); ++i) total_visited_count += visited[i].size();
//         cout << "Total visited element count: " << total_visited_count << endl;
//         cout << "Duplicate distance computations skipped: " << duplcount << endl;
//         cout << "Avg distance computation per point = " << numDist / FalconnLite::n_points << endl;
//         cout << "Avg number of collisions per point = " << numColl / FalconnLite::n_points << endl;
//     }
//
//     return make_tuple(indices, distances);
// }
//
//

