#! /usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import print_function
import numpy as np
import faiss
import time
import os
import argparse


parser = argparse.ArgumentParser()

def aa(*args, **kwargs):
    group.add_argument(*args, **kwargs)

group = parser.add_argument_group('dataset options')
aa('--dim', type=int, default=64)
aa('--nb', type=int, default=int(1e6))
aa('--subset_len', type=int, default=int(1e5))
aa('--key', default='IVF1000,Flat')
aa('--nprobe', type=int, default=640)
aa('--no_intcallback', default=False, action='store_true')
aa('--twostage', default=False, action='store_true')
aa('--nt', type=int, default=-1)


args = parser.parse_args()
print("args:", args)


d = args.dim  # dimension
nb = args.nb  # database size
nq = 1000  # nb of queries
nt = 100000
subset_len = args.subset_len


np.random.seed(1234)  # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xq = np.random.random((nq, d)).astype('float32')
xt = np.random.random((nt, d)).astype('float32')
k = 100

if args.no_intcallback:
    faiss.InterruptCallback.clear_instance()

if args.nt != -1:
    faiss.omp_set_num_threads(args.nt)

nprobe = args.nprobe
key = args.key
#key = 'IVF1000,Flat'
# key = 'IVF1000,PQ64'
# key = 'IVF100_HNSW32,PQ64'

# faiss.omp_set_num_threads(1)

pf = 'dim%d_' % d
if d == 64:
    pf = ''

basename = '/tmp/base%s%s.index' % (pf, key)

if os.path.exists(basename):
    print('load', basename)
    index_1 = faiss.read_index(basename)
else:
    print('train + write', basename)
    index_1 = faiss.index_factory(d, key)
    index_1.train(xt)
    faiss.write_index(index_1, basename)

print('add')
index_1.add(xb)

print('set nprobe=', nprobe)
faiss.ParameterSpace().set_index_parameter(index_1, 'nprobe', nprobe)

class ResultHeap:
    """ Combine query results from a sliced dataset """

    def __init__(self, nq, k):
        " nq: number of query vectors, k: number of results per query "
        self.I = np.zeros((nq, k), dtype='int64')
        self.D = np.zeros((nq, k), dtype='float32')
        self.nq, self.k = nq, k
        heaps = faiss.float_maxheap_array_t()
        heaps.k = k
        heaps.nh = nq
        heaps.val = faiss.swig_ptr(self.D)
        heaps.ids = faiss.swig_ptr(self.I)
        heaps.heapify()
        self.heaps = heaps

    def add_batch_result(self, D, I, i0):
        assert D.shape == (self.nq, self.k)
        assert I.shape == (self.nq, self.k)
        I += i0
        self.heaps.addn_with_ids(
            self.k, faiss.swig_ptr(D),
            faiss.swig_ptr(I), self.k)

    def finalize(self):
        self.heaps.reorder()

stats = faiss.cvar.indexIVF_stats
stats.reset()

print('index size', index_1.ntotal,
      'imbalance', index_1.invlists.imbalance_factor())
start = time.time()
Dref, Iref = index_1.search(xq, k)
print('time of searching: %.3f s = %.3f + %.3f ms' % (
    time.time() - start, stats.quantization_time, stats.search_time))

indexes = {}
if args.twostage:

    for i in range(0, nb, subset_len):
        index = faiss.read_index(basename)
        faiss.ParameterSpace().set_index_parameter(index, 'nprobe', nprobe)
        print("add %d:%d" %(i, i+subset_len))
        index.add(xb[i:i + subset_len])
        indexes[i] = index

rh = ResultHeap(nq, k)
sum_time = tq = ts = 0
for i in range(0, nb, subset_len):
    if not args.twostage:
        index = faiss.read_index(basename)
        faiss.ParameterSpace().set_index_parameter(index, 'nprobe', nprobe)
        print("add %d:%d" %(i, i+subset_len))
        index.add(xb[i:i + subset_len])
    else:
        index = indexes[i]

    stats.reset()
    start = time.time()
    Di, Ii = index.search(xq, k)
    sum_time = sum_time + time.time() - start
    tq += stats.quantization_time
    ts += stats.search_time
    rh.add_batch_result(Di, Ii, i)

print('time of searching separately: %.3f s = %.3f + %.3f ms' %
      (sum_time, tq, ts))

rh.finalize()

print('diffs: %d / %d'  % ((Iref != rh.I).sum(), Iref.size))
