import scanpy as sc
import numpy as np
import argparse
import faiss
import warnings
import scanpy as sc
import time
import matplotlib.pyplot as plt
from sklearn import metrics
import seaborn as sns
import random
import numpy as np

def cell_type_similarity_plot(actual, predicted, method1, method2, method, types_unique, types_unique_set, normalization, out_data):
    cm = metrics.confusion_matrix(actual, predicted)
    cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    for typing in types_unique:
        if typing not in types_unique_set:
            index = np.argwhere(types_unique==typing)
            types_unique = np.delete(types_unique, index)

    fig, ax = plt.subplots(figsize=(12,12))
    sns.heatmap(cmn, annot=True, fmt='.2f', xticklabels=types_unique, yticklabels=types_unique)
    plt.ylabel('Original Cell Type')
    plt.xlabel('Retreived Cell Type')

    if normalization == 'a':
        plt.title(f'Queried from {method1}, Retreived from {method2}, Embeddings generated using {method.capitalize()}\n (Z-score Normalization)')
    elif normalization == 'b':
        plt.title(f'Queried from {method1}, Retreived from {method2}, Embeddings generated using {method.capitalize()}\n (Whitening)')
    elif normalization == 'c':
        plt.title(f'Queried from {method1}, Retreived from {method2}, Embeddings generated using {method.capitalize()}\n (Optimal Transport)')
    else:
        plt.title(f'Queried from {method1}, Retreived from {method2}, Embeddings generated using {method.capitalize()}')

    plt.show(block=False)
    method1 = method1.replace(' ', '')
    method2 = method2.replace(' ', '')
    tmp_s = out_data + f'/{method}_{method1}_{method2}.png'

    plt.tight_layout()
    print(tmp_s + ' saved')
    plt.savefig(tmp_s)


def summary_plot(args, ratios, method, out_data, sequencing_methods):
    fig, ax = plt.subplots(figsize=(10,10))
    sns.heatmap(ratios, square=True, cmap='rocket_r', annot=True, fmt='.2f', xticklabels=sequencing_methods, yticklabels=sequencing_methods, vmin=0, vmax=1)
    plt.ylabel('Original Sequencing Method')
    plt.xlabel('Retrieved Sequencing Method')
    if args.normalization == 'a':
        plt.title(f'{method.capitalize()} (Z-score Normalization)')
    elif args.normalization == 'b':
        plt.title(f'{method.capitalize()} (Whitening)')
    elif args.normalization == 'c':
        plt.title(f'{method.capitalize()} (Optimal Transport)')
    else:
        plt.title(f'{method.capitalize()}')

    plt.show(block=False)
    path_tmp = out_data + f'/{method}_summary.png'
    plt.tight_layout()
    plt.savefig(path_tmp)
    print(path_tmp + ' saved')