import scanpy as sc
import numpy as np
import argparse
import faiss
import warnings
import scanpy as sc
import time
import matplotlib.pyplot as plt
from sklearn import metrics

def find_id(genes, reference):
    index = -1
    for i in range(len(reference)):
        if reference[i] == genes:
            index = i
    return index

warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser(description='Process single-cell data.')
parser.add_argument('--input_adata', type=str, required=True, help='Path to input .h5ad file')
parser.add_argument('--input_embeddings', type=str, required=True, help='Path to input .npy file')
parser.add_argument('--reference_adata', type=str, required=True, help='Path to reference .h5ad file')
parser.add_argument('--reference_embeddings', type=str, required=True, help='Path to reference .npy file')
parser.add_argument('--method', type=str, required=True, help='Method used for embeddings')
parser.add_argument('--retrieved_for_each_cell', type=int, required=True,
                    help='Number of cells to retrieve for each cell')
parser.add_argument('--output_adata', type=str, required=True, help='Path to output .png file')
args = parser.parse_args()

adams = sc.read(args.input_adata)
print(adams.obs)
input('Press enter to continue')
sequencing = np.array(adams.obs['Method'])
print("Sequencing: ", sequencing)
print("Sequencing shape: ", sequencing.shape)
labels = np.unique(sequencing)
print("Sequencing unique: ", labels)
types = np.unique(np.array(adams.obs['Type']))
print("Types: ", types)
embeddings = np.load(args.input_embeddings)
print("Embeddings shape: ", embeddings.shape)

mapping = {
    '10x Chromium (v2)': 0,
    '10x Chromium (v2) A': 1,
    '10x Chromium (v2) B': 2,
    '10x Chromium (v3)': 3,
    'CEL-Seq2': 4,
    'Drop-seq': 5,
    'Seq-Well': 6,
    'Smart-seq2': 7,
    'inDrops': 8
}

actual = np.empty(shape=[0, 0])

predicted = np.empty(shape=[0, 0])

adams.obs['domain'] = 'input'
input_size = adams.n_obs
reference = sc.read(args.reference_adata)
method = args.method
method_full = 'X_' + method

genes_encoded = embeddings
query = embeddings

DIMENSION = embeddings.shape[1]

NEIGHBORS = args.retrieved_for_each_cell + 1

NLIST = 64

quantizer = faiss.IndexFlatL2(DIMENSION)
index = faiss.IndexIVFFlat(quantizer, DIMENSION, NLIST, faiss.METRIC_L2)
index.train(genes_encoded)
index.add(genes_encoded)

index.nprobe = 16

sample_indices = []

similar_cells = list()
query_index = 0
for q in query:

    NEIGHBORS = args.retrieved_for_each_cell + 1
    for i in range(NEIGHBORS - 1):
        actual = np.append(actual, mapping[sequencing[query_index]])
    q = np.expand_dims(q, axis=0)
    print('Query: ', q)

    fail = True
    while fail:
        D, I = index.search(q, NEIGHBORS)

        for i in range(len(I)):
            tmp = I[i].tolist()
            print("IVF: ", I[i].tolist())
            cnt = 0
            for j in range(len(tmp)):
                if tmp[j] not in sample_indices:
                    cnt += 1
                    if tmp[j] == query_index:
                        continue
                    similar_cells.append(tmp[j])
                    predicted = np.append(predicted, mapping[sequencing[tmp[j]]])
                    print("Similar cell: ", tmp[j])
                    print(reference.obs.index)
                if cnt == args.retrieved_for_each_cell + 1:
                    fail = False
                    break
            if cnt < args.retrieved_for_each_cell + 1:
                fail = True
                break
        NEIGHBORS += args.retrieved_for_each_cell + 1
        end_time = time.time()
    query_index += 1

print(actual)
print(predicted)
confusion_matrix = metrics.confusion_matrix(actual, predicted)
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
cm_display.plot()
plt.title('Confusion Matrix')
plt.show()
plt.savefig(args.output_adata)