# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from multiprocessing.pool import ThreadPool
import faiss
from typing import List, Tuple

from . import rpc

############################################################
# Server implementation
############################################################


class SearchServer(rpc.Server):
    """ Assign version that can be exposed via RPC """

    def __init__(self, s: int, index: faiss.Index):
        rpc.Server.__init__(self, s)
        self.index = index
        self.index_ivf = faiss.extract_index_ivf(index)

    def set_nprobe(self, nprobe: int) -> int:
        """ set nprobe field """
        self.index_ivf.nprobe = nprobe

    def get_ntotal(self) -> int:
        return self.index.ntotal

    def __getattr__(self, f):
        # all other functions get forwarded to the index
        return getattr(self.index, f)


def run_index_server(index: faiss.Index, port: int, v6: bool = False):
    """ serve requests for that index forerver """
    rpc.run_server(
        lambda s: SearchServer(s, index),
        port, v6=v6)


############################################################
# Client implementation
############################################################

class ClientIndex:
    """manages a set of distance sub-indexes. The sub_indexes search a
    subset of the inverted lists. Searches are merged afterwards
    """

    def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False):
        """ connect to a series of (host, port) pairs """
        self.sub_indexes = []
        for machine, port in machine_ports:
            self.sub_indexes.append(rpc.Client(machine, port, v6))

        self.ni = len(self.sub_indexes)
        # pool of threads. Each thread manages one sub-index.
        self.pool = ThreadPool(self.ni)
        # test connection...
        self.ntotal = self.get_ntotal()
        self.verbose = False

    def set_nprobe(self, nprobe: int) -> None:
        self.pool.map(
            lambda idx: idx.set_nprobe(nprobe),
            self.sub_indexes
        )

    def set_omp_num_threads(self, nt: int) -> None:
        self.pool.map(
            lambda idx: idx.set_omp_num_threads(nt),
            self.sub_indexes
        )

    def get_ntotal(self) -> None:
        return sum(self.pool.map(
            lambda idx: idx.get_ntotal(),
            self.sub_indexes
        ))

    def search(self, x, k: int):

        rh = faiss.ResultHeap(x.shape[0], k)

        for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes):
            rh.add_result(Di, Ii)
        rh.finalize()
        return rh.D, rh.I
