# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Common functions to load datasets and compute their ground-truth
"""

import time
import numpy as np
import faiss

from faiss.contrib import datasets as faiss_datasets

print("path:", faiss_datasets.__file__)

faiss_datasets.dataset_basedir = '/checkpoint/matthijs/simsearch/'

def sanitize(x):
    return np.ascontiguousarray(x, dtype='float32')


#################################################################
# Dataset
#################################################################

class DatasetCentroids(faiss_datasets.Dataset):

    def __init__(self, ds, indexfile):
        self.d = ds.d
        self.metric = ds.metric
        self.nq = ds.nq
        self.xq = ds.get_queries()

        # get the xb set
        src_index = faiss.read_index(indexfile)
        src_quant = faiss.downcast_index(src_index.quantizer)
        centroids = faiss.vector_to_array(src_quant.xb)
        self.xb = centroids.reshape(-1, self.d)
        self.nb = self.nt = len(self.xb)

    def get_queries(self):
        return self.xq

    def get_database(self):
        return self.xb

    def get_train(self, maxtrain=None):
        return self.xb

    def get_groundtruth(self, k=100):
        return faiss.knn(
            self.xq, self.xb, k,
            faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT
        )[1]






def load_dataset(dataset='deep1M', compute_gt=False, download=False):

    print("load data", dataset)

    if dataset == 'sift1M':
        return faiss_datasets.DatasetSIFT1M()

    elif dataset.startswith('bigann'):

        dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1])

        return faiss_datasets.DatasetBigANN(nb_M=dbsize)

    elif dataset.startswith("deep_centroids_"):
        ncent = int(dataset[len("deep_centroids_"):])
        centdir = "/checkpoint/matthijs/bench_all_ivf/precomputed_clusters"
        return DatasetCentroids(
            faiss_datasets.DatasetDeep1B(nb=1000000),
            f"{centdir}/clustering.dbdeep1M.IVF{ncent}.faissindex"
        )

    elif dataset.startswith("deep"):

        szsuf = dataset[4:]
        if szsuf[-1] == 'M':
            dbsize = 10 ** 6 * int(szsuf[:-1])
        elif szsuf == '1B':
            dbsize = 10 ** 9
        elif szsuf[-1] == 'k':
            dbsize = 1000 * int(szsuf[:-1])
        else:
            assert False, "did not recognize suffix " + szsuf
        return faiss_datasets.DatasetDeep1B(nb=dbsize)

    elif dataset == "music-100":
        return faiss_datasets.DatasetMusic100()

    elif dataset == "glove":
        return faiss_datasets.DatasetGlove(download=download)

    else:
        assert False


#################################################################
# Evaluation
#################################################################


def evaluate_DI(D, I, gt):
    nq = gt.shape[0]
    k = I.shape[1]
    rank = 1
    while rank <= k:
        recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
        print("R@%d: %.4f" % (rank, recall), end=' ')
        rank *= 10


def evaluate(xq, gt, index, k=100, endl=True):
    t0 = time.time()
    D, I = index.search(xq, k)
    t1 = time.time()
    nq = xq.shape[0]
    print("\t %8.4f ms per query, " % (
        (t1 - t0) * 1000.0 / nq), end=' ')
    rank = 1
    while rank <= k:
        recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
        print("R@%d: %.4f" % (rank, recall), end=' ')
        rank *= 10
    if endl:
        print()
    return D, I
