import scanpy as sc
import numpy as np
import argparse
import faiss
import warnings
import scanpy as sc
import time
import matplotlib.pyplot as plt
from sklearn import metrics
import seaborn as sns
import random
import numpy as np
import ot
import time
import numpy as np
from sklearn.metrics import accuracy_score

def compare_search(args, genes_encoded, embeddings, query, types_retrieved, mapping, indices_ot_mapping=None, NLIST=32, N_PROBE=8):
    bnt = 0
    indexes = []
    actual = []
    predicted = []
    DIMENSION = embeddings.shape[1]
    NEIGHBORS = args.retrieved_for_each_cell
    if args.faiss_search == 'L2':
        quantizer = faiss.IndexFlatL2(DIMENSION)
        index = faiss.IndexIVFFlat(quantizer, DIMENSION, NLIST, faiss.METRIC_L2)
    index.train(genes_encoded)
    index.add(genes_encoded)
    index.nprobe = N_PROBE
    D, I = index.search(query, NEIGHBORS)
    query_index = 0
    for query_index in range(I.shape[0]):
        tmp = I[query_index].tolist()
        predicted.extend(tmp)
    index.nprobe = NLIST
    D, I = index.search(query, NEIGHBORS)
    query_index = 0
    for query_index in range(I.shape[0]):
        tmp = I[query_index].tolist()
        actual.extend(tmp)
    return actual, predicted

nlist_values = [4, 6, 8, 16, 32, 64, 128]
nprobe_values = [1, 2, 3, 4, 6, 8, 16, 24, 32]

def ablation_analysis(args, genes_encoded, embeddings, query, types_retrieved, mapping, method, out_method, indices_ot_mapping=None):
    results = []
    for nlist in nlist_values:
        for nprobe in nprobe_values:
            if nprobe < nlist:
                start_time = time.time()
                actual, predicted = compare_search(args, genes_encoded, embeddings, query, types_retrieved, mapping, indices_ot_mapping, nlist, nprobe)
                elapsed_time = time.time() - start_time
                accuracy = accuracy_score(actual, predicted)
                print(f"NLIST: {nlist}")
                print(f'NPROBE: {nprobe}')
                print(f'Elapsed Time: {elapsed_time}')
                print(f'Accuracy: {accuracy}')

                results.append((nlist, nprobe, elapsed_time, accuracy))
    np.save(f'ablation_analysis_{method}_out_{out_method}.npy', results)
    return results