# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import numpy as np
import unittest
import faiss


def make_binary_dataset(d, nb, nt, nq):
    assert d % 8 == 0
    rs = np.random.RandomState(123)
    x = rs.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8')
    return x[:nt], x[nt:-nq], x[-nq:]


def binary_to_float(x):
    n, d = x.shape
    x8 = x.reshape(n * d, -1)
    c8 = 2 * ((x8 >> np.arange(8)) & 1).astype('int8') - 1
    return c8.astype('float32').reshape(n, d * 8)


class TestIndexBinaryFromFloat(unittest.TestCase):
    """Use a binary index backed by a float index"""

    def test_index_from_float(self):
        d = 256
        nt = 0
        nb = 1500
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        index_ref = faiss.IndexFlatL2(d)
        index_ref.add(binary_to_float(xb))

        index = faiss.IndexFlatL2(d)
        index_bin = faiss.IndexBinaryFromFloat(index)
        index_bin.add(xb)

        D_ref, I_ref = index_ref.search(binary_to_float(xq), 10)
        D, I = index_bin.search(xq, 10)

        np.testing.assert_allclose((D_ref / 4.0).astype('int32'), D)

    def test_wrapped_quantizer(self):
        d = 256
        nt = 150
        nb = 1500
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        nlist = 16
        quantizer_ref = faiss.IndexBinaryFlat(d)
        index_ref = faiss.IndexBinaryIVF(quantizer_ref, d, nlist)
        index_ref.train(xt)

        index_ref.add(xb)

        unwrapped_quantizer = faiss.IndexFlatL2(d)
        quantizer = faiss.IndexBinaryFromFloat(unwrapped_quantizer)
        index = faiss.IndexBinaryIVF(quantizer, d, nlist)

        index.train(xt)

        index.add(xb)

        D_ref, I_ref = index_ref.search(xq, 10)
        D, I = index.search(xq, 10)

        np.testing.assert_array_equal(D_ref, D)

    def test_wrapped_quantizer_IMI(self):
        d = 256
        nt = 3500
        nb = 10000
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        index_ref = faiss.IndexBinaryFlat(d)

        index_ref.add(xb)

        nlist_exp = 6
        nlist = 2 ** (2 * nlist_exp)
        float_quantizer = faiss.MultiIndexQuantizer(d, 2, nlist_exp)
        wrapped_quantizer = faiss.IndexBinaryFromFloat(float_quantizer)
        wrapped_quantizer.train(xt)

        assert nlist == float_quantizer.ntotal

        index = faiss.IndexBinaryIVF(wrapped_quantizer, d,
                                     float_quantizer.ntotal)
        index.nprobe = 2048
        assert index.is_trained

        index.add(xb)

        D_ref, I_ref = index_ref.search(xq, 10)
        D, I = index.search(xq, 10)

        recall = sum(gti[0] in Di[:10] for gti, Di in zip(D_ref, D)) \
                 / float(D_ref.shape[0])

        assert recall > 0.82, "recall = %g" % recall

    def test_wrapped_quantizer_HNSW(self):

        def bin2float2d(v):
            n, d = v.shape
            vf = ((v.reshape(-1, 1) >> np.arange(8)) & 1).astype("float32")
            vf *= 2
            vf -= 1
            return vf.reshape(n, d * 8)

        d = 256
        nt = 12800
        nb = 10000
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        index_ref = faiss.IndexBinaryFlat(d)

        index_ref.add(xb)

        nlist = 256
        clus = faiss.Clustering(d, nlist)
        clus_index = faiss.IndexFlatL2(d)

        xt_f = bin2float2d(xt)
        clus.train(xt_f, clus_index)

        centroids = faiss.vector_to_array(clus.centroids).reshape(-1, clus.d)
        hnsw_quantizer = faiss.IndexHNSWFlat(d, 32)
        hnsw_quantizer.add(centroids)
        hnsw_quantizer.is_trained = True
        wrapped_quantizer = faiss.IndexBinaryFromFloat(hnsw_quantizer)

        assert nlist == hnsw_quantizer.ntotal
        assert nlist == wrapped_quantizer.ntotal
        assert wrapped_quantizer.is_trained

        index = faiss.IndexBinaryIVF(wrapped_quantizer, d,
                                     hnsw_quantizer.ntotal)
        index.nprobe = 128

        assert index.is_trained

        index.add(xb)

        D_ref, I_ref = index_ref.search(xq, 10)
        D, I = index.search(xq, 10)

        recall = sum(gti[0] in Di[:10] for gti, Di in zip(D_ref, D)) \
                 / float(D_ref.shape[0])

        assert recall >= 0.77, "recall = %g" % recall


class TestOverrideKmeansQuantizer(unittest.TestCase):

    def test_override(self):
        d = 256
        nt = 3500
        nb = 10000
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        def train_and_get_centroids(override_kmeans_index):
            index = faiss.index_binary_factory(d, "BIVF10")
            index.verbose = True

            if override_kmeans_index is not None:
                index.clustering_index = override_kmeans_index

            index.train(xt)

            centroids = faiss.downcast_IndexBinary(index.quantizer).xb
            return faiss.vector_to_array(centroids).reshape(-1, d // 8)

        centroids_ref = train_and_get_centroids(None)

        # should do the exact same thing
        centroids_new = train_and_get_centroids(faiss.IndexFlatL2(d))

        assert np.all(centroids_ref == centroids_new)

        # will do less accurate assignment... Sanity check that the
        # index is indeed used by kmeans
        centroids_new = train_and_get_centroids(faiss.IndexLSH(d, 16))

        assert not np.all(centroids_ref == centroids_new)
