/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <cstdio>
#include <cstdlib>
#include <random>

#include <gtest/gtest.h>

#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>

TEST(IVFPQ, accuracy) {
    // dimension of the vectors to index
    int d = 64;

    // size of the database we plan to index
    size_t nb = 1000;

    // make a set of nt training vectors in the unit cube
    // (could be the database)
    size_t nt = 1500;

    // make the index object and train it
    faiss::IndexFlatL2 coarse_quantizer(d);

    // a reasonable number of cetroids to index nb vectors
    int ncentroids = 25;

    faiss::IndexIVFPQ index(&coarse_quantizer, d, ncentroids, 16, 8);

    // index that gives the ground-truth
    faiss::IndexFlatL2 index_gt(d);

    std::mt19937 rng;
    std::uniform_real_distribution<> distrib;

    { // training

        std::vector<float> trainvecs(nt * d);
        for (size_t i = 0; i < nt * d; i++) {
            trainvecs[i] = distrib(rng);
        }
        index.verbose = true;
        index.train(nt, trainvecs.data());
    }

    { // populating the database

        std::vector<float> database(nb * d);
        for (size_t i = 0; i < nb * d; i++) {
            database[i] = distrib(rng);
        }

        index.add(nb, database.data());
        index_gt.add(nb, database.data());
    }

    int nq = 200;
    int n_ok;

    { // searching the database

        std::vector<float> queries(nq * d);
        for (size_t i = 0; i < nq * d; i++) {
            queries[i] = distrib(rng);
        }

        std::vector<faiss::idx_t> gt_nns(nq);
        std::vector<float> gt_dis(nq);

        index_gt.search(nq, queries.data(), 1, gt_dis.data(), gt_nns.data());

        index.nprobe = 5;
        int k = 5;
        std::vector<faiss::idx_t> nns(k * nq);
        std::vector<float> dis(k * nq);

        index.search(nq, queries.data(), k, dis.data(), nns.data());

        n_ok = 0;
        for (int q = 0; q < nq; q++) {
            for (int i = 0; i < k; i++)
                if (nns[q * k + i] == gt_nns[q])
                    n_ok++;
        }
        EXPECT_GT(n_ok, nq * 0.4);
    }
}
