/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <cstdio>
#include <cstdlib>

#include <memory>
#include <random>
#include <vector>

#include <gtest/gtest.h>

#include <faiss/AutoTune.h>
#include <faiss/IVFlib.h>
#include <faiss/IndexIVF.h>
#include <faiss/index_factory.h>

using namespace faiss;

namespace {

// dimension of the vectors to index
int d = 32;

// nb of training vectors
size_t nt = 5000;

// size of the database points per window step
size_t nb = 1000;

// nb of queries
size_t nq = 200;

std::mt19937 rng;

std::vector<float> make_data(size_t n) {
    std::vector<float> database(n * d);
    std::uniform_real_distribution<> distrib;

    for (size_t i = 0; i < n * d; i++) {
        database[i] = distrib(rng);
    }
    return database;
}

std::unique_ptr<Index> make_trained_index(const char* index_type) {
    auto index = std::unique_ptr<Index>(index_factory(d, index_type));
    auto xt = make_data(nt * d);
    index->train(nt, xt.data());
    ParameterSpace().set_index_parameter(index.get(), "nprobe", 4);
    return index;
}

std::vector<idx_t> search_index(Index* index, const float* xq) {
    int k = 10;
    std::vector<idx_t> I(k * nq);
    std::vector<float> D(k * nq);
    index->search(nq, xq, k, D.data(), I.data());
    return I;
}

/*************************************************************
 * Test functions for a given index type
 *************************************************************/

struct EncapsulateInvertedLists : InvertedLists {
    const InvertedLists* il;

    EncapsulateInvertedLists(const InvertedLists* il)
            : InvertedLists(il->nlist, il->code_size), il(il) {}

    static void* memdup(const void* m, size_t size) {
        if (size == 0)
            return nullptr;
        return memcpy(malloc(size), m, size);
    }

    size_t list_size(size_t list_no) const override {
        return il->list_size(list_no);
    }

    const uint8_t* get_codes(size_t list_no) const override {
        return (uint8_t*)memdup(
                il->get_codes(list_no), list_size(list_no) * code_size);
    }

    const idx_t* get_ids(size_t list_no) const override {
        return (idx_t*)memdup(
                il->get_ids(list_no), list_size(list_no) * sizeof(idx_t));
    }

    void release_codes(size_t, const uint8_t* codes) const override {
        free((void*)codes);
    }

    void release_ids(size_t, const idx_t* ids) const override {
        free((void*)ids);
    }

    const uint8_t* get_single_code(size_t list_no, size_t offset)
            const override {
        return (uint8_t*)memdup(
                il->get_single_code(list_no, offset), code_size);
    }

    size_t add_entries(size_t, size_t, const idx_t*, const uint8_t*) override {
        assert(!"not implemented");
        return 0;
    }

    void update_entries(size_t, size_t, size_t, const idx_t*, const uint8_t*)
            override {
        assert(!"not implemented");
    }

    void resize(size_t, size_t) override {
        assert(!"not implemented");
    }

    ~EncapsulateInvertedLists() override {}
};

int test_dealloc_invlists(const char* index_key) {
    std::unique_ptr<Index> index = make_trained_index(index_key);
    IndexIVF* index_ivf = ivflib::extract_index_ivf(index.get());

    auto xb = make_data(nb * d);
    index->add(nb, xb.data());

    auto xq = make_data(nq * d);

    auto ref_res = search_index(index.get(), xq.data());

    EncapsulateInvertedLists eil(index_ivf->invlists);

    index_ivf->own_invlists = false;
    index_ivf->replace_invlists(&eil, false);

    // TEST: this could crash or leak mem
    auto new_res = search_index(index.get(), xq.data());

    // delete explicitly
    delete eil.il;

    // just to make sure
    EXPECT_EQ(ref_res, new_res);
    return 0;
}

} // anonymous namespace

/*************************************************************
 * Test entry points
 *************************************************************/

TEST(TestIvlistDealloc, IVFFlat) {
    test_dealloc_invlists("IVF32,Flat");
}

TEST(TestIvlistDealloc, IVFSQ) {
    test_dealloc_invlists("IVF32,SQ8");
}

TEST(TestIvlistDealloc, IVFPQ) {
    test_dealloc_invlists("IVF32,PQ4np");
}
