# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import unittest
import faiss
import tempfile
import os
import io
import sys
import pickle
import platform
from multiprocessing.pool import ThreadPool
from common_faiss_tests import get_dataset_2


d = 32
nt = 2000
nb = 1000
nq = 200

class TestIOVariants(unittest.TestCase):

    def test_io_error(self):
        d, n = 32, 1000
        x = np.random.uniform(size=(n, d)).astype('float32')
        index = faiss.IndexFlatL2(d)
        index.add(x)
        fd, fname = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index(index, fname)

            # should be fine
            faiss.read_index(fname)

            with open(fname, 'rb') as f:
                data = f.read()
            # now damage file
            with open(fname, 'wb') as f:
                f.write(data[:int(len(data) / 2)])

            # should make a nice readable exception that mentions the filename
            try:
                faiss.read_index(fname)
            except RuntimeError as e:
                if fname not in str(e):
                    raise
            else:
                raise

        finally:
            if os.path.exists(fname):
                os.unlink(fname)


class TestCallbacks(unittest.TestCase):

    def do_write_callback(self, bsz):
        d, n = 32, 1000
        x = np.random.uniform(size=(n, d)).astype('float32')
        index = faiss.IndexFlatL2(d)
        index.add(x)

        f = io.BytesIO()
        # test with small block size
        writer = faiss.PyCallbackIOWriter(f.write, 1234)

        if bsz > 0:
            writer = faiss.BufferedIOWriter(writer, bsz)

        faiss.write_index(index, writer)
        del writer   # make sure all writes committed

        if sys.version_info[0] < 3:
            buf = f.getvalue()
        else:
            buf = f.getbuffer()

        index2 = faiss.deserialize_index(np.frombuffer(buf, dtype='uint8'))

        self.assertEqual(index.d, index2.d)
        np.testing.assert_array_equal(
            faiss.vector_to_array(index.codes),
            faiss.vector_to_array(index2.codes)
        )

        # This is not a callable function: should raise an exception
        writer = faiss.PyCallbackIOWriter("blabla")
        self.assertRaises(
            Exception,
            faiss.write_index, index, writer
        )

    def test_buf_read(self):
        x = np.random.uniform(size=20)

        fd, fname = tempfile.mkstemp()
        os.close(fd)
        try:
            x.tofile(fname)

            with open(fname, 'rb') as f:
                reader = faiss.PyCallbackIOReader(f.read, 1234)

                bsz = 123
                reader = faiss.BufferedIOReader(reader, bsz)

                y = np.zeros_like(x)
                reader(faiss.swig_ptr(y), y.nbytes, 1)

            np.testing.assert_array_equal(x, y)
        finally:
            if os.path.exists(fname):
                os.unlink(fname)

    def do_read_callback(self, bsz):
        d, n = 32, 1000
        x = np.random.uniform(size=(n, d)).astype('float32')
        index = faiss.IndexFlatL2(d)
        index.add(x)

        fd, fname = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index(index, fname)

            with open(fname, 'rb') as f:
                reader = faiss.PyCallbackIOReader(f.read, 1234)

                if bsz > 0:
                    reader = faiss.BufferedIOReader(reader, bsz)

                index2 = faiss.read_index(reader)

            self.assertEqual(index.d, index2.d)
            np.testing.assert_array_equal(
                faiss.vector_to_array(index.codes),
                faiss.vector_to_array(index2.codes)
            )

            # This is not a callable function: should raise an exception
            reader = faiss.PyCallbackIOReader("blabla")
            self.assertRaises(
                Exception,
                faiss.read_index, reader
            )
        finally:
            if os.path.exists(fname):
                os.unlink(fname)

    def test_write_callback(self):
        self.do_write_callback(0)

    def test_write_buffer(self):
        self.do_write_callback(123)
        self.do_write_callback(2345)

    def test_read_callback(self):
        self.do_read_callback(0)

    def test_read_callback_buffered(self):
        self.do_read_callback(123)
        self.do_read_callback(12345)

    def test_read_buffer(self):
        d, n = 32, 1000
        x = np.random.uniform(size=(n, d)).astype('float32')
        index = faiss.IndexFlatL2(d)
        index.add(x)

        fd, fname = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index(index, fname)

            reader = faiss.BufferedIOReader(
                faiss.FileIOReader(fname), 1234)

            index2 = faiss.read_index(reader)

            self.assertEqual(index.d, index2.d)
            np.testing.assert_array_equal(
                faiss.vector_to_array(index.codes),
                faiss.vector_to_array(index2.codes)
            )

        finally:
            del reader
            if os.path.exists(fname):
                os.unlink(fname)


    def test_transfer_pipe(self):
        """ transfer an index through a Unix pipe """

        d, n = 32, 1000
        x = np.random.uniform(size=(n, d)).astype('float32')
        index = faiss.IndexFlatL2(d)
        index.add(x)
        Dref, Iref = index.search(x, 10)

        rf, wf = os.pipe()

        # start thread that will decompress the index

        def index_from_pipe():
            reader = faiss.PyCallbackIOReader(lambda size: os.read(rf, size))
            return faiss.read_index(reader)

        with ThreadPool(1) as pool:
            fut = pool.apply_async(index_from_pipe, ())

            # write to pipe
            writer = faiss.PyCallbackIOWriter(lambda b: os.write(wf, b))
            faiss.write_index(index, writer)

            index2 = fut.get()

            # closing is not really useful but it does not hurt
            os.close(wf)
            os.close(rf)

        Dnew, Inew = index2.search(x, 10)

        np.testing.assert_array_equal(Iref, Inew)
        np.testing.assert_array_equal(Dref, Dnew)


class PyOndiskInvertedLists:
    """ wraps an OnDisk object for use from C++ """

    def __init__(self, oil):
        self.oil = oil

    def list_size(self, list_no):
        return self.oil.list_size(list_no)

    def get_codes(self, list_no):
        oil = self.oil
        assert 0 <= list_no < oil.lists.size()
        l = oil.lists.at(list_no)
        with open(oil.filename, 'rb') as f:
            f.seek(l.offset)
            return f.read(l.size * oil.code_size)

    def get_ids(self, list_no):
        oil = self.oil
        assert 0 <= list_no < oil.lists.size()
        l = oil.lists.at(list_no)
        with open(oil.filename, 'rb') as f:
            f.seek(l.offset + l.capacity * oil.code_size)
            return f.read(l.size * 8)


class TestPickle(unittest.TestCase):

    def dump_load_factory(self, fs):
        xq = faiss.randn((25, 10), 123)
        xb = faiss.randn((25, 10), 124)

        index = faiss.index_factory(10, fs)
        index.train(xb)
        index.add(xb)
        Dref, Iref = index.search(xq, 4)

        buf = io.BytesIO()
        pickle.dump(index, buf)
        buf.seek(0)
        index2 = pickle.load(buf)

        Dnew, Inew = index2.search(xq, 4)

        np.testing.assert_array_equal(Iref, Inew)
        np.testing.assert_array_equal(Dref, Dnew)

    def test_flat(self):
        self.dump_load_factory("Flat")

    def test_hnsw(self):
        self.dump_load_factory("HNSW32")

    def test_ivf(self):
        self.dump_load_factory("IVF5,Flat")


class Test_IO_VectorTransform(unittest.TestCase):
    """
    test write_VectorTransform using IOWriter Pointer
    and read_VectorTransform using file name
    """
    def test_write_vector_transform(self):
        d, n = 32, 1000
        x = np.random.uniform(size=(n, d)).astype('float32')
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
        index.train(x)
        index.add(x)
        fd, fname = tempfile.mkstemp()
        os.close(fd)
        try:

            writer = faiss.FileIOWriter(fname)
            faiss.write_VectorTransform(index.vt, writer)
            del writer

            vt = faiss.read_VectorTransform(fname)

            assert vt.d_in == index.vt.d_in
            assert vt.d_out == index.vt.d_out
            assert vt.is_trained

        finally:
            if os.path.exists(fname):
                os.unlink(fname)

    """
    test write_VectorTransform using file name
    and read_VectorTransform using IOWriter Pointer
    """
    def test_read_vector_transform(self):
        d, n = 32, 1000
        x = np.random.uniform(size=(n, d)).astype('float32')
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
        index.train(x)
        index.add(x)
        fd, fname = tempfile.mkstemp()
        os.close(fd)
        try:

            faiss.write_VectorTransform(index.vt, fname)

            reader = faiss.FileIOReader(fname)
            vt = faiss.read_VectorTransform(reader)
            del reader

            assert vt.d_in == index.vt.d_in
            assert vt.d_out == index.vt.d_out
            assert vt.is_trained
        finally:
            if os.path.exists(fname):
                os.unlink(fname)


class Test_IO_PQ(unittest.TestCase):
    """
    test read and write PQ.
    """
    def test_io_pq(self):
        xt, xb, xq = get_dataset_2(d, nt, nb, nq)
        index = faiss.IndexPQ(d, 4, 4)
        index.train(xt)

        fd, fname = tempfile.mkstemp()
        os.close(fd)

        try:
            faiss.write_ProductQuantizer(index.pq, fname)

            read_pq = faiss.read_ProductQuantizer(fname)

            self.assertEqual(index.pq.M, read_pq.M)
            self.assertEqual(index.pq.nbits, read_pq.nbits)
            self.assertEqual(index.pq.dsub, read_pq.dsub)
            self.assertEqual(index.pq.ksub, read_pq.ksub)
            np.testing.assert_array_equal(
                faiss.vector_to_array(index.pq.centroids),
                faiss.vector_to_array(read_pq.centroids)
            )

        finally:
            if os.path.exists(fname):
                os.unlink(fname)


class Test_IO_IndexLSH(unittest.TestCase):
    """
    test read and write IndexLSH.
    """
    def test_io_lsh(self):
        xt, xb, xq = get_dataset_2(d, nt, nb, nq)
        index_lsh = faiss.IndexLSH(d, 32, True, True)
        index_lsh.train(xt)
        index_lsh.add(xb)
        D, I = index_lsh.search(xq, 10)

        fd, fname = tempfile.mkstemp()
        os.close(fd)

        try:
            faiss.write_index(index_lsh, fname)

            reader = faiss.BufferedIOReader(
                faiss.FileIOReader(fname), 1234)
            read_index_lsh = faiss.read_index(reader)
            # Delete reader to prevent [WinError 32] The process cannot
            # access the file because it is being used by another process
            del reader

            self.assertEqual(index_lsh.d, read_index_lsh.d)
            np.testing.assert_array_equal(
                faiss.vector_to_array(index_lsh.codes),
                faiss.vector_to_array(read_index_lsh.codes)
            )
            D_read, I_read = read_index_lsh.search(xq, 10)

            np.testing.assert_array_equal(D, D_read)
            np.testing.assert_array_equal(I, I_read)

        finally:
            if os.path.exists(fname):
                os.unlink(fname)


class Test_IO_IndexIVFSpectralHash(unittest.TestCase):
    """
    test read and write IndexIVFSpectralHash.
    """
    def test_io_ivf_spectral_hash(self):
        nlist = 1000
        xt, xb, xq = get_dataset_2(d, nt, nb, nq)
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, 8, 1.0)
        index.train(xt)
        index.add(xb)
        D, I = index.search(xq, 10)

        fd, fname = tempfile.mkstemp()
        os.close(fd)

        try:
            faiss.write_index(index, fname)

            reader = faiss.BufferedIOReader(
                faiss.FileIOReader(fname), 1234)
            read_index = faiss.read_index(reader)
            del reader

            self.assertEqual(index.d, read_index.d)
            self.assertEqual(index.nbit, read_index.nbit)
            self.assertEqual(index.period, read_index.period)
            self.assertEqual(index.threshold_type, read_index.threshold_type)

            D_read, I_read = read_index.search(xq, 10)
            np.testing.assert_array_equal(D, D_read)
            np.testing.assert_array_equal(I, I_read)

        finally:
            if os.path.exists(fname):
                os.unlink(fname)

class TestIVFPQRead(unittest.TestCase):
    def test_reader(self):
        d, n = 32, 1000
        xq = np.random.uniform(size=(n, d)).astype('float32')
        xb = np.random.uniform(size=(n, d)).astype('float32')

        index = faiss.index_factory(32, "IVF32,PQ16np", faiss.METRIC_L2)
        index.train(xb)
        index.add(xb)
        fd, fname = tempfile.mkstemp()
        os.close(fd)

        try:
            faiss.write_index(index, fname)

            index_a = faiss.read_index(fname)
            index_b = faiss.read_index(fname, faiss.IO_FLAG_SKIP_PRECOMPUTE_TABLE)

            Da, Ia = index_a.search(xq, 10)
            Db, Ib = index_b.search(xq, 10)
            np.testing.assert_array_equal(Ia, Ib)
            np.testing.assert_almost_equal(Da, Db, decimal=5)

            codes_a = index_a.sa_encode(xq)
            codes_b = index_b.sa_encode(xq)
            np.testing.assert_array_equal(codes_a, codes_b)

        finally:
            if os.path.exists(fname):
                os.unlink(fname)



class TestIOFlatMMap(unittest.TestCase):
    @unittest.skipIf(
        platform.system() not in ["Windows", "Linux"],
        "supported OSes only"
    )
    def test_mmap(self): 
        xt, xb, xq = get_dataset_2(32, 0, 100, 50)
        index = faiss.index_factory(32, "SQfp16", faiss.METRIC_L2)
        # does not need training 
        index.add(xb)
        Dref, Iref = index.search(xq, 10)

        fd, fname = tempfile.mkstemp()
        os.close(fd)

        index2 = None
        try:
            faiss.write_index(index, fname)
            index2 = faiss.read_index(fname, faiss.IO_FLAG_MMAP_IFC)
            Dnew, Inew = index2.search(xq, 10)
            np.testing.assert_array_equal(Iref, Inew)
            np.testing.assert_array_equal(Dref, Dnew)
        finally:
            del index2

            if os.path.exists(fname):
                # skip the error. On Windows, index2 holds the handle file, 
                #   so it cannot be ensured that the file can be deleted
                #   unless index2 is collected by a GC
                try:
                    os.unlink(fname)
                except:
                    pass

    def test_zerocopy(self): 
        xt, xb, xq = get_dataset_2(32, 0, 100, 50)
        index = faiss.index_factory(32, "SQfp16", faiss.METRIC_L2)
        # does not need training 
        index.add(xb)
        Dref, Iref = index.search(xq, 10)

        serialized_index = faiss.serialize_index(index)
        reader = faiss.ZeroCopyIOReader(faiss.swig_ptr(serialized_index), serialized_index.size)
        index2 = faiss.read_index(reader)
        Dnew, Inew = index2.search(xq, 10)
        np.testing.assert_array_equal(Iref, Inew)
        np.testing.assert_array_equal(Dref, Dnew)
