/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

// -*- c++ -*-

#include <faiss/IndexIVFPQ.h>

#include <array>
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <algorithm>

#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/utils.h>

#include <faiss/Clustering.h>

#include <faiss/utils/hamming.h>

#include <faiss/impl/FaissAssert.h>

#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/IDSelector.h>

#include <faiss/impl/ProductQuantizer.h>

#include <faiss/impl/code_distance/code_distance.h>

namespace faiss {

/*****************************************
 * IndexIVFPQ implementation
 ******************************************/

IndexIVFPQ::IndexIVFPQ(
        Index* quantizer,
        size_t d,
        size_t nlist,
        size_t M,
        size_t nbits_per_idx,
        MetricType metric,
        bool own_invlists)
        : IndexIVF(quantizer, d, nlist, 0, metric, own_invlists),
          pq(d, M, nbits_per_idx) {
    code_size = pq.code_size;
    if (own_invlists) {
        invlists->code_size = code_size;
    }
    is_trained = false;
    by_residual = true;
    use_precomputed_table = 0;
    scan_table_threshold = 0;

    polysemous_training = nullptr;
    do_polysemous_training = false;
    polysemous_ht = 0;
}

/****************************************************************
 * training                                                     */

void IndexIVFPQ::train_encoder(idx_t n, const float* x, const idx_t* assign) {
    pq.train(n, x);

    if (do_polysemous_training) {
        if (verbose)
            printf("doing polysemous training for PQ\n");
        PolysemousTraining default_pt;
        PolysemousTraining* pt =
                polysemous_training ? polysemous_training : &default_pt;
        pt->optimize_pq_for_hamming(pq, n, x);
    }

    if (by_residual) {
        precompute_table();
    }
}

idx_t IndexIVFPQ::train_encoder_num_vectors() const {
    return pq.cp.max_points_per_centroid * pq.ksub;
}

/****************************************************************
 * IVFPQ as codec                                               */

/* produce a binary signature based on the residual vector */
void IndexIVFPQ::encode(idx_t key, const float* x, uint8_t* code) const {
    if (by_residual) {
        std::vector<float> residual_vec(d);
        quantizer->compute_residual(x, residual_vec.data(), key);
        pq.compute_code(residual_vec.data(), code);
    } else
        pq.compute_code(x, code);
}

void IndexIVFPQ::encode_multiple(
        size_t n,
        idx_t* keys,
        const float* x,
        uint8_t* xcodes,
        bool compute_keys) const {
    if (compute_keys)
        quantizer->assign(n, x, keys);

    encode_vectors(n, x, keys, xcodes);
}

void IndexIVFPQ::decode_multiple(
        size_t n,
        const idx_t* keys,
        const uint8_t* xcodes,
        float* x) const {
    pq.decode(xcodes, x, n);
    if (by_residual) {
        std::vector<float> centroid(d);
        for (size_t i = 0; i < n; i++) {
            quantizer->reconstruct(keys[i], centroid.data());
            float* xi = x + i * d;
            for (size_t j = 0; j < d; j++) {
                xi[j] += centroid[j];
            }
        }
    }
}

/****************************************************************
 * add                                                          */

void IndexIVFPQ::add_core(
        idx_t n,
        const float* x,
        const idx_t* xids,
        const idx_t* coarse_idx,
        void* inverted_list_context) {
    add_core_o(n, x, xids, nullptr, coarse_idx, inverted_list_context);
}

static std::unique_ptr<float[]> compute_residuals(
        const Index* quantizer,
        idx_t n,
        const float* x,
        const idx_t* list_nos) {
    size_t d = quantizer->d;
    std::unique_ptr<float[]> residuals(new float[n * d]);
    // TODO: parallelize?
    for (size_t i = 0; i < n; i++) {
        if (list_nos[i] < 0)
            memset(residuals.get() + i * d, 0, sizeof(float) * d);
        else
            quantizer->compute_residual(
                    x + i * d, residuals.get() + i * d, list_nos[i]);
    }
    return residuals;
}

void IndexIVFPQ::encode_vectors(
        idx_t n,
        const float* x,
        const idx_t* list_nos,
        uint8_t* codes,
        bool include_listnos) const {
    if (by_residual) {
        std::unique_ptr<float[]> to_encode =
                compute_residuals(quantizer, n, x, list_nos);
        pq.compute_codes(to_encode.get(), codes, n);
    } else {
        pq.compute_codes(x, codes, n);
    }

    if (include_listnos) {
        size_t coarse_size = coarse_code_size();
        for (idx_t i = n - 1; i >= 0; i--) {
            uint8_t* code = codes + i * (coarse_size + code_size);
            memmove(code + coarse_size, codes + i * code_size, code_size);
            encode_listno(list_nos[i], code);
        }
    }
}

void IndexIVFPQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
    size_t coarse_size = coarse_code_size();

#pragma omp parallel
    {
        std::vector<float> residual(d);

#pragma omp for
        for (idx_t i = 0; i < n; i++) {
            const uint8_t* code = codes + i * (code_size + coarse_size);
            int64_t list_no = decode_listno(code);
            float* xi = x + i * d;
            pq.decode(code + coarse_size, xi);
            if (by_residual) {
                quantizer->reconstruct(list_no, residual.data());
                for (size_t j = 0; j < d; j++) {
                    xi[j] += residual[j];
                }
            }
        }
    }
}

// block size used in IndexIVFPQ::add_core_o
int index_ivfpq_add_core_o_bs = 32768;

void IndexIVFPQ::add_core_o(
        idx_t n,
        const float* x,
        const idx_t* xids,
        float* residuals_2,
        const idx_t* precomputed_idx,
        void* inverted_list_context) {
    idx_t bs = index_ivfpq_add_core_o_bs;
    if (n > bs) {
        for (idx_t i0 = 0; i0 < n; i0 += bs) {
            idx_t i1 = std::min(i0 + bs, n);
            if (verbose) {
                printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64
                       " / %" PRId64 "\n",
                       i0,
                       i1,
                       n);
            }
            add_core_o(
                    i1 - i0,
                    x + i0 * d,
                    xids ? xids + i0 : nullptr,
                    residuals_2 ? residuals_2 + i0 * d : nullptr,
                    precomputed_idx ? precomputed_idx + i0 : nullptr,
                    inverted_list_context);
        }
        return;
    }

    InterruptCallback::check();

    direct_map.check_can_add(xids);

    FAISS_THROW_IF_NOT(is_trained);
    double t0 = getmillisecs();
    const idx_t* idx;
    std::unique_ptr<idx_t[]> del_idx;

    if (precomputed_idx) {
        idx = precomputed_idx;
    } else {
        idx_t* idx0 = new idx_t[n];
        del_idx.reset(idx0);
        quantizer->assign(n, x, idx0);
        idx = idx0;
    }

    double t1 = getmillisecs();
    std::unique_ptr<uint8_t[]> xcodes(new uint8_t[n * code_size]);

    const float* to_encode = nullptr;
    std::unique_ptr<const float[]> del_to_encode;

    if (by_residual) {
        del_to_encode = compute_residuals(quantizer, n, x, idx);
        to_encode = del_to_encode.get();
    } else {
        to_encode = x;
    }
    pq.compute_codes(to_encode, xcodes.get(), n);

    double t2 = getmillisecs();
    // TODO: parallelize?
    size_t n_ignore = 0;
    for (size_t i = 0; i < n; i++) {
        idx_t key = idx[i];
        idx_t id = xids ? xids[i] : ntotal + i;
        if (key < 0) {
            direct_map.add_single_id(id, -1, 0);
            n_ignore++;
            if (residuals_2)
                memset(residuals_2, 0, sizeof(*residuals_2) * d);
            continue;
        }

        uint8_t* code = xcodes.get() + i * code_size;
        size_t offset =
                invlists->add_entry(key, id, code, inverted_list_context);

        if (residuals_2) {
            float* res2 = residuals_2 + i * d;
            const float* xi = to_encode + i * d;
            pq.decode(code, res2);
            for (int j = 0; j < d; j++)
                res2[j] = xi[j] - res2[j];
        }

        direct_map.add_single_id(id, key, offset);
    }

    double t3 = getmillisecs();
    if (verbose) {
        char comment[100] = {0};
        if (n_ignore > 0)
            snprintf(comment, 100, "(%zd vectors ignored)", n_ignore);
        printf(" add_core times: %.3f %.3f %.3f %s\n",
               t1 - t0,
               t2 - t1,
               t3 - t2,
               comment);
    }
    ntotal += n;
}

void IndexIVFPQ::reconstruct_from_offset(
        int64_t list_no,
        int64_t offset,
        float* recons) const {
    const uint8_t* code = invlists->get_single_code(list_no, offset);

    pq.decode(code, recons);
    if (by_residual) {
        std::vector<float> centroid(d);
        quantizer->reconstruct(list_no, centroid.data());

        for (int i = 0; i < d; ++i) {
            recons[i] += centroid[i];
        }
    }
}

void IndexIVFPQ::decode_from_offset(
        int64_t list_no,
        int64_t offset,
        float* recons) const {
    const uint8_t* code = invlists->get_single_code(list_no, offset);
    pq.decode(code, recons);
}

/// 2G by default, accommodates tables up to PQ32 w/ 65536 centroids
size_t precomputed_table_max_bytes = ((size_t)1) << 31;

/** Precomputed tables for residuals
 *
 * During IVFPQ search with by_residual, we compute
 *
 *     d = || x - y_C - y_R ||^2
 *
 * where x is the query vector, y_C the coarse centroid, y_R the
 * refined PQ centroid. The expression can be decomposed as:
 *
 *    d = || x - y_C ||^2 + || y_R ||^2 + 2 * (y_C|y_R) - 2 * (x|y_R)
 *        ---------------   ---------------------------       -------
 *             term 1                 term 2                   term 3
 *
 * When using multiprobe, we use the following decomposition:
 * - term 1 is the distance to the coarse centroid, that is computed
 *   during the 1st stage search.
 * - term 2 can be precomputed, as it does not involve x. However,
 *   because of the PQ, it needs nlist * M * ksub storage. This is why
 *   use_precomputed_table is off by default
 * - term 3 is the classical non-residual distance table.
 *
 * Since y_R defined by a product quantizer, it is split across
 * subvectors and stored separately for each subvector. If the coarse
 * quantizer is a MultiIndexQuantizer then the table can be stored
 * more compactly.
 *
 * At search time, the tables for term 2 and term 3 are added up. This
 * is faster when the length of the lists is > ksub * M.
 */

void initialize_IVFPQ_precomputed_table(
        int& use_precomputed_table,
        const Index* quantizer,
        const ProductQuantizer& pq,
        AlignedTable<float>& precomputed_table,
        bool by_residual,
        bool verbose) {
    size_t nlist = quantizer->ntotal;
    size_t d = quantizer->d;
    FAISS_THROW_IF_NOT(d == pq.d);

    if (use_precomputed_table == -1) {
        precomputed_table.resize(0);
        return;
    }

    if (use_precomputed_table == 0) { // then choose the type of table
        if (!(quantizer->metric_type == METRIC_L2 && by_residual)) {
            if (verbose) {
                printf("IndexIVFPQ::precompute_table: precomputed "
                       "tables needed only for L2 metric and by_residual is enabled\n");
            }
            precomputed_table.resize(0);
            return;
        }
        const MultiIndexQuantizer* miq =
                dynamic_cast<const MultiIndexQuantizer*>(quantizer);
        if (miq && pq.M % miq->pq.M == 0)
            use_precomputed_table = 2;
        else {
            size_t table_size = pq.M * pq.ksub * nlist * sizeof(float);
            if (table_size > precomputed_table_max_bytes) {
                if (verbose) {
                    printf("IndexIVFPQ::precompute_table: not precomputing table, "
                           "it would be too big: %zd bytes (max %zd)\n",
                           table_size,
                           precomputed_table_max_bytes);
                    use_precomputed_table = 0;
                }
                return;
            }
            use_precomputed_table = 1;
        }
    } // otherwise assume user has set appropriate flag on input

    if (verbose) {
        printf("precomputing IVFPQ tables type %d\n", use_precomputed_table);
    }

    // squared norms of the PQ centroids
    std::vector<float> r_norms(pq.M * pq.ksub, NAN);
    for (int m = 0; m < pq.M; m++)
        for (int j = 0; j < pq.ksub; j++)
            r_norms[m * pq.ksub + j] =
                    fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);

    if (use_precomputed_table == 1) {
        precomputed_table.resize(nlist * pq.M * pq.ksub);
        std::vector<float> centroid(d);

        for (size_t i = 0; i < nlist; i++) {
            quantizer->reconstruct(i, centroid.data());

            float* tab = &precomputed_table[i * pq.M * pq.ksub];
            pq.compute_inner_prod_table(centroid.data(), tab);

            // compute c := a + bf * b for a, b and c tables
            fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
        }
    } else if (use_precomputed_table == 2) {
        const MultiIndexQuantizer* miq =
                dynamic_cast<const MultiIndexQuantizer*>(quantizer);
        FAISS_THROW_IF_NOT(miq);
        const ProductQuantizer& cpq = miq->pq;
        FAISS_THROW_IF_NOT(pq.M % cpq.M == 0);

        precomputed_table.resize(cpq.ksub * pq.M * pq.ksub);

        // reorder PQ centroid table
        std::vector<float> centroids(d * cpq.ksub, NAN);

        for (int m = 0; m < cpq.M; m++) {
            for (size_t i = 0; i < cpq.ksub; i++) {
                memcpy(centroids.data() + i * d + m * cpq.dsub,
                       cpq.get_centroids(m, i),
                       sizeof(*centroids.data()) * cpq.dsub);
            }
        }

        pq.compute_inner_prod_tables(
                cpq.ksub, centroids.data(), precomputed_table.data());

        for (size_t i = 0; i < cpq.ksub; i++) {
            float* tab = &precomputed_table[i * pq.M * pq.ksub];
            fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
        }
    }
}

void IndexIVFPQ::precompute_table() {
    initialize_IVFPQ_precomputed_table(
            use_precomputed_table,
            quantizer,
            pq,
            precomputed_table,
            by_residual,
            verbose);
}

namespace {

#define TIC t0 = get_cycles()
#define TOC get_cycles() - t0

/** QueryTables manages the various ways of searching an
 * IndexIVFPQ. The code contains a lot of branches, depending on:
 * - metric_type: are we computing L2 or Inner product similarity?
 * - by_residual: do we encode raw vectors or residuals?
 * - use_precomputed_table: are x_R|x_C tables precomputed?
 * - polysemous_ht: are we filtering with polysemous codes?
 */
struct QueryTables {
    /*****************************************************
     * General data from the IVFPQ
     *****************************************************/

    const IndexIVFPQ& ivfpq;
    const IVFSearchParameters* params;

    // copied from IndexIVFPQ for easier access
    int d;
    const ProductQuantizer& pq;
    MetricType metric_type;
    bool by_residual;
    int use_precomputed_table;
    int polysemous_ht;

    // pre-allocated data buffers
    float *sim_table, *sim_table_2;
    float *residual_vec, *decoded_vec;

    // single data buffer
    std::vector<float> mem;

    // for table pointers
    std::vector<const float*> sim_table_ptrs;

    explicit QueryTables(
            const IndexIVFPQ& ivfpq,
            const IVFSearchParameters* params)
            : ivfpq(ivfpq),
              d(ivfpq.d),
              pq(ivfpq.pq),
              metric_type(ivfpq.metric_type),
              by_residual(ivfpq.by_residual),
              use_precomputed_table(ivfpq.use_precomputed_table) {
        mem.resize(pq.ksub * pq.M * 2 + d * 2);
        sim_table = mem.data();
        sim_table_2 = sim_table + pq.ksub * pq.M;
        residual_vec = sim_table_2 + pq.ksub * pq.M;
        decoded_vec = residual_vec + d;

        // for polysemous
        polysemous_ht = ivfpq.polysemous_ht;
        if (auto ivfpq_params =
                    dynamic_cast<const IVFPQSearchParameters*>(params)) {
            polysemous_ht = ivfpq_params->polysemous_ht;
        }
        if (polysemous_ht != 0) {
            q_code.resize(pq.code_size);
        }
        init_list_cycles = 0;
        sim_table_ptrs.resize(pq.M);
    }

    /*****************************************************
     * What we do when query is known
     *****************************************************/

    // field specific to query
    const float* qi;

    // query-specific initialization
    void init_query(const float* qi) {
        this->qi = qi;
        if (metric_type == METRIC_INNER_PRODUCT)
            init_query_IP();
        else
            init_query_L2();
        if (!by_residual && polysemous_ht != 0)
            pq.compute_code(qi, q_code.data());
    }

    void init_query_IP() {
        // precompute some tables specific to the query qi
        pq.compute_inner_prod_table(qi, sim_table);
    }

    void init_query_L2() {
        if (!by_residual) {
            pq.compute_distance_table(qi, sim_table);
        } else if (use_precomputed_table) {
            pq.compute_inner_prod_table(qi, sim_table_2);
        }
    }

    /*****************************************************
     * When inverted list is known: prepare computations
     *****************************************************/

    // fields specific to list
    idx_t key;
    float coarse_dis;
    std::vector<uint8_t> q_code;

    uint64_t init_list_cycles;

    /// once we know the query and the centroid, we can prepare the
    /// sim_table that will be used for accumulation
    /// and dis0, the initial value
    float precompute_list_tables() {
        float dis0 = 0;
        uint64_t t0;
        TIC;
        if (by_residual) {
            if (metric_type == METRIC_INNER_PRODUCT)
                dis0 = precompute_list_tables_IP();
            else
                dis0 = precompute_list_tables_L2();
        }
        init_list_cycles += TOC;
        return dis0;
    }

    // copy of precompute_list_tables with cache dest ptr
    float precompute_list_tables_panorama(float* sim_table_ptr) {
        float dis0 = 0;
        uint64_t t0;
        TIC;
        if (by_residual) {
            if (metric_type == METRIC_INNER_PRODUCT)
                dis0 = precompute_list_tables_IP();
            else
                dis0 = precompute_list_tables_L2_panorama(sim_table_ptr);
        }
        init_list_cycles += TOC;
        return dis0;
    }

    float precompute_list_table_pointers() {
        float dis0 = 0;
        uint64_t t0;
        TIC;
        if (by_residual) {
            if (metric_type == METRIC_INNER_PRODUCT)
                FAISS_THROW_MSG("not implemented");
            else
                dis0 = precompute_list_table_pointers_L2();
        }
        init_list_cycles += TOC;
        return dis0;
    }

    /*****************************************************
     * compute tables for inner prod
     *****************************************************/

    float precompute_list_tables_IP() {
        // prepare the sim_table that will be used for accumulation
        // and dis0, the initial value
        ivfpq.quantizer->reconstruct(key, decoded_vec);
        // decoded_vec = centroid
        float dis0 = fvec_inner_product(qi, decoded_vec, d);

        if (polysemous_ht) {
            for (int i = 0; i < d; i++) {
                residual_vec[i] = qi[i] - decoded_vec[i];
            }
            pq.compute_code(residual_vec, q_code.data());
        }
        return dis0;
    }

    /*****************************************************
     * compute tables for L2 distance
     *****************************************************/

    float precompute_list_tables_L2() {
        float dis0 = 0;

        if (use_precomputed_table == 0 || use_precomputed_table == -1) {
            ivfpq.quantizer->compute_residual(qi, residual_vec, key);
            pq.compute_distance_table(residual_vec, sim_table);

            if (polysemous_ht != 0) {
                pq.compute_code(residual_vec, q_code.data());
            }

        } else if (use_precomputed_table == 1) {
            dis0 = coarse_dis;

            fvec_madd(
                    pq.M * pq.ksub,
                    ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
                    -2.0,
                    sim_table_2,
                    sim_table);

            if (polysemous_ht != 0) {
                ivfpq.quantizer->compute_residual(qi, residual_vec, key);
                pq.compute_code(residual_vec, q_code.data());
            }

        } else if (use_precomputed_table == 2) {
            dis0 = coarse_dis;

            const MultiIndexQuantizer* miq =
                    dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
            FAISS_THROW_IF_NOT(miq);
            const ProductQuantizer& cpq = miq->pq;
            int Mf = pq.M / cpq.M;

            const float* qtab = sim_table_2; // query-specific table
            float* ltab = sim_table;         // (output) list-specific table

            long k = key;
            for (int cm = 0; cm < cpq.M; cm++) {
                // compute PQ index
                int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
                k >>= cpq.nbits;

                // get corresponding table
                const float* pc = ivfpq.precomputed_table.data() +
                        (ki * pq.M + cm * Mf) * pq.ksub;

                if (polysemous_ht == 0) {
                    // sum up with query-specific table
                    fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
                    ltab += Mf * pq.ksub;
                    qtab += Mf * pq.ksub;
                } else {
                    for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
                        q_code[m] = fvec_madd_and_argmin(
                                pq.ksub, pc, -2, qtab, ltab);
                        pc += pq.ksub;
                        ltab += pq.ksub;
                        qtab += pq.ksub;
                    }
                }
            }
        }

        return dis0;
    }

    float precompute_list_tables_L2_panorama(float* sim_table_ptr) {
        float dis0 = 0;

        if (use_precomputed_table == 0 || use_precomputed_table == -1) {
            FAISS_ASSERT(false); // this path not supported
            ivfpq.quantizer->compute_residual(qi, residual_vec, key);
            pq.compute_distance_table(residual_vec, sim_table);

            if (polysemous_ht != 0) {
                pq.compute_code(residual_vec, q_code.data());
            }

        } else if (use_precomputed_table == 1) {
            // only supported path for panorama
            dis0 = coarse_dis;

            const size_t n = pq.M * pq.ksub;
            const float bf = -2.0;
            const float* b = sim_table_2;
            float* c = sim_table_ptr;

            const size_t n16 = n / 16;
            const size_t n_for_masking = n % 16;

            const __m512 bfmm = _mm512_set1_ps(bf);

            size_t idx = 0;
            for (idx = 0; idx < n16 * 16; idx += 16) {
                const __m512 bx = _mm512_loadu_ps(b + idx);
                const __m512 abmul = _mm512_mul_ps(bfmm, bx);
                _mm512_storeu_ps(c + idx, abmul);
            }

            if (n_for_masking > 0) {
                const __mmask16 mask = (1 << n_for_masking) - 1;
                const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
                const __m512 abmul = _mm512_mul_ps(bfmm, bx);
                _mm512_mask_storeu_ps(c + idx, mask, abmul);
            }

            // point to the cached sim_table
            sim_table = sim_table_ptr;

            if (polysemous_ht != 0) {
                ivfpq.quantizer->compute_residual(qi, residual_vec, key);
                pq.compute_code(residual_vec, q_code.data());
            }
        } else if (use_precomputed_table == 2) {
            FAISS_ASSERT(false); // this path not supported
            dis0 = coarse_dis;

            const MultiIndexQuantizer* miq =
                    dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
            FAISS_THROW_IF_NOT(miq);
            const ProductQuantizer& cpq = miq->pq;
            int Mf = pq.M / cpq.M;

            const float* qtab = sim_table_2; // query-specific table
            float* ltab = sim_table;         // (output) list-specific table

            long k = key;
            for (int cm = 0; cm < cpq.M; cm++) {
                // compute PQ index
                int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
                k >>= cpq.nbits;

                // get corresponding table
                const float* pc = ivfpq.precomputed_table.data() +
                        (ki * pq.M + cm * Mf) * pq.ksub;

                if (polysemous_ht == 0) {
                    // sum up with query-specific table
                    fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
                    ltab += Mf * pq.ksub;
                    qtab += Mf * pq.ksub;
                } else {
                    for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
                        q_code[m] = fvec_madd_and_argmin(
                                pq.ksub, pc, -2, qtab, ltab);
                        pc += pq.ksub;
                        ltab += pq.ksub;
                        qtab += pq.ksub;
                    }
                }
            }
        }
        return dis0;
    }

    float precompute_list_table_pointers_L2() {
        float dis0 = 0;

        if (use_precomputed_table == 1) {
            dis0 = coarse_dis;

            const float* s =
                    ivfpq.precomputed_table.data() + key * pq.ksub * pq.M;
            for (int m = 0; m < pq.M; m++) {
                sim_table_ptrs[m] = s;
                s += pq.ksub;
            }
        } else if (use_precomputed_table == 2) {
            dis0 = coarse_dis;

            const MultiIndexQuantizer* miq =
                    dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
            FAISS_THROW_IF_NOT(miq);
            const ProductQuantizer& cpq = miq->pq;
            int Mf = pq.M / cpq.M;

            long k = key;
            int m0 = 0;
            for (int cm = 0; cm < cpq.M; cm++) {
                int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
                k >>= cpq.nbits;

                const float* pc = ivfpq.precomputed_table.data() +
                        (ki * pq.M + cm * Mf) * pq.ksub;

                for (int m = m0; m < m0 + Mf; m++) {
                    sim_table_ptrs[m] = pc;
                    pc += pq.ksub;
                }
                m0 += Mf;
            }
        } else {
            FAISS_THROW_MSG("need precomputed tables");
        }

        if (polysemous_ht) {
            FAISS_THROW_MSG("not implemented");
            // Not clear that it makes sense to implemente this,
            // because it costs M * ksub, which is what we wanted to
            // avoid with the tables pointers.
        }

        return dis0;
    }
};

// Addendum to only add if not already in the heap
template <class C, bool use_sel>
struct KnnSearchResultsPanorama {
    idx_t key;
    const idx_t* ids;
    const IDSelector* sel;

    // heap params
    size_t k;
    float* heap_sim;
    idx_t* heap_ids;

    size_t nup;

    inline bool skip_entry(idx_t j) {
        return use_sel && !sel->is_member(ids[j]);
    }

    inline bool should_keep(float dis) {
        return C::cmp(heap_sim[0], dis);
    }

    inline float top() {
        return heap_sim[0];
    }

    // Find ID in heap and return index, or k if not found
    inline size_t find_in_heap(idx_t id) {
        for (size_t i = 0; i < k; i++) {
            if (heap_ids[i] == id) {
                return i;
            }
        }
        return k; // not found
    }

    inline void add(idx_t j, float dis) {
        // Fast path: if distance is too large, bail early
        if (!C::cmp(heap_sim[0], dis)) {
            return;
        }

        idx_t id = ids ? ids[j] : lo_build(key, j);
        size_t pos = find_in_heap(id);

        if (pos < k) {
            heap_sim[pos] = dis;
            heap_fix_single_element<CMax<float, idx_t>>(k, heap_sim, heap_ids, pos);
            nup++;
        } else {
            // ID not in heap - normal insertion (we already know distance
            // qualifies)
            heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
            nup++;
        }
    }
};

// This way of handling the selector is not optimal since all distances
// are computed even if the id would filter it out.
template <class C, bool use_sel>
struct KnnSearchResults {
    idx_t key;
    const idx_t* ids;
    const IDSelector* sel;

    // heap params
    size_t k;
    float* heap_sim;
    idx_t* heap_ids;

    size_t nup;

    inline bool skip_entry(idx_t j) {
        return use_sel && !sel->is_member(ids[j]);
    }

    inline void add(idx_t j, float dis) {
        if (C::cmp(heap_sim[0], dis)) {
            idx_t id = ids ? ids[j] : lo_build(key, j);
            heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
            nup++;
        }
    }
};

template <class C, bool use_sel>
struct RangeSearchResults {
    idx_t key;
    const idx_t* ids;
    const IDSelector* sel;

    // wrapped result structure
    float radius;
    RangeQueryResult& rres;

    inline bool skip_entry(idx_t j) {
        return use_sel && !sel->is_member(ids[j]);
    }

    inline void add(idx_t j, float dis) {
        if (C::cmp(radius, dis)) {
            idx_t id = ids ? ids[j] : lo_build(key, j);
            rres.add(dis, id);
        }
    }
};

/*****************************************************
 * Scaning the codes.
 * The scanning functions call their favorite precompute_*
 * function to precompute the tables they need.
 *****************************************************/
template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
struct IVFPQScannerT : QueryTables {
    const uint8_t* list_codes;
    const IDType* list_ids;
    size_t list_size;

    IVFPQScannerT(const IndexIVFPQ& ivfpq, const IVFSearchParameters* params)
            : QueryTables(ivfpq, params) {
        assert(METRIC_TYPE == metric_type);
    }

    float dis0;

    void init_list(idx_t list_no, float coarse_dis, int mode) {
        this->key = list_no;
        this->coarse_dis = coarse_dis;

        if (mode == 2) {
            dis0 = precompute_list_tables();
        } else if (mode == 1) {
            dis0 = precompute_list_table_pointers();
        }
    }

    void init_list_panorama(
            idx_t list_no,
            float coarse_dis,
            int mode,
            float* sim_table,
            float* dis0_ptr,
            bool update) {
        this->key = list_no;
        this->coarse_dis = coarse_dis;

        if (mode == 2) {
            // update cache values
            if (update) {
                *dis0_ptr = precompute_list_tables_panorama(
                        sim_table); // we go down this path
            }
            dis0 = *dis0_ptr;
        } else if (mode == 1) {
            dis0 = precompute_list_table_pointers();
        }
    }

    /*****************************************************
     * Scaning the codes: simple PQ scan.
     *****************************************************/

    // This is the baseline version of scan_list_with_tables().
    // It demonstrates what this function actually does.
    //
    // /// version of the scan where we use precomputed tables.
    // template <class SearchResultType>
    // void scan_list_with_table(
    //         size_t ncode,
    //         const uint8_t* codes,
    //         SearchResultType& res) const {
    //
    //     for (size_t j = 0; j < ncode; j++, codes += pq.code_size) {
    //         if (res.skip_entry(j)) {
    //             continue;
    //         }
    //         float dis = dis0 + distance_single_code<PQDecoder>(
    //             pq, sim_table, codes);
    //         res.add(j, dis);
    //     }
    // }

    // This is the modified version of scan_list_with_tables().
    // It was observed that doing manual unrolling of the loop that
    //    utilizes distance_single_code() speeds up the computations.

    /// version of the scan where we use precomputed tables.
    template <class SearchResultType>
    void scan_list_with_table(
            size_t ncode,
            const uint8_t* codes,
            SearchResultType& res) const {
        int counter = 0;

        size_t saved_j[4] = {0, 0, 0, 0};
        for (size_t j = 0; j < ncode; j++) {
            if (res.skip_entry(j)) {
                continue;
            }

            saved_j[0] = (counter == 0) ? j : saved_j[0];
            saved_j[1] = (counter == 1) ? j : saved_j[1];
            saved_j[2] = (counter == 2) ? j : saved_j[2];
            saved_j[3] = (counter == 3) ? j : saved_j[3];

            counter += 1;
            if (counter == 4) {
                float distance_0 = 0;
                float distance_1 = 0;
                float distance_2 = 0;
                float distance_3 = 0;
                distance_four_codes<PQDecoder>(
                        pq.M,
                        pq.nbits,
                        sim_table,
                        codes + saved_j[0] * pq.code_size,
                        codes + saved_j[1] * pq.code_size,
                        codes + saved_j[2] * pq.code_size,
                        codes + saved_j[3] * pq.code_size,
                        distance_0,
                        distance_1,
                        distance_2,
                        distance_3);

                res.add(saved_j[0], dis0 + distance_0);
                res.add(saved_j[1], dis0 + distance_1);
                res.add(saved_j[2], dis0 + distance_2);
                res.add(saved_j[3], dis0 + distance_3);
                counter = 0;
            }
        }

        if (counter >= 1) {
            float dis = dis0 +
                    distance_single_code<PQDecoder>(
                                pq.M,
                                pq.nbits,
                                sim_table,
                                codes + saved_j[0] * pq.code_size);
            res.add(saved_j[0], dis);
        }
        if (counter >= 2) {
            float dis = dis0 +
                    distance_single_code<PQDecoder>(
                                pq.M,
                                pq.nbits,
                                sim_table,
                                codes + saved_j[1] * pq.code_size);
            res.add(saved_j[1], dis);
        }
        if (counter >= 3) {
            float dis = dis0 +
                    distance_single_code<PQDecoder>(
                                pq.M,
                                pq.nbits,
                                sim_table,
                                codes + saved_j[2] * pq.code_size);
            res.add(saved_j[2], dis);
        }
    }
    
    template <class SearchResultType, faiss::PanoramaScanMode ScanMode, bool execute_pruning, uint8_t m_rem>
    void scan_list_with_table_panorama_core(
            // Important metadata.
            size_t list_size,
            size_t m_level_width,
            float epsilon,
            // Input codes.
            const uint8_t* codes,
            uint8_t* const* codes_offset,
            // Output data.
            uint8_t* bitset,
            SearchResultType& res,
            // Cumsums.
            const float* cum_sums,
            float query_cum_norm,
            // Batch allocations (zeroed out).
            uint8_t* batch_storage,
            uint8_t** batch_offsets,
            size_t batch_size, // Unit is bytes.
            // Indices
            size_t* num_active,
            uint32_t* indices,
            float* exact_distances,
            float* sim_table_ptr,
            size_t m) const {
        // We reset here for a new `M` level.
        // Notice how it does not get reset for the last level.
        uint32_t num_processed = 0;
        const uint8_t* bitset_ptr = bitset;
        size_t rem_list_size = list_size;
        size_t remainder_idx = 0;

        const uint8_t* codes_ptr = codes;
        __m512i zero = _mm512_setzero_si512();

        uint8_t* codes_ptr_0 = codes_offset[m + 0];
        uint8_t* codes_ptr_1 = nullptr;
        if constexpr (m_rem > 1) {
            codes_ptr_1 = codes_offset[m + 1];
        }
        uint8_t* codes_ptr_2 = nullptr;
        if constexpr (m_rem > 2) {
            codes_ptr_2 = codes_offset[m + 2];
        }
        uint8_t* codes_ptr_3 = nullptr;
        if constexpr (m_rem > 3) {
            codes_ptr_3 = codes_offset[m + 3];
        }

        uint32_t* indices_dest_ptr = indices;
        float* exact_distances_dest_ptr = exact_distances;

        __m512 query_cum_norm_broadcast = _mm512_setzero_ps();
        __m512 dis0_broadcast = _mm512_setzero_ps();
        __m512 epsilon_broadcast = _mm512_set1_ps(epsilon);

        size_t new_num_active = 0;

        if constexpr (execute_pruning) {
            query_cum_norm_broadcast = _mm512_set1_ps(query_cum_norm);
            dis0_broadcast = _mm512_set1_ps(dis0);
                query_cum_norm_broadcast =
                    _mm512_mul_ps(query_cum_norm_broadcast, _mm512_set1_ps(2.0f));
        }

        while (rem_list_size > 0) {
            // The following for loop fully populates the batch as much as
            // it can in batches of 64 elements (starting the batch from
            // scratch). We piggy-back the following values so that they get
            // reset for every `M` level.
            size_t rem_batch_size = batch_size / 4;

            uint8_t* batch_storage_ptr_0 = batch_offsets[0];
            uint8_t* batch_storage_ptr_1 = batch_offsets[1];
            uint8_t* batch_storage_ptr_2 = batch_offsets[2];
            uint8_t* batch_storage_ptr_3 = batch_offsets[3];

            if constexpr (ScanMode == faiss::PanoramaScanMode::Dense) {
                while (rem_list_size >= 64 && rem_batch_size >= 64) {
                    __m512i bytes =
                            _mm512_loadu_si512((__m512i*)bitset_ptr);
                    __mmask64 mask = _mm512_cmpneq_epi8_mask(bytes, zero);

                    size_t num_active_real = _mm_popcnt_u64(mask);

                    __m512i codes_batch_0 = _mm512_loadu_si512(codes_ptr_0);
                    __m512i codes_batch_1;
                    if constexpr (m_rem > 1) {
                        codes_batch_1 = _mm512_loadu_si512(codes_ptr_1);
                    }
                    __m512i codes_batch_2;
                    if constexpr (m_rem > 2) {
                        codes_batch_2 = _mm512_loadu_si512(codes_ptr_2);
                    }
                    __m512i codes_batch_3;
                    if constexpr (m_rem > 3) {
                        codes_batch_3 = _mm512_loadu_si512(codes_ptr_3);
                    }

                    __m512i compressed_0 =
                            _mm512_maskz_compress_epi8(mask, codes_batch_0);
                    _mm512_storeu_si512(batch_storage_ptr_0, compressed_0);
                    if constexpr (m_rem > 1) {
                        __m512i compressed_1 =
                                _mm512_maskz_compress_epi8(mask, codes_batch_1);
                        _mm512_storeu_si512(batch_storage_ptr_1, compressed_1);
                    }
                    if constexpr (m_rem > 2) {
                        __m512i compressed_2 =
                                _mm512_maskz_compress_epi8(mask, codes_batch_2);
                        _mm512_storeu_si512(batch_storage_ptr_2, compressed_2);
                    }
                    if constexpr (m_rem > 3) {
                        __m512i compressed_3 =
                                _mm512_maskz_compress_epi8(mask, codes_batch_3);
                        _mm512_storeu_si512(batch_storage_ptr_3, compressed_3);
                    }

                    batch_storage_ptr_0 += num_active_real;
                    batch_storage_ptr_1 += num_active_real;
                    batch_storage_ptr_2 += num_active_real;
                    batch_storage_ptr_3 += num_active_real;

                    codes_ptr_0 += 64;
                    codes_ptr_1 += 64;
                    codes_ptr_2 += 64;
                    codes_ptr_3 += 64;

                    bitset_ptr += 64;

                    rem_list_size -= 64;
                    rem_batch_size -= num_active_real;
                }

                // Check if < 64 elements remain and put them into current
                // batch. This is guaranteed to happen at most once, for the
                // very last batch.
                // bench this if statement
                if (rem_list_size < 64 && rem_list_size < rem_batch_size) {
                    // Here we need `curr_batch_size` as we do *not* start
                    // from the beginning of the batch anymore.
                    for (size_t i = 0; i < rem_list_size; i++) {
                        if (bitset_ptr[i]) {
                            *batch_storage_ptr_0 = codes_ptr_0[i];
                            if constexpr (m_rem > 1) {
                                *batch_storage_ptr_1 = codes_ptr_1[i];
                            }
                            if constexpr (m_rem > 2) {
                                *batch_storage_ptr_2 = codes_ptr_2[i];
                            }
                            if constexpr (m_rem > 3) {
                                *batch_storage_ptr_3 = codes_ptr_3[i];
                            }

                            batch_storage_ptr_0++;
                            batch_storage_ptr_1++;
                            batch_storage_ptr_2++;
                            batch_storage_ptr_3++;
                        }
                    }
                    rem_list_size = 0;
                }
            } else if constexpr (
                    ScanMode == faiss::PanoramaScanMode::Sparse) {
                while (remainder_idx < *num_active && rem_batch_size > 0) {
                    *batch_storage_ptr_0 =
                            codes_ptr_0[indices[remainder_idx]];
                    if constexpr (m_rem > 1) {
                        *batch_storage_ptr_1 =
                                codes_ptr_1[indices[remainder_idx]];
                    }
                    if constexpr (m_rem > 2) {
                        *batch_storage_ptr_2 =
                                codes_ptr_2[indices[remainder_idx]];
                    }
                    if constexpr (m_rem > 3) {
                        *batch_storage_ptr_3 =
                                codes_ptr_3[indices[remainder_idx]];
                    }

                    batch_storage_ptr_0++;
                    batch_storage_ptr_1++;
                    batch_storage_ptr_2++;
                    batch_storage_ptr_3++;

                    rem_batch_size--;
                    remainder_idx++;
                }

                if (remainder_idx == *num_active) { // tricky stopping condition
                    rem_list_size = 0;
                }
            }

            // Compute distances for current batch (size is number of points
            // in batch)
            size_t curr_batch_size = batch_storage_ptr_0 - batch_offsets[0];

            float* exact_distances_ptr = exact_distances + num_processed;
            batch_storage_ptr_0 = batch_offsets[0];
            batch_storage_ptr_1 = batch_offsets[1];
            batch_storage_ptr_2 = batch_offsets[2];
            batch_storage_ptr_3 = batch_offsets[3];

            if constexpr (ScanMode == faiss::PanoramaScanMode::Full) {
                exact_distances_ptr = exact_distances;

                batch_storage_ptr_0 = codes_ptr_0;
                batch_storage_ptr_1 = codes_ptr_1;
                batch_storage_ptr_2 = codes_ptr_2;
                batch_storage_ptr_3 = codes_ptr_3;

                curr_batch_size = list_size;

                rem_list_size = 0;
            }

            size_t batch_offset = 0;
            for (; batch_offset + 15 < curr_batch_size;
                    batch_offset += 16) {
                // 1: Load 64 int8 codes
                __m128i codes_batch_0 =
                        _mm_loadu_si128((__m128i*)(batch_storage_ptr_0));
                __m128i codes_batch_1;
                if constexpr (m_rem > 1) {
                    codes_batch_1 =
                            _mm_loadu_si128((__m128i*)(batch_storage_ptr_1));
                }
                __m128i codes_batch_2;
                if constexpr (m_rem > 2) {
                    codes_batch_2 =
                            _mm_loadu_si128((__m128i*)(batch_storage_ptr_2));
                }
                __m128i codes_batch_3;
                if constexpr (m_rem > 3) {
                    codes_batch_3 =
                            _mm_loadu_si128((__m128i*)(batch_storage_ptr_3));
                }

                // 2: Convert to int32
                __m512i codes_batch_int_0 =
                        _mm512_cvtepu8_epi32(codes_batch_0);
                __m512i codes_batch_int_1;
                if constexpr (m_rem > 1) {
                    codes_batch_int_1 =
                            _mm512_cvtepu8_epi32(codes_batch_1);
                }
                __m512i codes_batch_int_2;
                if constexpr (m_rem > 2) {
                    codes_batch_int_2 =
                            _mm512_cvtepu8_epi32(codes_batch_2);
                }
                __m512i codes_batch_int_3;
                if constexpr (m_rem > 3) {
                    codes_batch_int_3 =
                            _mm512_cvtepu8_epi32(codes_batch_3);
                }

                // 3: Gather 64 distances
                __m512 m_dist_0 = _mm512_i32gather_ps(
                        codes_batch_int_0, sim_table_ptr, sizeof(float));
                __m512 m_dist_1;
                if constexpr (m_rem > 1) {
                    m_dist_1 = _mm512_i32gather_ps(
                            codes_batch_int_1,
                            sim_table_ptr + 256,
                            sizeof(float));
                }
                __m512 m_dist_2;
                if constexpr (m_rem > 2) {
                    m_dist_2 = _mm512_i32gather_ps(
                            codes_batch_int_2,
                            sim_table_ptr + 512,
                            sizeof(float));
                }
                __m512 m_dist_3;
                if constexpr (m_rem > 3) {
                    m_dist_3 = _mm512_i32gather_ps(
                            codes_batch_int_3,
                            sim_table_ptr + 768,
                            sizeof(float));
                }

                // 4: Vertically add to exact distances
                __m512 exact_distances_batch =
                        _mm512_loadu_ps(exact_distances_ptr);

                exact_distances_batch =
                        _mm512_add_ps(exact_distances_batch, m_dist_0);
                if constexpr (m_rem > 1) {
                    exact_distances_batch =
                            _mm512_add_ps(exact_distances_batch, m_dist_1);
                }
                if constexpr (m_rem > 2) {
                    exact_distances_batch =
                            _mm512_add_ps(exact_distances_batch, m_dist_2);
                }
                if constexpr (m_rem > 3) {
                    exact_distances_batch =
                            _mm512_add_ps(exact_distances_batch, m_dist_3);
                }

                if constexpr (execute_pruning) {
                    // Load active indices into register
                    __m512i indices_batch = _mm512_loadu_si512(
                            indices + num_processed + batch_offset);

                    __m512 cum_sums_batch = _mm512_i32gather_ps(
                            indices_batch, cum_sums, sizeof(float));

                    __m512 cauchy_schwarz_bound = _mm512_mul_ps(
                            query_cum_norm_broadcast, cum_sums_batch);

                    __m512 exact_distances_batch_dis0 = _mm512_add_ps(
                            exact_distances_batch, dis0_broadcast);

                    __m512 heap_max = _mm512_set1_ps(res.top());

                    __m512 upper_bound = _mm512_add_ps(
                            exact_distances_batch_dis0, cauchy_schwarz_bound);

                    __mmask16 mask_heap_insert = _mm512_cmp_ps_mask(
                            upper_bound, heap_max, _CMP_LT_OQ);

                    __m512i indices_batch_compressed =
                            _mm512_maskz_compress_epi32(
                                    mask_heap_insert, indices_batch);

                    __m512 upper_bound_compressed = _mm512_maskz_compress_ps(
                            mask_heap_insert, upper_bound);

                    // Store compressed upper bounds
                    alignas(64) float upper_bound_array[16];

                    _mm512_store_ps(upper_bound_array, upper_bound_compressed);

                    // Store compressed indices
                    alignas(64) uint32_t indices_array[16];

                    _mm512_store_epi32(indices_array, indices_batch_compressed);

                    size_t mask_heap_insert_popcount =
                            _mm_popcnt_u32(mask_heap_insert);

                    for (size_t j = 0; j < mask_heap_insert_popcount; j++) {
                        res.add(indices_array[j], upper_bound_array[j]);
                    }

                    // Calculate lower bounds
                    heap_max = _mm512_set1_ps(res.top());

                    __m512 cauchy_schwarz_bound_epsilon = _mm512_mul_ps(cauchy_schwarz_bound, epsilon_broadcast);

                    __m512 lower_bound = _mm512_sub_ps(
                            exact_distances_batch_dis0, cauchy_schwarz_bound_epsilon);

                    __mmask16 mask_should_keep = _mm512_cmp_ps_mask(
                            lower_bound, heap_max, _CMP_LT_OQ);

                    size_t mask_should_keep_popcount =
                            _mm_popcnt_u32(mask_should_keep);

                    // Store compressed indices & exact distances.
                    indices_batch_compressed = _mm512_maskz_compress_epi32(
                            mask_should_keep, indices_batch);
                    _mm512_storeu_si512(
                            indices_dest_ptr, indices_batch_compressed);
                    __m512 exact_distances_batch_compressed =
                            _mm512_maskz_compress_ps(
                                    mask_should_keep, exact_distances_batch);
                    _mm512_storeu_ps(
                            exact_distances_dest_ptr,
                            exact_distances_batch_compressed);

                    // Now we reverse compress it for the bitset
                    indices_batch_compressed = _mm512_maskz_compress_epi32(
                            ~mask_should_keep, indices_batch);
                    _mm512_store_epi32(indices_array, indices_batch_compressed);

                    const size_t prune_count = 16 - mask_should_keep_popcount;
                    for (size_t j = 0; j < prune_count; ++j) {
                        bitset[indices_array[j]] = 0;
                    }

                    new_num_active += mask_should_keep_popcount;
                    indices_dest_ptr += mask_should_keep_popcount;
                    exact_distances_dest_ptr += mask_should_keep_popcount;
                } else {
                    _mm512_storeu_ps(
                        exact_distances_ptr, exact_distances_batch);
                }

                batch_storage_ptr_0 += 16;
                batch_storage_ptr_1 += 16;
                batch_storage_ptr_2 += 16;
                batch_storage_ptr_3 += 16;

                exact_distances_ptr += 16;
            }

            // Calculate remainder distances
            // bench this loop
            for (; batch_offset < curr_batch_size; batch_offset++) {
                if constexpr (ScanMode != faiss::PanoramaScanMode::Full) {
                    if constexpr (m_rem == 1) {
                        exact_distances[num_processed + batch_offset] +=
                                sim_table_ptr[batch_offsets[0][batch_offset]];
                    } else if constexpr (m_rem == 2) {
                        exact_distances[num_processed + batch_offset] +=
                                sim_table_ptr[batch_offsets[0][batch_offset]] +
                                (sim_table_ptr +
                                    256)[batch_offsets[1][batch_offset]];
                    } else if constexpr (m_rem == 3) {
                        exact_distances[num_processed + batch_offset] +=
                                sim_table_ptr[batch_offsets[0][batch_offset]] +
                                (sim_table_ptr +
                                    256)[batch_offsets[1][batch_offset]] +
                                (sim_table_ptr +
                                    512)[batch_offsets[2][batch_offset]];
                    } else if constexpr (m_rem == 4) {
                        exact_distances[num_processed + batch_offset] +=
                                sim_table_ptr[batch_offsets[0][batch_offset]] +
                                (sim_table_ptr +
                                    256)[batch_offsets[1][batch_offset]] +
                                (sim_table_ptr +
                                    512)[batch_offsets[2][batch_offset]] +
                                (sim_table_ptr +
                                    768)[batch_offsets[3][batch_offset]];
                    }
                } else {
                    if constexpr (m_rem == 1) {
                        exact_distances[num_processed + batch_offset] +=
                                sim_table_ptr[codes_ptr_0[batch_offset]];
                    } else if constexpr (m_rem == 2) {
                        exact_distances[num_processed + batch_offset] +=
                                sim_table_ptr[codes_ptr_0[batch_offset]] +
                                (sim_table_ptr +
                                    256)[codes_ptr_1[batch_offset]];
                    } else if constexpr (m_rem == 3) {
                        exact_distances[num_processed + batch_offset] +=
                                sim_table_ptr[codes_ptr_0[batch_offset]] +
                                (sim_table_ptr +
                                    256)[codes_ptr_1[batch_offset]] +
                                (sim_table_ptr +
                                    512)[codes_ptr_2[batch_offset]];
                    } else if constexpr (m_rem == 4) {
                        exact_distances[num_processed + batch_offset] +=
                                sim_table_ptr[codes_ptr_0[batch_offset]] +
                                (sim_table_ptr +
                                    256)[codes_ptr_1[batch_offset]] +
                                (sim_table_ptr +
                                    512)[codes_ptr_2[batch_offset]] +
                                (sim_table_ptr +
                                    768)[codes_ptr_3[batch_offset]];
                    }
                }

                uint32_t index = indices[num_processed + batch_offset];

                // Load exact distances
                float exact_distance =
                        exact_distances[num_processed + batch_offset];

                if constexpr (execute_pruning) {
                    // Calculate bounds
                    float cauchy_schwarz_bound =
                            2.0f * query_cum_norm * cum_sums[index];
                    float lower_bound =
                            dis0 + exact_distance - cauchy_schwarz_bound * epsilon;
                    float upper_bound =
                            dis0 + exact_distance + cauchy_schwarz_bound;

                    // Compare upper bound with heap max
                    if (upper_bound < res.top()) {
                        res.add(index, upper_bound);
                    }

                    // Pruning
                    bool should_keep = lower_bound < res.top();

                    // Compress indices
                    indices[new_num_active] = index;
                    exact_distances[new_num_active] = exact_distance;
                    if (!should_keep) {
                        bitset[index] = 0;
                    }

                    // Update num_active
                    new_num_active += should_keep;
                    indices_dest_ptr += should_keep;
                    exact_distances_dest_ptr += should_keep;
                }
            }

            num_processed += curr_batch_size;
        }

        if constexpr (execute_pruning) {
            *num_active = new_num_active;
        }
    }

    template <class SearchResultType, faiss::PanoramaScanMode ScanMode>
    inline void scan_list_with_table_panorama(
            // Important metadata.
            size_t list_size,
            size_t m_level_width,
            float epsilon,
            // Input codes.
            const uint8_t* codes,
            uint8_t* const* codes_offset,
            // Output data.
            uint8_t* bitset,
            SearchResultType& res,
            // Cumsums.
            const float* cum_sums,
            float query_cum_norm,
            // Batch allocations (zeroed out).
            uint8_t* batch_storage,
            uint8_t** batch_offsets,
            size_t batch_size, // Unit is bytes.
            // Indices
            size_t* num_active,
            uint32_t* indices,
            float* exact_distances) const {

        float* sim_table_ptr = sim_table;

        size_t m = 0;
        size_t last_m = ((m_level_width - 1) / 4) * 4;
        for (; m < last_m; m += 4) {
            scan_list_with_table_panorama_core<SearchResultType, ScanMode, false, 4>(
                list_size, 
                m_level_width,
                epsilon,
                codes, 
                codes_offset, 
                bitset, 
                res, 
                cum_sums, 
                query_cum_norm, 
                batch_storage, 
                batch_offsets, 
                batch_size, 
                num_active, 
                indices, 
                exact_distances, 
                sim_table_ptr,
                m);
            sim_table_ptr += 1024;
        }

        size_t rem = m_level_width % 4;

        if (rem == 1) {
            scan_list_with_table_panorama_core<SearchResultType, ScanMode, true, 1>(
                list_size,
                m_level_width, 
                epsilon,
                codes, 
                codes_offset, 
                bitset, 
                res, 
                cum_sums, 
                query_cum_norm, 
                batch_storage, 
                batch_offsets, 
                batch_size, 
                num_active, 
                indices, 
                exact_distances, 
                sim_table_ptr,
                m);
        } else if (rem == 2) {
            scan_list_with_table_panorama_core<SearchResultType, ScanMode, true, 2>(
                list_size,
                m_level_width, 
                epsilon,
                codes, 
                codes_offset, 
                bitset, 
                res, 
                cum_sums, 
                query_cum_norm, 
                batch_storage, 
                batch_offsets, 
                batch_size, 
                num_active, 
                indices, 
                exact_distances, 
                sim_table_ptr,
                m);
        } else if (rem == 3) {
            scan_list_with_table_panorama_core<SearchResultType, ScanMode, true, 3>(
                list_size,
                m_level_width, 
                epsilon,
                codes, 
                codes_offset, 
                bitset, 
                res, 
                cum_sums, 
                query_cum_norm, 
                batch_storage, 
                batch_offsets, 
                batch_size, 
                num_active, 
                indices, 
                exact_distances, 
                sim_table_ptr,
                m);
        } else if (rem == 0) {
            scan_list_with_table_panorama_core<SearchResultType, ScanMode, true, 4>(
                list_size,
                m_level_width, 
                epsilon,
                codes, 
                codes_offset, 
                bitset, 
                res, 
                cum_sums, 
                query_cum_norm, 
                batch_storage, 
                batch_offsets, 
                batch_size, 
                num_active, 
                indices, 
                exact_distances, 
                sim_table_ptr,
                m);
        }
    }

    /// tables are not precomputed, but pointers are provided to the
    /// relevant X_c|x_r tables
    template <class SearchResultType>
    void scan_list_with_pointer(
            size_t ncode,
            const uint8_t* codes,
            SearchResultType& res) const {
        for (size_t j = 0; j < ncode; j++, codes += pq.code_size) {
            if (res.skip_entry(j)) {
                continue;
            }
            PQDecoder decoder(codes, pq.nbits);
            float dis = dis0;
            const float* tab = sim_table_2;

            for (size_t m = 0; m < pq.M; m++) {
                int ci = decoder.decode();
                dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
                tab += pq.ksub;
            }
            res.add(j, dis);
        }
    }

    /// nothing is precomputed: access residuals on-the-fly
    template <class SearchResultType>
    void scan_on_the_fly_dist(
            size_t ncode,
            const uint8_t* codes,
            SearchResultType& res) const {
        const float* dvec;
        float dis0 = 0;
        if (by_residual) {
            if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
                ivfpq.quantizer->reconstruct(key, residual_vec);
                dis0 = fvec_inner_product(residual_vec, qi, d);
            } else {
                ivfpq.quantizer->compute_residual(qi, residual_vec, key);
            }
            dvec = residual_vec;
        } else {
            dvec = qi;
            dis0 = 0;
        }

        for (size_t j = 0; j < ncode; j++, codes += pq.code_size) {
            if (res.skip_entry(j)) {
                continue;
            }
            pq.decode(codes, decoded_vec);

            float dis;
            if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
                dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
            } else {
                dis = fvec_L2sqr(decoded_vec, dvec, d);
            }
            res.add(j, dis);
        }
    }

    /*****************************************************
     * Scanning codes with polysemous filtering
     *****************************************************/

    // This is the baseline version of scan_list_polysemous_hc().
    // It demonstrates what this function actually does.

    //     template <class HammingComputer, class SearchResultType>
    //     void scan_list_polysemous_hc(
    //             size_t ncode,
    //             const uint8_t* codes,
    //             SearchResultType& res) const {
    //         int ht = ivfpq.polysemous_ht;
    //         size_t n_hamming_pass = 0, nup = 0;
    //
    //         int code_size = pq.code_size;
    //
    //         HammingComputer hc(q_code.data(), code_size);
    //
    //         for (size_t j = 0; j < ncode; j++, codes += code_size) {
    //             if (res.skip_entry(j)) {
    //                 continue;
    //             }
    //             const uint8_t* b_code = codes;
    //             int hd = hc.hamming(b_code);
    //             if (hd < ht) {
    //                 n_hamming_pass++;
    //
    //                 float dis =
    //                         dis0 +
    //                         distance_single_code<PQDecoder>(
    //                             pq, sim_table, codes);
    //
    //                 res.add(j, dis);
    //             }
    //         }
    // #pragma omp critical
    //         { indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
    //     }

    // This is the modified version of scan_list_with_tables().
    // It was observed that doing manual unrolling of the loop that
    //    utilizes distance_single_code() speeds up the computations.

    template <class HammingComputer, class SearchResultType>
    void scan_list_polysemous_hc(
            size_t ncode,
            const uint8_t* codes,
            SearchResultType& res) const {
        int ht = ivfpq.polysemous_ht;
        size_t n_hamming_pass = 0;

        int code_size = pq.code_size;

        size_t saved_j[8];
        int counter = 0;

        HammingComputer hc(q_code.data(), code_size);

        for (size_t j = 0; j < (ncode / 4) * 4; j += 4) {
            const uint8_t* b_code = codes + j * code_size;

            // Unrolling is a key. Basically, doing multiple popcount
            // operations one after another speeds things up.

            // 9999999 is just an arbitrary large number
            int hd0 = (res.skip_entry(j + 0))
                    ? 99999999
                    : hc.hamming(b_code + 0 * code_size);
            int hd1 = (res.skip_entry(j + 1))
                    ? 99999999
                    : hc.hamming(b_code + 1 * code_size);
            int hd2 = (res.skip_entry(j + 2))
                    ? 99999999
                    : hc.hamming(b_code + 2 * code_size);
            int hd3 = (res.skip_entry(j + 3))
                    ? 99999999
                    : hc.hamming(b_code + 3 * code_size);

            saved_j[counter] = j + 0;
            counter = (hd0 < ht) ? (counter + 1) : counter;
            saved_j[counter] = j + 1;
            counter = (hd1 < ht) ? (counter + 1) : counter;
            saved_j[counter] = j + 2;
            counter = (hd2 < ht) ? (counter + 1) : counter;
            saved_j[counter] = j + 3;
            counter = (hd3 < ht) ? (counter + 1) : counter;

            if (counter >= 4) {
                // process four codes at the same time
                n_hamming_pass += 4;

                float distance_0 = dis0;
                float distance_1 = dis0;
                float distance_2 = dis0;
                float distance_3 = dis0;
                distance_four_codes<PQDecoder>(
                        pq.M,
                        pq.nbits,
                        sim_table,
                        codes + saved_j[0] * pq.code_size,
                        codes + saved_j[1] * pq.code_size,
                        codes + saved_j[2] * pq.code_size,
                        codes + saved_j[3] * pq.code_size,
                        distance_0,
                        distance_1,
                        distance_2,
                        distance_3);

                res.add(saved_j[0], dis0 + distance_0);
                res.add(saved_j[1], dis0 + distance_1);
                res.add(saved_j[2], dis0 + distance_2);
                res.add(saved_j[3], dis0 + distance_3);

                //
                counter -= 4;
                saved_j[0] = saved_j[4];
                saved_j[1] = saved_j[5];
                saved_j[2] = saved_j[6];
                saved_j[3] = saved_j[7];
            }
        }

        for (size_t kk = 0; kk < counter; kk++) {
            n_hamming_pass++;

            float dis = dis0 +
                    distance_single_code<PQDecoder>(
                                pq.M,
                                pq.nbits,
                                sim_table,
                                codes + saved_j[kk] * pq.code_size);

            res.add(saved_j[kk], dis);
        }

        // process leftovers
        for (size_t j = (ncode / 4) * 4; j < ncode; j++) {
            if (res.skip_entry(j)) {
                continue;
            }
            const uint8_t* b_code = codes + j * code_size;
            int hd = hc.hamming(b_code);
            if (hd < ht) {
                n_hamming_pass++;

                float dis = dis0 +
                        distance_single_code<PQDecoder>(
                                    pq.M,
                                    pq.nbits,
                                    sim_table,
                                    codes + j * code_size);

                res.add(j, dis);
            }
        }

#pragma omp critical
        {
            indexIVFPQ_stats.n_hamming_pass += n_hamming_pass;
        }
    }

    template <class SearchResultType>
    struct Run_scan_list_polysemous_hc {
        using T = void;
        template <class HammingComputer, class... Types>
        void f(const IVFPQScannerT* scanner, Types... args) {
            scanner->scan_list_polysemous_hc<HammingComputer, SearchResultType>(
                    args...);
        }
    };

    template <class SearchResultType>
    void scan_list_polysemous(
            size_t ncode,
            const uint8_t* codes,
            SearchResultType& res) const {
        Run_scan_list_polysemous_hc<SearchResultType> r;
        dispatch_HammingComputer(pq.code_size, r, this, ncode, codes, res);
    }
};

/* We put as many parameters as possible in template. Hopefully the
 * gain in runtime is worth the code bloat.
 *
 * C is the comparator < or >, it is directly related to METRIC_TYPE.
 *
 * precompute_mode is how much we precompute (2 = precompute distance tables,
 * 1 = precompute pointers to distances, 0 = compute distances one by one).
 * Currently only 2 is supported
 *
 * use_sel: store or ignore the IDSelector
 */
template <MetricType METRIC_TYPE, class C, class PQDecoder, bool use_sel>
struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
                      InvertedListScanner {
    int precompute_mode;
    const IDSelector* sel;

    IVFPQScanner(
            const IndexIVFPQ& ivfpq,
            bool store_pairs,
            int precompute_mode,
            const IDSelector* sel)
            : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>(ivfpq, nullptr),
              precompute_mode(precompute_mode),
              sel(sel) {
        this->store_pairs = store_pairs;
        this->keep_max = is_similarity_metric(METRIC_TYPE);
        this->code_size = this->pq.code_size;
    }

    void set_query(const float* query) override {
        this->init_query(query);
    }

    void set_list(idx_t list_no, float coarse_dis) override {
        this->list_no = list_no;
        this->init_list(list_no, coarse_dis, precompute_mode);
    }

    void set_list_panorama(
            idx_t list_no,
            float coarse_dis,
            float* sim_table,
            float* dis0_ptr,
            bool update) override {
        this->list_no = list_no;
        this->init_list_panorama(
                list_no,
                coarse_dis,
                precompute_mode,
                sim_table,
                dis0_ptr,
                update);
    }

    void set_sim_table(float* sim_table, float dis0) override {
        this->sim_table = sim_table;
        this->dis0 = dis0;
    }

    float distance_to_code(const uint8_t* code) const override {
        assert(precompute_mode == 2);
        float dis = this->dis0 +
                distance_single_code<PQDecoder>(
                            this->pq.M, this->pq.nbits, this->sim_table, code);
        return dis;
    }

    size_t scan_codes_panorama(
            size_t list_size,
            size_t m_level_width,
            float epsilon,
            const uint8_t* codes,
            uint8_t* const* codes_offset,
            uint8_t* bitset,
            float* exact_distances,
            const idx_t* ids, // ids in the cluster
            float* heap_sim,
            idx_t* heap_ids,
            size_t k,
            const float* cum_sums,
            float query_cum_norm,
            uint8_t* batch_storage,
            uint8_t** batch_offsets,
            uint32_t* indices,
            size_t* num_active,
            size_t batch_size) const override {
        KnnSearchResultsPanorama<C, use_sel> res = {
                /* key */ this->key,
                /* ids */ this->store_pairs ? nullptr : ids,
                /* sel */ this->sel,
                /* k */ k,
                /* heap_sim */ heap_sim,
                /* heap_ids */ heap_ids,
                /* nup */ 0};

        FAISS_ASSERT(precompute_mode == 2);

        float occupancy = ((float)*num_active) / list_size;

        if (occupancy == 1.0) {
            this->template scan_list_with_table_panorama<
                    KnnSearchResultsPanorama<C, use_sel>,
                    faiss::PanoramaScanMode::Full>(
                    list_size,
                    m_level_width,
                    epsilon,
                    codes,
                    codes_offset,
                    bitset,
                    res,
                    cum_sums,
                    query_cum_norm,
                    batch_storage,
                    batch_offsets,
                    batch_size,
                    num_active,
                    indices,
                    exact_distances);
        } else if (occupancy > 0.05) {
            this->template scan_list_with_table_panorama<
                    KnnSearchResultsPanorama<C, use_sel>,
                    faiss::PanoramaScanMode::Dense>(
                    list_size,
                    m_level_width,
                    epsilon,
                    codes,
                    codes_offset,
                    bitset,
                    res,
                    cum_sums,
                    query_cum_norm,
                    batch_storage,
                    batch_offsets,
                    batch_size,
                    num_active,
                    indices,
                    exact_distances);
        } else {
            this->template scan_list_with_table_panorama<
                    KnnSearchResultsPanorama<C, use_sel>,
                    faiss::PanoramaScanMode::Sparse>(
                    list_size,
                    m_level_width,
                    epsilon,
                    codes,
                    codes_offset,
                    bitset,
                    res,
                    cum_sums,
                    query_cum_norm,
                    batch_storage,
                    batch_offsets,
                    batch_size,
                    num_active,
                    indices,
                    exact_distances);
        }

        return res.nup;
    }

    size_t scan_codes(
            size_t ncode,
            const uint8_t* codes,
            const idx_t* ids,
            float* heap_sim,
            idx_t* heap_ids,
            size_t k) const override {
        KnnSearchResults<C, use_sel> res = {
                /* key */ this->key,
                /* ids */ this->store_pairs ? nullptr : ids,
                /* sel */ this->sel,
                /* k */ k,
                /* heap_sim */ heap_sim,
                /* heap_ids */ heap_ids,
                /* nup */ 0};

        if (this->polysemous_ht > 0) {
            assert(precompute_mode == 2);
            this->scan_list_polysemous(ncode, codes, res);
        } else if (precompute_mode == 2) {
            this->scan_list_with_table(ncode, codes, res);
        } else if (precompute_mode == 1) {
            this->scan_list_with_pointer(ncode, codes, res);
        } else if (precompute_mode == 0) {
            this->scan_on_the_fly_dist(ncode, codes, res);
        } else {
            FAISS_THROW_MSG("bad precomp mode");
        }
        return res.nup;
    }

    void scan_codes_range(
            size_t ncode,
            const uint8_t* codes,
            const idx_t* ids,
            float radius,
            RangeQueryResult& rres) const override {
        RangeSearchResults<C, use_sel> res = {
                /* key */ this->key,
                /* ids */ this->store_pairs ? nullptr : ids,
                /* sel */ this->sel,
                /* radius */ radius,
                /* rres */ rres};

        if (this->polysemous_ht > 0) {
            assert(precompute_mode == 2);
            this->scan_list_polysemous(ncode, codes, res);
        } else if (precompute_mode == 2) {
            this->scan_list_with_table(ncode, codes, res);
        } else if (precompute_mode == 1) {
            this->scan_list_with_pointer(ncode, codes, res);
        } else if (precompute_mode == 0) {
            this->scan_on_the_fly_dist(ncode, codes, res);
        } else {
            FAISS_THROW_MSG("bad precomp mode");
        }
    }
};

template <class PQDecoder, bool use_sel>
InvertedListScanner* get_InvertedListScanner1(
        const IndexIVFPQ& index,
        bool store_pairs,
        const IDSelector* sel) {
    if (index.metric_type == METRIC_INNER_PRODUCT) {
        return new IVFPQScanner<
                METRIC_INNER_PRODUCT,
                CMin<float, idx_t>,
                PQDecoder,
                use_sel>(index, store_pairs, 2, sel);
    } else if (index.metric_type == METRIC_L2) {
        return new IVFPQScanner<
                METRIC_L2,
                CMax<float, idx_t>,
                PQDecoder,
                use_sel>(index, store_pairs, 2, sel);
    }
    return nullptr;
}

template <bool use_sel>
InvertedListScanner* get_InvertedListScanner2(
        const IndexIVFPQ& index,
        bool store_pairs,
        const IDSelector* sel) {
    if (index.pq.nbits == 8) {
        return get_InvertedListScanner1<PQDecoder8, use_sel>(
                index, store_pairs, sel);
    } else if (index.pq.nbits == 16) {
        return get_InvertedListScanner1<PQDecoder16, use_sel>(
                index, store_pairs, sel);
    } else {
        return get_InvertedListScanner1<PQDecoderGeneric, use_sel>(
                index, store_pairs, sel);
    }
}

} // anonymous namespace

InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
        bool store_pairs,
        const IDSelector* sel,
        const IVFSearchParameters*) const {
    if (sel) {
        return get_InvertedListScanner2<true>(*this, store_pairs, sel);
    } else {
        return get_InvertedListScanner2<false>(*this, store_pairs, sel);
    }
    return nullptr;
}

IndexIVFPQStats indexIVFPQ_stats;

void IndexIVFPQStats::reset() {
    memset(this, 0, sizeof(*this));
}

IndexIVFPQ::IndexIVFPQ() {
    // initialize some runtime values
    use_precomputed_table = 0;
    scan_table_threshold = 0;
    do_polysemous_training = false;
    polysemous_ht = 0;
    polysemous_training = nullptr;
}

struct CodeCmp {
    const uint8_t* tab;
    size_t code_size;
    bool operator()(int a, int b) const {
        return cmp(a, b) > 0;
    }
    int cmp(int a, int b) const {
        return memcmp(tab + a * code_size, tab + b * code_size, code_size);
    }
};

size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const {
    size_t ngroup = 0;
    lims[0] = 0;
    for (size_t list_no = 0; list_no < nlist; list_no++) {
        size_t n = invlists->list_size(list_no);
        std::vector<int> ord(n);
        for (int i = 0; i < n; i++)
            ord[i] = i;
        InvertedLists::ScopedCodes codes(invlists, list_no);
        CodeCmp cs = {codes.get(), code_size};
        std::sort(ord.begin(), ord.end(), cs);

        InvertedLists::ScopedIds list_ids(invlists, list_no);
        int prev = -1; // all elements from prev to i-1 are equal
        for (int i = 0; i < n; i++) {
            if (prev >= 0 && cs.cmp(ord[prev], ord[i]) == 0) {
                // same as previous => remember
                if (prev + 1 == i) { // start new group
                    ngroup++;
                    lims[ngroup] = lims[ngroup - 1];
                    dup_ids[lims[ngroup]++] = list_ids[ord[prev]];
                }
                dup_ids[lims[ngroup]++] = list_ids[ord[i]];
            } else { // not same as previous.
                prev = i;
            }
        }
    }
    return ngroup;
}

} // namespace faiss
