import scanpy as sc
import numpy as np
import argparse
import faiss
import warnings
import scanpy as sc
import time
import matplotlib.pyplot as plt
from sklearn import metrics
import seaborn as sns
import random
import numpy as np

disease_types =  ['normal', 'Alzheimer disease']
mapping = {
    'B cell': 0,
    'CD14+ monocyte': 1,
    'CD16+ monocyte': 2,
    'CD4+ T cell': 3,
    'Cytotoxic T cell': 4,
    'Dendritic cell': 5,
    'Megakaryocyte': 6,
    'Natural killer cell': 7,
    'Plasmacytoid dendritic cell': 8,
    'Unassigned': 9
}

sequencing_methods =  ['10x Chromium (v2)', '10x Chromium (v2) A', '10x Chromium (v2) B',
'10x Chromium (v3)', 'CEL-Seq2', 'Drop-seq', 'Seq-Well', 'Smart-seq2',
'inDrops']


def cell_data_loading(args):
    adams = sc.read(args.input_adata)
    sequencing = np.array(adams.obs['Method'])
    labels = np.unique(sequencing)
    types = np.array(adams.obs['CellType'])
    types_unique = np.unique(types)
    return adams, sequencing, labels, types, types_unique


def calculate_matches(sorted_arr1, sorted_arr2):
    i = j = matches = 0
    while i < len(sorted_arr1) and j < len(sorted_arr2):
        if sorted_arr1[i] == sorted_arr2[j]:
            matches += 1
            i += 1
            j += 1
        elif sorted_arr1[i] < sorted_arr2[j]:
            i += 1
        else:
            j += 1

    return matches


def intersection_union(arr1, arr2):
    return len(set(arr1).intersection(set(arr2))) / len(set(arr1).union(set(arr2)))
