# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time

from faiss.contrib import rpc

import combined_index
import argparse



############################################################
# Server implementation
############################################################


class MyServer(rpc.Server):
    """ Assign version that can be exposed via RPC """
    def __init__(self, s, index):
        rpc.Server.__init__(self, s)
        self.index = index

    def __getattr__(self, f):
        return getattr(self.index, f)

def main():
    parser = argparse.ArgumentParser()

    def aa(*args, **kwargs):
        group.add_argument(*args, **kwargs)

    group = parser.add_argument_group('server options')
    aa('--port', default=12012, type=int, help='server port')
    aa('--when_ready_dir', default=None,
       help='store host:port to this file when ready')
    aa('--ipv4', default=False, action='store_true', help='force ipv4')
    aa('--rank', default=0, type=int,
       help='rank used as index in the client table')

    args = parser.parse_args()

    when_ready = None
    if args.when_ready_dir:
        when_ready = '%s/%d' % (args.when_ready_dir, args.rank)

    print('loading index')

    index = combined_index.CombinedIndexDeep1B()

    print('starting server')
    rpc.run_server(
        lambda s: MyServer(s, index),
        args.port, report_to_file=when_ready,
        v6=not args.ipv4)

if __name__ == '__main__':
    main()


############################################################
# Client implementation
############################################################

from multiprocessing.pool import ThreadPool
import faiss
import numpy as np



class ResultHeap:
    """ Combine query results from a sliced dataset (for k-nn search) """

    def __init__(self, nq, k):
        " nq: number of query vectors, k: number of results per query "
        self.I = np.zeros((nq, k), dtype='int64')
        self.D = np.zeros((nq, k), dtype='float32')
        self.nq, self.k = nq, k
        heaps = faiss.float_maxheap_array_t()
        heaps.k = k
        heaps.nh = nq
        heaps.val = faiss.swig_ptr(self.D)
        heaps.ids = faiss.swig_ptr(self.I)
        heaps.heapify()
        self.heaps = heaps

    def add_batch_result(self, D, I, i0):
        assert D.shape == (self.nq, self.k)
        assert I.shape == (self.nq, self.k)
        I += i0
        self.heaps.addn_with_ids(
            self.k, faiss.swig_ptr(D),
            faiss.swig_ptr(I), self.k)

    def finalize(self):
        self.heaps.reorder()

def distribute_weights(weights, nbin):
    """ assign a set of weights to a smaller set of bins to balance them """
    nw = weights.size
    o = weights.argsort()
    bins = np.zeros(nbin)
    assign = np.ones(nw, dtype=int)
    for i in o[::-1]:
        b = bins.argmin()
        assign[i] = b
        bins[b] += weights[i]
    return bins, assign



class SplitPerListIndex:
    """manages a local index, that does the coarse quantization and a set
    of sub_indexes. The sub_indexes search a subset of the inverted
    lists. The SplitPerListIndex merges results from the sub-indexes"""

    def __init__(self, index, sub_indexes):
        self.index = index
        self.code_size = faiss.extract_index_ivf(index.index).code_size
        self.sub_indexes = sub_indexes
        self.ni = len(self.sub_indexes)
        # pool of threads. Each thread manages one sub-index.
        self.pool = ThreadPool(self.ni)
        self.verbose = False

    def set_nprobe(self, nprobe):
        self.index.set_nprobe(nprobe)
        self.pool.map(
            lambda i: self.sub_indexes[i].set_nprobe(nprobe),
            range(self.ni)
        )

    def set_omp_num_threads(self, nt):
        faiss.omp_set_num_threads(nt)
        self.pool.map(
            lambda idx: idx.set_omp_num_threads(nt),
            self.sub_indexes
        )

    def set_parallel_mode(self, pm):
        self.index.set_parallel_mode(pm)
        self.pool.map(
            lambda idx: idx.set_parallel_mode(pm),
            self.sub_indexes
        )

    def set_prefetch_nthread(self, nt):
        self.index.set_prefetch_nthread(nt)
        self.pool.map(
            lambda idx: idx.set_prefetch_nthread(nt),
            self.sub_indexes
        )

    def balance_lists(self, list_nos):
        big_il = self.index.big_il
        weights = np.array([big_il.list_size(int(i))
                            for i in list_nos.ravel()])
        bins, assign = distribute_weights(weights, self.ni)
        if self.verbose:
            print('bins weight range %d:%d total %d (%.2f MiB)' % (
                bins.min(), bins.max(), bins.sum(),
                bins.sum() * (self.code_size + 8) / 2 ** 20))
        self.nscan = bins.sum()
        return assign.reshape(list_nos.shape)

    def search(self, x, k):
        xqo, list_nos, coarse_dis = self.index.transform_and_assign(x)
        assign = self.balance_lists(list_nos)

        def do_query(i):
            sub_index = self.sub_indexes[i]
            list_nos_i = list_nos.copy()
            list_nos_i[assign != i] = -1
            t0 = time.time()
            Di, Ii = sub_index.ivf_search_preassigned(
                xqo, list_nos_i, coarse_dis, k)
            #print(list_nos_i, Ii)
            if self.verbose:
                print('client %d: %.3f s' % (i, time.time() - t0))
            return Di, Ii

        rh = ResultHeap(x.shape[0], k)

        for Di, Ii in self.pool.imap(do_query, range(self.ni)):
            #print("ADD", Ii, rh.I)
            rh.add_batch_result(Di, Ii, 0)
        rh.finalize()
        return rh.D, rh.I

    def range_search(self, x, radius):
        xqo, list_nos, coarse_dis = self.index.transform_and_assign(x)
        assign = self.balance_lists(list_nos)
        nq = len(x)

        def do_query(i):
            sub_index = self.sub_indexes[i]
            list_nos_i = list_nos.copy()
            list_nos_i[assign != i] = -1
            t0 = time.time()
            limi, Di, Ii = sub_index.ivf_range_search_preassigned(
                xqo, list_nos_i, coarse_dis, radius)
            if self.verbose:
                print('slice %d: %.3f s' % (i, time.time() - t0))
            return limi, Di, Ii

        D = [[] for i in range(nq)]
        I = [[] for i in range(nq)]

        sizes = np.zeros(nq, dtype=int)
        for lims, Di, Ii in self.pool.imap(do_query, range(self.ni)):
            for i in range(nq):
                l0, l1 = lims[i:i + 2]
                D[i].append(Di[l0:l1])
                I[i].append(Ii[l0:l1])
                sizes[i] += l1 - l0
        lims = np.zeros(nq + 1, dtype=int)
        lims[1:] = np.cumsum(sizes)
        D = np.hstack([j for i in D for j in i])
        I = np.hstack([j for i in I for j in i])
        return lims, D, I
