import os

os.environ["MKL_NUM_THREADS"] = "32"
os.environ["OPENBLAS_NUM_THREADS"] = "32"
os.environ["NUMEXPR_NUM_THREADS"] = "32"
os.environ["VECLIB_MAXIMUM_THREADS"] = "32"
os.environ["OMP_NUM_THREADS"] = "32"
os.environ["FAISS_NUM_THREADS"] = "32"

# os.environ["MKL_NUM_THREADS"] = "1"
# os.environ["OPENBLAS_NUM_THREADS"] = "1"
# os.environ["NUMEXPR_NUM_THREADS"] = "1"
# os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["FAISS_NUM_THREADS"] = "1"

import faiss

import numpy as np
import timeit

def mmap_bin(bin_path, num_rows, num_cols, dtype=np.float32):
    # ('r+' = read/write; 'r' = read-only)
    # mode='c' = copy-on-write (modifications are NOT written back)
    return np.memmap(bin_path, dtype=dtype, mode='c', shape=(num_rows, num_cols))


def getAcc(exact, approx):
    n, k = np.shape(exact)
    result = 0
    for i in range(n):
        result += len(np.intersect1d(exact[i], approx[i])) / k
    return result / n

# Faiss BF
def faissBF(X, Q, k, numThreads, dist='ip'):

    n, d = np.shape(X)
    faiss.omp_set_num_threads(numThreads)

    t1 = timeit.default_timer()

    # build the index
    if (dist == 'l2'):
        index = faiss.IndexFlatL2(d)
    else:
        index = faiss.IndexFlatIP(d)

    index.add(X)                  # add vectors to the index
    t2 = timeit.default_timer()
    print('Faiss bruteforce index time: {}'.format(t2 - t1))


    t1 = timeit.default_timer()
    distances, exact_kNN = index.search(Q, k)
    t2 = timeit.default_timer()
    print('Faiss bruteforce query time: {}'.format(t2 - t1))

    # Cross-check bf of the first query
    # exactDOT = np.matmul(X, Q[0, :].transpose())  # Exact dot products
    # topK = np.argsort(-exactDOT)[:k]  # topK MIPS indexes

    return exact_kNN

def faissIVF(exact_kNN, X, Q, k=10, n_list = 100, n_probe = 10, n_threads=8, dist='ip'):
    """
    Run label propagation clustering using Faiss + iGraph.

    Parameters:
    - X: np.ndarray of shape (n, d)
    - k: number of nearest neighbors (default: 10)
    - metric: 'squared_l2' or 'dot_product'

    Returns:
    - labels: list of cluster labels for each point
    """

    X = X.astype(np.float32)
    n, d = X.shape

    # 1. Create FAISS index
    faiss.omp_set_num_threads(n_threads) # This is also default
    nlist = n_list  # the number of clusters
    print("nlist = ", nlist)

    t1 = timeit.default_timer()
    if (dist == 'l2'):
        quantizer = faiss.IndexFlatL2(d)  # the other index
    else:
        quantizer = faiss.IndexFlatIP(d)  # the other index

    index = faiss.IndexIVFFlat(quantizer, d, nlist)
    # 8 specifies that each sub-vector is encoded as 8 bits
    index.train(X)
    index.add(X)

    t2 = timeit.default_timer()
    print('Construction time of Faiss IVF (s): {}'.format(t2 - t1))

    for i in range(1):

        index.nprobe = n_probe + i * 10
        print("nprobe = ", index.nprobe)

        t1 = timeit.default_timer()
        dist, approx_kNN = index.search(Q, k=k)
        print('\tFaiss-IVF query time: {}'.format(timeit.default_timer() - t1))
        print("\tFaiss-IVF Accuracy: ", getAcc(exact_kNN, approx_kNN))


def faiss_approx_kNN_IVFPQ(exact_kNN, X, Q, k=10, n_list = 100, n_subquantizer = 8, n_probe = 10, n_threads=8, dist='ip'):
    """
    Run label propagation clustering using Faiss + iGraph.

    Parameters:
    - X: np.ndarray of shape (n, d)
    - k: number of nearest neighbors (default: 10)
    - metric: 'squared_l2' or 'dot_product'

    Returns:
    - labels: list of cluster labels for each point
    """

    X = X.astype(np.float32)
    n, d = X.shape

    # 1. Create FAISS index
    faiss.omp_set_num_threads(n_threads) # This is also default
    nlist = n_list  # the number of clusters
    print("nlist = ", nlist) # number of coarse centroids (IVF)

    m = n_subquantizer  # number of PQ subquantizers
    nbits = 8  # bits per subquantizer


    t1 = timeit.default_timer()
    if (dist == 'l2'):
        quantizer = faiss.IndexFlatL2(d)  # the other index
    else:
        quantizer = faiss.IndexFlatIP(d)

    index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)

    # 8 specifies that each sub-vector is encoded as 8 bits
    index.train(X)
    index.add(X)

    t2 = timeit.default_timer()
    print('Construction time of Faiss IVFPQ: {}'.format(t2 - t1))

    for i in range(1):

        index.nprobe = n_probe + i * 10
        print("nprobe = ", index.nprobe)

        t1 = timeit.default_timer()
        dist, approx_kNN = index.search(Q, k=k)
        print('\tFaiss-IVFPQ query time: {}'.format(timeit.default_timer() - t1))
        print("\tFaiss-IVFPQ Accuracy: ", getAcc(exact_kNN, approx_kNN))
    
# HNSW
def hnswMIPS(exact_kNN, X, Q, k, efSearch=100, n_threads=8, dist='ip'):

    n, d = np.shape(X)

    import hnswlib

    hnsw_m = 64  # The number of neighbors for HNSW. This is typically 32
    efConstruction = 16
    index = hnswlib.Index(space=dist, dim=d)
    print("m = %d, ef = %d" % (hnsw_m, efConstruction))

    index.set_num_threads(n_threads)
    t1 = timeit.default_timer()
    index.init_index(max_elements=n, ef_construction=efConstruction, M=hnsw_m)
    index.add_items(X)
    t2 = timeit.default_timer()
    print('Hnswlib index time: {}'.format(t2 - t1))

    for i in range(1):

        new_efSearch = efSearch + i * 10
        index.set_ef(new_efSearch)
        print("Hnsw efSearch: ", new_efSearch)
        t1 = timeit.default_timer()
        approx_kNN, dist = index.knn_query(Q, k=k)
        print('\tHnsw query time: {}'.format(timeit.default_timer() - t1))
        print("\tHnsw Accuracy: ", getAcc(exact_kNN, approx_kNN))

# scann
def scannMIPS(exact_kNN, X, Q, k):
    n, d = np.shape(X)

    import scann

    print('Constructing the Scann')
    t1 = timeit.default_timer()
    searcher = scann.scann_ops_pybind.builder(X, k, "dot_product").tree(
        num_leaves=5000, num_leaves_to_search=100, training_sample_size=n).score_ah(
        2, anisotropic_quantization_threshold=0.2).reorder(100).build()
    t2 = timeit.default_timer()
    print('Scann index time (s): {}'.format(t2 - t1))

    leaves_range = 100
    for j in range(1):
        # leaves = 50 + j * 10

        leaves = leaves_range
        t1 = timeit.default_timer()
        approx_kNN, dist = searcher.search_batched(Q, leaves_to_search=leaves, pre_reorder_num_neighbors=500)
        print("\tScann querying with {0} leaves has time: {1: .2f}".format(leaves, timeit.default_timer() - t1))
        print("\tScann Accuracy: ", getAcc(exact_kNN, approx_kNN))

if __name__ == '__main__':

    path = "/Dataset/ANNS/CEOs/"

    n = 17770
    d = 300
    bin_file = path + 'Netflix_X_17770_300.bin'
    X = mmap_bin(bin_file, n, d)

    q = 999
    bin_file = path + "Netflix_Q_999_300.bin"
    Q = mmap_bin(bin_file, q, d)











