import numpy as np
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.cluster import SpectralCoclustering
from sklearn.metrics import consensus_score
from ...general_utils.normalization import Normalization_A
from ...general_utils.helper import (mapping, sequencing_methods, cell_data_loading, calculate_matches, intersection_union)
from ...general_utils.plotting import summary_plot, cell_type_similarity_plot
from ..retreival_utils.faiss_retreival import similarity_search
import time

USE_QUERY_CELL_COUNT_LIMITER = False
QUERY_CELL_COUNT = 500
REFERENCE_CELL_COUNT = 6240
parser = argparse.ArgumentParser(description='Process single-cell data.')
parser.add_argument('--input_adata', type=str, required=True, help='Path to input .h5ad file')
parser.add_argument('--reference_adata', type=str, required=True, help='Path to reference .h5ad file')
parser.add_argument('--method', type=str, required=True,
                    help='Method used for embeddings; all if you want to run all methods')
parser.add_argument('--retrieved_for_each_cell', type=int, required=True,
                    help='Number of cells to retrieve for each cell')
parser.add_argument('--normalization', type=str, required=True, help='Normalization method')
parser.add_argument('--faiss_search', type=str, required=True, help='Faiss search method')
parser.add_argument('--n_clusters', type=str, required=True, help='Number of Clusters for'
                                                                  ' Spectral Co-Clustering')
args = parser.parse_args()
retreived_count = args.retrieved_for_each_cell
adams, sequencing, labels, types, types_unique = cell_data_loading(args)
print(adams.obs)

methods = ['scimilarity',
           'pca',
           'cellfish.jl',
           'cell_blast', 'linearscvi',
           'scvi',  'scgpt', 'geneformer',   'uce', 'scMulan',
           'scfoundation', 'cellplm'
           ]

array_labels = ['Scimilarity',
                'PCA',
                'CellFishing.jl',
                'Cell BLAST', 'Linear Scvi',
                'Scvi', 'ScGPT', 'Geneformer', 'UCE', 'ScMulan',
                'ScFoundation', 'CellPLM'
                ]
M = len(methods)
input_embeddings_path = ['embeds/embeds_scimilarity.npy',
'embeds/embeds_pca.npy',
'embeds/embeds_cellfishing',
'embeds/embeds_cell_blast.npy',
'embeds/embeds_linearscvi.npy',
'embeds/embeds_scvi.npy',
'embeds/embeds_scgpt.npy',
        'embeds/embeds_geneformer.npy',
        'embeds/embeds_uce.npy',
        'embeds/embeds_scMulan.npy',
        'embeds/embeds_scfoundation.npy',
        'embeds/embeds_cellplm.npy'
        ]
cells = list()
ranking = []
for k in range(len(sequencing_methods)):
    sequencing_method = sequencing_methods[k]
    cells.append(list())
    for j in range(len(methods)):
        cells[k].append(list())
        if methods[j] == 'cellfish.jl':
            indexes_by_query = []
            data = pd.read_csv(f'cellfishing/neighbors_index_{k+1}_100.tsv', sep='\t')
            for _, row in data.iterrows():
                tmp = []
                for i in range(1, args.retrieved_for_each_cell + 1):
                    column = f'n{i}'
                    tmp.append(int(row[column]))
                indexes_by_query.append(tmp)
        else:
            method = methods[j]
            print(f'====== {method.capitalize()} ======')
            if methods[j] == 'pca':
                input_path = f'pca/method_{k+1}.npy'
            else:
                input_path = input_embeddings_path[j]

            start = time.time()
            embeddings = np.load(input_path)
            end = time.time()
            print("Loading time: ", end - start)
            typesA = np.empty(shape=[0, 0])
            typesB = np.empty(shape=[0, 0])
            query_cells = np.empty(shape=[0, embeddings.shape[1]])
            retreive_cells = np.empty(shape=[0, embeddings.shape[1]])

            start = time.time()
            limit  = len(sequencing)
            if USE_QUERY_CELL_COUNT_LIMITER:
                limit = REFERENCE_CELL_COUNT
            matching_indices = np.where(sequencing == sequencing_method)[0]
            query_cells = np.append(query_cells, embeddings[matching_indices].reshape(-1, embeddings.shape[1]), axis=0)
            for i in range(limit):
                if j == 0:
                    ranking.append(dict())
            typesA = np.append(typesA, types[matching_indices])
            non_matching_indices = np.where(sequencing != sequencing_method)[0]
            retreive_cells = np.append(retreive_cells,
                                       embeddings[non_matching_indices].reshape(-1, embeddings.shape[1]), axis=0)
            typesB = np.append(typesB, types[non_matching_indices])

            end = time.time()
            print("Selecting queries: ", end - start)
            if USE_QUERY_CELL_COUNT_LIMITER == True:
                query_cells = query_cells[:QUERY_CELL_COUNT]
                typesA = typesA[:QUERY_CELL_COUNT]
                retreive_cells = retreive_cells[:REFERENCE_CELL_COUNT]
                typesB =typesB[:REFERENCE_CELL_COUNT]
                print('Query cells: ',len(query_cells))
                print('Reference cells: ', len(retreive_cells))
            else:
                print('FULL Query and Reference')

            start = time.time()
            indices_ot_mapping = None
            query_cells = Normalization_A(query_cells)
            retreive_cells = Normalization_A(retreive_cells)
            end = time.time()
            print("Normalization: ", end - start)

            indexA = 0
            indexB = 0
            actual, predicted, types_unique_set, indexes, indexes_by_query, ranking = similarity_search(args, retreive_cells, embeddings,
                                                                                      query_cells, types, typesA, typesB,
                                                                                      indexA, indexB, mapping,
                                                                             ranking, indices_ot_mapping)
        cells[k][j] = indexes_by_query
arrays = [arr for arr in cells[0]]
match_matrix = np.zeros((len(sequencing_methods), len(arrays), len(arrays)))
cell_vote = np.zeros((len(sequencing_methods), len(arrays)))
average_overlap = np.zeros((len(sequencing_methods),len(arrays)))
for k in range(len(sequencing_methods)):
    arrays = [arr for arr in cells[k]]
    S_best = []

    for q in range(len(arrays[0])):
        tmp_dict = ranking[q]
        S_best.append(sorted(tmp_dict, key=tmp_dict.get, reverse=True)[:args.retrieved_for_each_cell])

    for i in range(len(arrays)):
        for j in range(len(arrays)):
            query_count = len(arrays[0])
            for q in range(query_count):
                try:
                    matches = intersection_union(arrays[i][q], arrays[j][q])
                    match_matrix[k][i][j] += matches
                except Exception as e:
                    print(e)

            match_matrix[k][i][j] /= query_count
            if i != j:
                overlap = float(match_matrix[k][i][j]) / (M - 1.0)
                average_overlap[k][i] += overlap

    for i in range(len(arrays)):
        query_count = len(arrays[0])
        for q in range(query_count):
            cell_vote[k][i] += intersection_union(arrays[i][q], S_best[i])
        cell_vote[k][i] /= query_count

    print(f'==========AVERAGE VOTE {array_labels[k]}==========')
    for label, overlap in zip(array_labels, average_overlap[k]):
        print(f"{label: <15}: {overlap:.2f}")

    print(f'==========CELL VOTE {array_labels[k]}=============')
    for label, overlap in zip(array_labels, cell_vote[k]):
        print(f"{label: <15}: {overlap:.2f}")

ratios = np.mean(match_matrix, axis=0)
print(ratios)
np.save(f'overlap_{retreived_count}_retrieved_cells_{args.n_clusters}_clusters_leave_one_out.npy', ratios)
model = SpectralCoclustering(n_clusters=int(args.n_clusters), random_state=0)
model.fit(ratios)
row_order = np.argsort(model.row_labels_)
col_order = np.argsort(model.column_labels_)
ratios_reordered = ratios[row_order][:, col_order]
array_labels_reordered = np.array(array_labels)[row_order]

ratios_df = pd.DataFrame(ratios_reordered, index=array_labels_reordered, columns=array_labels_reordered)

fig, ax = plt.subplots(figsize=(14,12))
sns.heatmap(ratios_df, square=True, cmap='viridis', annot=True, fmt='.2f', xticklabels=array_labels_reordered,
            yticklabels=array_labels_reordered, vmin=0, vmax=1, annot_kws={"size": 17})
ax.set_xticklabels(ax.get_xticklabels(), fontsize=20)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=20)
ax.tick_params(axis='x', rotation=90)
ax.tick_params(axis='y', rotation=0)

plt.tight_layout()
plt.savefig(f'overlap_{retreived_count}_retrieved_cells_{args.n_clusters}_clusters_leave_one_out.png')
plt.show()