import scanpy as sc
import numpy as np
import argparse
import faiss
import warnings
import scanpy as sc
import time
import matplotlib.pyplot as plt
from sklearn import metrics
import seaborn as sns
import random
import numpy as np

N_COMPONENTS = 4096
NLIST = 128
N_PROBE = 16

def similarity_search(args, genes_encoded, embeddings, query, types, typesA, typesB, indexA, indexB,
                      mapping, ranking, indices_ot_mapping=None):
    indexes = []
    indexes_by_query = []
    types_unique_set = set()
    types_unique = np.unique(types)
    actual = np.empty(shape=[0, 0])
    predicted = np.empty(shape=[0, 0])
    DIMENSION = embeddings.shape[1]
    if args.normalization == 'b':
        DIMENSION = min(DIMENSION, N_COMPONENTS)
    NEIGHBORS = args.retrieved_for_each_cell
    start = time.time()
    if args.faiss_search == 'IP':
        quantizer = faiss.IndexFlatIP(DIMENSION)
        index = faiss.IndexIVFFlat(quantizer, min(N_COMPONENTS, DIMENSION), NLIST, faiss.METRIC_INNER_PRODUCT)
    if args.faiss_search == 'L2':
        quantizer = faiss.IndexFlatL2(DIMENSION)
        index = faiss.IndexIVFFlat(quantizer, DIMENSION, NLIST, faiss.METRIC_L2)

    index.train(genes_encoded)
    index.add(genes_encoded)

    print('Building index tree: ', time.time() - start)
    index.nprobe = N_PROBE
    start = time.time()
    query_index = 0
    for q in query:
        if typesA[query_index] not in typesB:
            query_index += 1
            continue
        if args.normalization == 'c':
            if indexA != indexB:
                q = genes_encoded[indices_ot_mapping[query_index]]
        if indexA != indexB:
            NEIGHBORS = args.retrieved_for_each_cell
        if indexA == indexB:
            NEIGHBORS = args.retrieved_for_each_cell
        for i in range(args.retrieved_for_each_cell):
            actual = np.append(actual, mapping[typesA[query_index]])
            types_unique_set.add(typesA[query_index])
        q = np.expand_dims(q, axis=0)

        D, I = index.search(q, NEIGHBORS)
        for i in range(len(I)):
            tmp = I[i].tolist()
            indexes_by_query.append(tmp)
            cnt = 0
            for j in range(len(tmp)):
                cnt += 1
                if indexA == indexB and tmp[j] == query_index:
                    continue
                indexes.append(tmp[j])
                if tmp[j] in ranking[query_index]:
                    ranking[query_index][tmp[j]] += args.retrieved_for_each_cell - j
                else:
                    ranking[query_index][tmp[j]] = args.retrieved_for_each_cell - j
                predicted = np.append(predicted, mapping[typesB[tmp[j]]])
                types_unique_set.add(typesB[tmp[j]])
        query_index += 1
    print('Query time: ', time.time() - start)
    return actual, predicted, types_unique_set, indexes, indexes_by_query, ranking