/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

// -*- c++ -*-

#include <faiss/impl/PolysemousTraining.h>

#include <omp.h>
#include <stdint.h>

#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <memory>

#include <faiss/utils/distances.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/random.h>
#include <faiss/utils/utils.h>

#include <faiss/impl/FaissAssert.h>

/*****************************************
 * Mixed PQ / Hamming
 ******************************************/

namespace faiss {

/****************************************************
 * Optimization code
 ****************************************************/

// what would the cost update be if iw and jw were swapped?
// default implementation just computes both and computes the difference
double PermutationObjective::cost_update(const int* perm, int iw, int jw)
        const {
    double orig_cost = compute_cost(perm);

    std::vector<int> perm2(n);
    for (int i = 0; i < n; i++)
        perm2[i] = perm[i];
    perm2[iw] = perm[jw];
    perm2[jw] = perm[iw];

    double new_cost = compute_cost(perm2.data());
    return new_cost - orig_cost;
}

SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer(
        PermutationObjective* obj,
        const SimulatedAnnealingParameters& p)
        : SimulatedAnnealingParameters(p),
          obj(obj),
          n(obj->n),
          logfile(nullptr) {
    rnd = new RandomGenerator(p.seed);
    FAISS_THROW_IF_NOT(n < 100000 && n >= 0);
}

SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer() {
    delete rnd;
}

// run the optimization and return the best result in best_perm
double SimulatedAnnealingOptimizer::run_optimization(int* best_perm) {
    double min_cost = 1e30;

    // just do a few runs of the annealing and keep the lowest output cost
    for (int it = 0; it < n_redo; it++) {
        std::vector<int> perm(n);
        for (int i = 0; i < n; i++)
            perm[i] = i;
        if (init_random) {
            for (int i = 0; i < n; i++) {
                int j = i + rnd->rand_int(n - i);
                std::swap(perm[i], perm[j]);
            }
        }
        float cost = optimize(perm.data());
        if (logfile)
            fprintf(logfile, "\n");
        if (verbose > 1) {
            printf("    optimization run %d: cost=%g %s\n",
                   it,
                   cost,
                   cost < min_cost ? "keep" : "");
        }
        if (cost < min_cost) {
            memcpy(best_perm, perm.data(), sizeof(perm[0]) * n);
            min_cost = cost;
        }
    }
    return min_cost;
}

// perform the optimization loop, starting from and modifying
// permutation in-place
double SimulatedAnnealingOptimizer::optimize(int* perm) {
    double cost = init_cost = obj->compute_cost(perm);
    int log2n = 0;
    while (!(n <= (1 << log2n)))
        log2n++;
    double temperature = init_temperature;
    int n_swap = 0, n_hot = 0;
    for (int it = 0; it < n_iter; it++) {
        temperature = temperature * temperature_decay;
        int iw, jw;
        if (only_bit_flips) {
            iw = rnd->rand_int(n);
            jw = iw ^ (1 << rnd->rand_int(log2n));
        } else {
            iw = rnd->rand_int(n);
            jw = rnd->rand_int(n - 1);
            if (jw == iw)
                jw++;
        }
        double delta_cost = obj->cost_update(perm, iw, jw);
        if (delta_cost < 0 || rnd->rand_float() < temperature) {
            std::swap(perm[iw], perm[jw]);
            cost += delta_cost;
            n_swap++;
            if (delta_cost >= 0)
                n_hot++;
        }
        if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
            printf("      iteration %d cost %g temp %g n_swap %d "
                   "(%d hot)     \r",
                   it,
                   cost,
                   temperature,
                   n_swap,
                   n_hot);
            fflush(stdout);
        }
        if (logfile) {
            fprintf(logfile,
                    "%d %g %g %d %d\n",
                    it,
                    cost,
                    temperature,
                    n_swap,
                    n_hot);
        }
    }
    if (verbose > 1)
        printf("\n");
    return cost;
}

/****************************************************
 * Cost functions: ReproduceDistanceTable
 ****************************************************/

static inline int hamming_dis(uint64_t a, uint64_t b) {
    return __builtin_popcountl(a ^ b);
}

namespace {

/// optimize permutation to reproduce a distance table with Hamming distances
struct ReproduceWithHammingObjective : PermutationObjective {
    int nbits;
    double dis_weight_factor;

    static double sqr(double x) {
        return x * x;
    }

    // weihgting of distances: it is more important to reproduce small
    // distances well
    double dis_weight(double x) const {
        return exp(-dis_weight_factor * x);
    }

    std::vector<double> target_dis; // wanted distances (size n^2)
    std::vector<double> weights;    // weights for each distance (size n^2)

    // cost = quadratic difference between actual distance and Hamming distance
    double compute_cost(const int* perm) const override {
        double cost = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                double wanted = target_dis[i * n + j];
                double w = weights[i * n + j];
                double actual = hamming_dis(perm[i], perm[j]);
                cost += w * sqr(wanted - actual);
            }
        }
        return cost;
    }

    // what would the cost update be if iw and jw were swapped?
    // computed in O(n) instead of O(n^2) for the full re-computation
    double cost_update(const int* perm, int iw, int jw) const override {
        double delta_cost = 0;

        for (int i = 0; i < n; i++) {
            if (i == iw) {
                for (int j = 0; j < n; j++) {
                    double wanted = target_dis[i * n + j],
                           w = weights[i * n + j];
                    double actual = hamming_dis(perm[i], perm[j]);
                    delta_cost -= w * sqr(wanted - actual);
                    double new_actual = hamming_dis(
                            perm[jw],
                            perm[j == iw           ? jw
                                         : j == jw ? iw
                                                   : j]);
                    delta_cost += w * sqr(wanted - new_actual);
                }
            } else if (i == jw) {
                for (int j = 0; j < n; j++) {
                    double wanted = target_dis[i * n + j],
                           w = weights[i * n + j];
                    double actual = hamming_dis(perm[i], perm[j]);
                    delta_cost -= w * sqr(wanted - actual);
                    double new_actual = hamming_dis(
                            perm[iw],
                            perm[j == iw           ? jw
                                         : j == jw ? iw
                                                   : j]);
                    delta_cost += w * sqr(wanted - new_actual);
                }
            } else {
                int j = iw;
                {
                    double wanted = target_dis[i * n + j],
                           w = weights[i * n + j];
                    double actual = hamming_dis(perm[i], perm[j]);
                    delta_cost -= w * sqr(wanted - actual);
                    double new_actual = hamming_dis(perm[i], perm[jw]);
                    delta_cost += w * sqr(wanted - new_actual);
                }
                j = jw;
                {
                    double wanted = target_dis[i * n + j],
                           w = weights[i * n + j];
                    double actual = hamming_dis(perm[i], perm[j]);
                    delta_cost -= w * sqr(wanted - actual);
                    double new_actual = hamming_dis(perm[i], perm[iw]);
                    delta_cost += w * sqr(wanted - new_actual);
                }
            }
        }

        return delta_cost;
    }

    ReproduceWithHammingObjective(
            int nbits,
            const std::vector<double>& dis_table,
            double dis_weight_factor)
            : nbits(nbits), dis_weight_factor(dis_weight_factor) {
        n = 1 << nbits;
        FAISS_THROW_IF_NOT(dis_table.size() == n * n);
        set_affine_target_dis(dis_table);
    }

    void set_affine_target_dis(const std::vector<double>& dis_table) {
        double sum = 0, sum2 = 0;
        int n2 = n * n;
        for (int i = 0; i < n2; i++) {
            sum += dis_table[i];
            sum2 += dis_table[i] * dis_table[i];
        }
        double mean = sum / n2;
        double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));

        target_dis.resize(n2);

        for (int i = 0; i < n2; i++) {
            // the mapping function
            double td = (dis_table[i] - mean) / stddev * sqrt(nbits / 4) +
                    nbits / 2;
            target_dis[i] = td;
            // compute a weight
            weights.push_back(dis_weight(td));
        }
    }

    ~ReproduceWithHammingObjective() override {}
};

} // anonymous namespace

// weihgting of distances: it is more important to reproduce small
// distances well
double ReproduceDistancesObjective::dis_weight(double x) const {
    return exp(-dis_weight_factor * x);
}

double ReproduceDistancesObjective::get_source_dis(int i, int j) const {
    return source_dis[i * n + j];
}

// cost = quadratic difference between actual distance and Hamming distance
double ReproduceDistancesObjective::compute_cost(const int* perm) const {
    double cost = 0;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            double wanted = target_dis[i * n + j];
            double w = weights[i * n + j];
            double actual = get_source_dis(perm[i], perm[j]);
            cost += w * sqr(wanted - actual);
        }
    }
    return cost;
}

// what would the cost update be if iw and jw were swapped?
// computed in O(n) instead of O(n^2) for the full re-computation
double ReproduceDistancesObjective::cost_update(const int* perm, int iw, int jw)
        const {
    double delta_cost = 0;
    for (int i = 0; i < n; i++) {
        if (i == iw) {
            for (int j = 0; j < n; j++) {
                double wanted = target_dis[i * n + j], w = weights[i * n + j];
                double actual = get_source_dis(perm[i], perm[j]);
                delta_cost -= w * sqr(wanted - actual);
                double new_actual = get_source_dis(
                        perm[jw],
                        perm[j == iw           ? jw
                                     : j == jw ? iw
                                               : j]);
                delta_cost += w * sqr(wanted - new_actual);
            }
        } else if (i == jw) {
            for (int j = 0; j < n; j++) {
                double wanted = target_dis[i * n + j], w = weights[i * n + j];
                double actual = get_source_dis(perm[i], perm[j]);
                delta_cost -= w * sqr(wanted - actual);
                double new_actual = get_source_dis(
                        perm[iw],
                        perm[j == iw           ? jw
                                     : j == jw ? iw
                                               : j]);
                delta_cost += w * sqr(wanted - new_actual);
            }
        } else {
            int j = iw;
            {
                double wanted = target_dis[i * n + j], w = weights[i * n + j];
                double actual = get_source_dis(perm[i], perm[j]);
                delta_cost -= w * sqr(wanted - actual);
                double new_actual = get_source_dis(perm[i], perm[jw]);
                delta_cost += w * sqr(wanted - new_actual);
            }
            j = jw;
            {
                double wanted = target_dis[i * n + j], w = weights[i * n + j];
                double actual = get_source_dis(perm[i], perm[j]);
                delta_cost -= w * sqr(wanted - actual);
                double new_actual = get_source_dis(perm[i], perm[iw]);
                delta_cost += w * sqr(wanted - new_actual);
            }
        }
    }
    return delta_cost;
}

ReproduceDistancesObjective::ReproduceDistancesObjective(
        int n,
        const double* source_dis_in,
        const double* target_dis_in,
        double dis_weight_factor)
        : dis_weight_factor(dis_weight_factor), target_dis(target_dis_in) {
    this->n = n;
    set_affine_target_dis(source_dis_in);
}

void ReproduceDistancesObjective::compute_mean_stdev(
        const double* tab,
        size_t n2,
        double* mean_out,
        double* stddev_out) {
    double sum = 0, sum2 = 0;
    for (int i = 0; i < n2; i++) {
        sum += tab[i];
        sum2 += tab[i] * tab[i];
    }
    double mean = sum / n2;
    double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
    *mean_out = mean;
    *stddev_out = stddev;
}

void ReproduceDistancesObjective::set_affine_target_dis(
        const double* source_dis_in) {
    int n2 = n * n;

    double mean_src, stddev_src;
    compute_mean_stdev(source_dis_in, n2, &mean_src, &stddev_src);

    double mean_target, stddev_target;
    compute_mean_stdev(target_dis, n2, &mean_target, &stddev_target);

    printf("map mean %g std %g -> mean %g std %g\n",
           mean_src,
           stddev_src,
           mean_target,
           stddev_target);

    source_dis.resize(n2);
    weights.resize(n2);

    for (int i = 0; i < n2; i++) {
        // the mapping function
        source_dis[i] =
                (source_dis_in[i] - mean_src) / stddev_src * stddev_target +
                mean_target;

        // compute a weight
        weights[i] = dis_weight(target_dis[i]);
    }
}

/****************************************************
 * Cost functions: RankingScore
 ****************************************************/

/// Maintains a 3D table of elementary costs.
/// Accumulates elements based on Hamming distance comparisons
template <typename Ttab, typename Taccu>
struct Score3Computer : PermutationObjective {
    int nc;

    // cost matrix of size nc * nc *nc
    // n_gt (i,j,k) = count of d_gt(x, y-) < d_gt(x, y+)
    // where x has PQ code i, y- PQ code j and y+ PQ code k
    std::vector<Ttab> n_gt;

    /// the cost is a triple loop on the nc * nc * nc matrix of entries.
    ///
    Taccu compute(const int* perm) const {
        Taccu accu = 0;
        const Ttab* p = n_gt.data();
        for (int i = 0; i < nc; i++) {
            int ip = perm[i];
            for (int j = 0; j < nc; j++) {
                int jp = perm[j];
                for (int k = 0; k < nc; k++) {
                    int kp = perm[k];
                    if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
                        accu += *p; // n_gt [ ( i * nc + j) * nc + k];
                    }
                    p++;
                }
            }
        }
        return accu;
    }

    /** cost update if entries iw and jw of the permutation would be
     * swapped.
     *
     * The computation is optimized by avoiding elements in the
     * nc*nc*nc cube that are known not to change. For nc=256, this
     * reduces the nb of cells to visit to about 6/256 th of the
     * cells. Practical speedup is about 8x, and the code is quite
     * complex :-/
     */
    Taccu compute_update(const int* perm, int iw, int jw) const {
        assert(iw != jw);
        if (iw > jw)
            std::swap(iw, jw);

        Taccu accu = 0;
        const Ttab* n_gt_i = n_gt.data();
        for (int i = 0; i < nc; i++) {
            int ip0 = perm[i];
            int ip = perm[i == iw ? jw : i == jw ? iw : i];

            // accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);

            accu += update_i_cross(perm, iw, jw, ip0, ip, n_gt_i);

            if (ip != ip0)
                accu += update_i_plane(perm, iw, jw, ip0, ip, n_gt_i);

            n_gt_i += nc * nc;
        }

        return accu;
    }

    Taccu update_i(
            const int* perm,
            int iw,
            int jw,
            int ip0,
            int ip,
            const Ttab* n_gt_i) const {
        Taccu accu = 0;
        const Ttab* n_gt_ij = n_gt_i;
        for (int j = 0; j < nc; j++) {
            int jp0 = perm[j];
            int jp = perm[j == iw ? jw : j == jw ? iw : j];
            for (int k = 0; k < nc; k++) {
                int kp0 = perm[k];
                int kp = perm[k == iw ? jw : k == jw ? iw : k];
                int ng = n_gt_ij[k];
                if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
                    accu += ng;
                }
                if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
                    accu -= ng;
                }
            }
            n_gt_ij += nc;
        }
        return accu;
    }

    // 2 inner loops for the case ip0 != ip
    Taccu update_i_plane(
            const int* perm,
            int iw,
            int jw,
            int ip0,
            int ip,
            const Ttab* n_gt_i) const {
        Taccu accu = 0;
        const Ttab* n_gt_ij = n_gt_i;

        for (int j = 0; j < nc; j++) {
            if (j != iw && j != jw) {
                int jp = perm[j];
                for (int k = 0; k < nc; k++) {
                    if (k != iw && k != jw) {
                        int kp = perm[k];
                        Ttab ng = n_gt_ij[k];
                        if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
                            accu += ng;
                        }
                        if (hamming_dis(ip0, jp) < hamming_dis(ip0, kp)) {
                            accu -= ng;
                        }
                    }
                }
            }
            n_gt_ij += nc;
        }
        return accu;
    }

    /// used for the 8 cells were the 3 indices are swapped
    inline Taccu update_k(
            const int* perm,
            int iw,
            int jw,
            int ip0,
            int ip,
            int jp0,
            int jp,
            int k,
            const Ttab* n_gt_ij) const {
        Taccu accu = 0;
        int kp0 = perm[k];
        int kp = perm[k == iw ? jw : k == jw ? iw : k];
        Ttab ng = n_gt_ij[k];
        if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
            accu += ng;
        }
        if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
            accu -= ng;
        }
        return accu;
    }

    /// compute update on a line of k's, where i and j are swapped
    Taccu update_j_line(
            const int* perm,
            int iw,
            int jw,
            int ip0,
            int ip,
            int jp0,
            int jp,
            const Ttab* n_gt_ij) const {
        Taccu accu = 0;
        for (int k = 0; k < nc; k++) {
            if (k == iw || k == jw)
                continue;
            int kp = perm[k];
            Ttab ng = n_gt_ij[k];
            if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
                accu += ng;
            }
            if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp)) {
                accu -= ng;
            }
        }
        return accu;
    }

    /// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
    Taccu update_i_cross(
            const int* perm,
            int iw,
            int jw,
            int ip0,
            int ip,
            const Ttab* n_gt_i) const {
        Taccu accu = 0;
        const Ttab* n_gt_ij = n_gt_i;

        for (int j = 0; j < nc; j++) {
            int jp0 = perm[j];
            int jp = perm[j == iw ? jw : j == jw ? iw : j];

            accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
            accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);

            if (jp != jp0)
                accu += update_j_line(perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);

            n_gt_ij += nc;
        }
        return accu;
    }

    /// PermutationObjective implementeation (just negates the scores
    /// for minimization)

    double compute_cost(const int* perm) const override {
        return -compute(perm);
    }

    double cost_update(const int* perm, int iw, int jw) const override {
        double ret = -compute_update(perm, iw, jw);
        return ret;
    }

    ~Score3Computer() override {}
};

struct IndirectSort {
    const float* tab;
    bool operator()(int a, int b) {
        return tab[a] < tab[b];
    }
};

struct RankingScore2 : Score3Computer<float, double> {
    int nbits;
    int nq, nb;
    const uint32_t *qcodes, *bcodes;
    const float* gt_distances;

    RankingScore2(
            int nbits,
            int nq,
            int nb,
            const uint32_t* qcodes,
            const uint32_t* bcodes,
            const float* gt_distances)
            : nbits(nbits),
              nq(nq),
              nb(nb),
              qcodes(qcodes),
              bcodes(bcodes),
              gt_distances(gt_distances) {
        n = nc = 1 << nbits;
        n_gt.resize(nc * nc * nc);
        init_n_gt();
    }

    double rank_weight(int r) {
        return 1.0 / (r + 1);
    }

    /// count nb of i, j in a x b st. i < j
    /// a and b should be sorted on input
    /// they are the ranks of j and k respectively.
    /// specific version for diff-of-rank weighting, cannot optimized
    /// with a cumulative table
    double accum_gt_weight_diff(
            const std::vector<int>& a,
            const std::vector<int>& b) {
        const auto nb_2 = b.size();
        const auto na = a.size();

        double accu = 0;
        size_t j = 0;
        for (size_t i = 0; i < na; i++) {
            const auto ai = a[i];
            while (j < nb_2 && ai >= b[j]) {
                j++;
            }

            double accu_i = 0;
            for (auto k = j; k < b.size(); k++) {
                accu_i += rank_weight(b[k] - ai);
            }

            accu += rank_weight(ai) * accu_i;
        }
        return accu;
    }

    void init_n_gt() {
        for (int q = 0; q < nq; q++) {
            const float* gtd = gt_distances + q * nb;
            const uint32_t* cb = bcodes; // all same codes
            float* n_gt_q = &n_gt[qcodes[q] * nc * nc];

            printf("init gt for q=%d/%d    \r", q, nq);
            fflush(stdout);

            std::vector<int> rankv(nb);
            int* ranks = rankv.data();

            // elements in each code bin, ordered by rank within each bin
            std::vector<std::vector<int>> tab(nc);

            { // build rank table
                IndirectSort s = {gtd};
                for (int j = 0; j < nb; j++)
                    ranks[j] = j;
                std::sort(ranks, ranks + nb, s);
            }

            for (int rank = 0; rank < nb; rank++) {
                int i = ranks[rank];
                tab[cb[i]].push_back(rank);
            }

            // this is very expensive. Any suggestion for improvement
            // welcome.
            for (int i = 0; i < nc; i++) {
                std::vector<int>& di = tab[i];
                for (int j = 0; j < nc; j++) {
                    std::vector<int>& dj = tab[j];
                    n_gt_q[i * nc + j] += accum_gt_weight_diff(di, dj);
                }
            }
        }
    }
};

/*****************************************
 * PolysemousTraining
 ******************************************/

PolysemousTraining::PolysemousTraining() {
    optimization_type = OT_ReproduceDistances_affine;
    ntrain_permutation = 0;
    dis_weight_factor = log(2);
    // max 20 G RAM
    max_memory = (size_t)(20) * 1024 * 1024 * 1024;
}

void PolysemousTraining::optimize_reproduce_distances(
        ProductQuantizer& pq) const {
    int dsub = pq.dsub;

    int n = pq.ksub;
    int nbits = pq.nbits;

    size_t mem1 = memory_usage_per_thread(pq);
    int nt = std::min(omp_get_max_threads(), int(pq.M));
    FAISS_THROW_IF_NOT_FMT(
            mem1 < max_memory,
            "Polysemous training will use %zd bytes per thread, while the max is set to %zd",
            mem1,
            max_memory);

    if (mem1 * nt > max_memory) {
        nt = max_memory / mem1;
        fprintf(stderr,
                "Polysemous training: WARN, reducing number of threads to %d to save memory",
                nt);
    }

#pragma omp parallel for num_threads(nt)
    for (int m = 0; m < pq.M; m++) {
        std::vector<double> dis_table;

        // printf ("Optimizing quantizer %d\n", m);

        float* centroids = pq.get_centroids(m, 0);

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                dis_table.push_back(fvec_L2sqr(
                        centroids + i * dsub, centroids + j * dsub, dsub));
            }
        }

        std::vector<int> perm(n);
        ReproduceWithHammingObjective obj(nbits, dis_table, dis_weight_factor);

        SimulatedAnnealingOptimizer optim(&obj, *this);

        if (log_pattern.size()) {
            char fname[256];
            snprintf(fname, 256, log_pattern.c_str(), m);
            printf("opening log file %s\n", fname);
            optim.logfile = fopen(fname, "w");
            FAISS_THROW_IF_NOT_MSG(optim.logfile, "could not open logfile");
        }
        double final_cost = optim.run_optimization(perm.data());

        if (verbose > 0) {
            printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
                   m,
                   optim.init_cost,
                   final_cost);
        }

        if (log_pattern.size())
            fclose(optim.logfile);

        std::vector<float> centroids_copy;
        for (int i = 0; i < dsub * n; i++)
            centroids_copy.push_back(centroids[i]);

        for (int i = 0; i < n; i++)
            memcpy(centroids + perm[i] * dsub,
                   centroids_copy.data() + i * dsub,
                   dsub * sizeof(centroids[0]));
    }
}

void PolysemousTraining::optimize_ranking(
        ProductQuantizer& pq,
        size_t n,
        const float* x) const {
    int dsub = pq.dsub;
    int nbits = pq.nbits;

    std::vector<uint8_t> all_codes(pq.code_size * n);

    pq.compute_codes(x, all_codes.data(), n);

    FAISS_THROW_IF_NOT(pq.nbits == 8);

    if (n == 0) {
        pq.compute_sdc_table();
    }

#pragma omp parallel for
    for (int m = 0; m < pq.M; m++) {
        size_t nq, nb;
        std::vector<uint32_t> codes;     // query codes, then db codes
        std::vector<float> gt_distances; // nq * nb matrix of distances

        if (n > 0) {
            std::vector<float> xtrain(n * dsub);
            for (int i = 0; i < n; i++)
                memcpy(xtrain.data() + i * dsub,
                       x + i * pq.d + m * dsub,
                       sizeof(float) * dsub);

            codes.resize(n);
            for (int i = 0; i < n; i++)
                codes[i] = all_codes[i * pq.code_size + m];

            nq = n / 4;
            nb = n - nq;
            const float* xq = xtrain.data();
            const float* xb = xq + nq * dsub;

            gt_distances.resize(nq * nb);

            pairwise_L2sqr(dsub, nq, xq, nb, xb, gt_distances.data());
        } else {
            nq = nb = pq.ksub;
            codes.resize(2 * nq);
            for (int i = 0; i < nq; i++)
                codes[i] = codes[i + nq] = i;

            gt_distances.resize(nq * nb);

            memcpy(gt_distances.data(),
                   pq.sdc_table.data() + m * nq * nb,
                   sizeof(float) * nq * nb);
        }

        double t0 = getmillisecs();

        std::unique_ptr<PermutationObjective> obj(new RankingScore2(
                nbits,
                nq,
                nb,
                codes.data(),
                codes.data() + nq,
                gt_distances.data()));

        if (verbose > 0) {
            printf("   m=%d, nq=%zd, nb=%zd, initialize RankingScore "
                   "in %.3f ms\n",
                   m,
                   nq,
                   nb,
                   getmillisecs() - t0);
        }

        SimulatedAnnealingOptimizer optim(obj.get(), *this);

        if (log_pattern.size()) {
            char fname[256];
            snprintf(fname, 256, log_pattern.c_str(), m);
            printf("opening log file %s\n", fname);
            optim.logfile = fopen(fname, "w");
            FAISS_THROW_IF_NOT_FMT(
                    optim.logfile, "could not open logfile %s", fname);
        }

        std::vector<int> perm(pq.ksub);

        double final_cost = optim.run_optimization(perm.data());
        printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
               m,
               optim.init_cost,
               final_cost);

        if (log_pattern.size())
            fclose(optim.logfile);

        float* centroids = pq.get_centroids(m, 0);

        std::vector<float> centroids_copy;
        for (int i = 0; i < dsub * pq.ksub; i++)
            centroids_copy.push_back(centroids[i]);

        for (int i = 0; i < pq.ksub; i++)
            memcpy(centroids + perm[i] * dsub,
                   centroids_copy.data() + i * dsub,
                   dsub * sizeof(centroids[0]));
    }
}

void PolysemousTraining::optimize_pq_for_hamming(
        ProductQuantizer& pq,
        size_t n,
        const float* x) const {
    if (optimization_type == OT_None) {
    } else if (optimization_type == OT_ReproduceDistances_affine) {
        optimize_reproduce_distances(pq);
    } else {
        optimize_ranking(pq, n, x);
    }

    pq.compute_sdc_table();
}

size_t PolysemousTraining::memory_usage_per_thread(
        const ProductQuantizer& pq) const {
    size_t n = pq.ksub;

    switch (optimization_type) {
        case OT_None:
            return 0;
        case OT_ReproduceDistances_affine:
            return n * n * sizeof(double) * 3;
        case OT_Ranking_weighted_diff:
            return n * n * n * sizeof(float);
    }

    FAISS_THROW_MSG("Invalid optmization type");
    return 0;
}

} // namespace faiss
