# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time
import numpy as np
import faiss
from datasets import load_sift1M


print("load data")

xb, xq, xt, gt = load_sift1M()
nq, d = xq.shape

ncent = 256

variants = [(name, getattr(faiss.ScalarQuantizer, name))
            for name in dir(faiss.ScalarQuantizer)
            if name.startswith('QT_')]

quantizer = faiss.IndexFlatL2(d)
# quantizer.add(np.zeros((1, d), dtype='float32'))

if False:
    for name, qtype in [('flat', 0)] + variants:

        print("============== test", name)
        t0 = time.time()

        if name == 'flat':
            index = faiss.IndexIVFFlat(quantizer, d, ncent,
                                       faiss.METRIC_L2)
        else:
            index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
                                                  qtype, faiss.METRIC_L2)

        index.nprobe = 16
        print("[%.3f s] train" % (time.time() - t0))
        index.train(xt)
        print("[%.3f s] add" % (time.time() - t0))
        index.add(xb)
        print("[%.3f s] search" % (time.time() - t0))
        D, I = index.search(xq, 100)
        print("[%.3f s] eval" % (time.time() - t0))

        for rank in 1, 10, 100:
            n_ok = (I[:, :rank] == gt[:, :1]).sum()
            print("%.4f" % (n_ok / float(nq)), end=' ')
        print()

if True:
    for name, qtype in variants:

        print("============== test", name)

        for rsname, vals in [('RS_minmax',
                              [-0.4, -0.2, -0.1, -0.05, 0.0, 0.1, 0.5]),
                             ('RS_meanstd', [0.8, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0]),
                             ('RS_quantiles', [0.02, 0.05, 0.1, 0.15]),
                             ('RS_optim', [0.0])]:
            for val in vals:
                print("%-15s %5g    " % (rsname, val), end=' ')
                index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
                                                      qtype, faiss.METRIC_L2)
                index.nprobe = 16
                index.sq.rangestat = getattr(faiss.ScalarQuantizer,
                                          rsname)

                index.rangestat_arg = val

                index.train(xt)
                index.add(xb)
                t0 = time.time()
                D, I = index.search(xq, 100)
                t1 = time.time()

                for rank in 1, 10, 100:
                    n_ok = (I[:, :rank] == gt[:, :1]).sum()
                    print("%.4f" % (n_ok / float(nq)), end=' ')
                print("   %.3f s" % (t1 - t0))
