from sklearn.cluster import KMeans, SpectralClustering
from sklearn.metrics import pairwise_distances
from sklearn_extra.cluster import KMedoids
from sentence_transformers import SentenceTransformer
import numpy as np
from FlagEmbedding import BGEM3FlagModel


def kmedoid_clustering(chunks, embedder, n_clusters, ablation_type=None):
    # Convert distance matrix to a format suitable for K-means (flattened)
    # Initialize KMeans with the desired number of clusters
    # pam is more accurate than alternate    
    chunk_embeddings = [] 
    for chunk in chunks:
        if ablation_type is not None and 'norm_cluster2' in ablation_type:
            emb = embedder.encode(chunk, normalize_embeddings=True)
        elif ablation_type is not None and ('lexical' in ablation_type or 'multivec' in ablation_type or 'hybrid_search' in ablation_type):
            emb = embedder.encode(chunk)['dense_vecs']
        else:
            emb = embedder.encode(chunk)
            if ablation_type is not None and 'norm_cluster' in ablation_type:
                emb = emb / np.linalg.norm(emb)
        chunk_embeddings.append(emb)
    # chunk_embeddings = embedder.encode(chunks)    

    if ablation_type == 'kmeans':
        cluster = KMeans(n_clusters=n_clusters)
    elif ablation_type == 'spectral':
        cluster = SpectralClustering(n_clusters=n_clusters)
    else:
        cluster = KMedoids(n_clusters=n_clusters, method='pam')

    # Fit the KMeans algorithm to the distance matrix
    cluster.fit(chunk_embeddings)

    # Get the cluster labels
    cluster_labels = cluster.labels_
    if ablation_type == 'kmeans' or ablation_type == 'spectral':
        cluster_medoids = None
    else:
        cluster_medoids = cluster.medoid_indices_

    return cluster_labels, cluster_medoids, chunk_embeddings



def get_similarities(chunks, query, embedding_model):
    if query is None:
        sentences = chunks
    else:
        sentences = chunks + [query]

    embeddings = embedding_model.encode(sentences)
    similarities = embedding_model.similarity(embeddings, embeddings)

    return similarities


def construct_subgraphs(chunks, query, embedding_model, goa_cluster_size, is_batch=False, search_method='greedy', ablation_type='None'):
    """
    Constructs subgraphs based on the similarity of chunks and a query.
    
    Args:
        chunks: List of text chunks to be clustered.
        query: Optional query string to include in the clustering.
        embedding_model: Pre-trained embedding model for similarity computation.
        n_clusters: Number of clusters to form.
        
    Returns:
        List of tuples containing cluster labels and medoids.
    """
    if is_batch:
        raise NotImplementedError("Batch processing is not supported for this function.")
    
    num_cluster = goa_cluster_size
    num_cluster = int(num_cluster)
    num_cluster = min(num_cluster, len(chunks))

    cluster_chunks = []
    cluster_queries = []
    if num_cluster > 1: 
        cluster_idx, cluster_medoids, chunk_embeddings = kmedoid_clustering(chunks, embedding_model, num_cluster, ablation_type)
        print(f"Cluster indices: {cluster_idx}")
        print(f"Cluster medoids: {cluster_medoids}")

        for i in range(num_cluster):
            if cluster_medoids is None:
                allocated = [chunks[j] for j in range(len(chunks)) if cluster_idx[j] == i]
            else:
                medoid = cluster_medoids[i]

                allocated = [chunks[j] for j in range(len(chunks)) if cluster_idx[j] == i and j != medoid]
                allocated = [chunks[medoid]] + allocated  # Put the medoid at the front
            if len(allocated) == 0:
                continue
            cluster_chunks.append(allocated)
            if query is not None:
                cluster_queries.append(query)
            else:
                cluster_queries.append(None)
            actual_num_cluster = len(set(cluster_idx))
    else:
        actual_num_cluster = 1
        cluster_chunks.append(chunks)
        if query is not None:
            cluster_queries.append(query)
        else:
            cluster_queries.append(None)
    
    if search_method == 'cluster_only':
        # If we are only clustering, we do not need to sort the chunks based on the query
        if actual_num_cluster == 1:
            return cluster_chunks[0], actual_num_cluster
        else:
            # If there are multiple clusters, return the clusters as they are
            return cluster_chunks, actual_num_cluster

    ## This part is used for non-contextual search, which predetermine the search orders
    cluster_distances = []
    cluster_ordered_chunks = []
    for i, (cs, qs) in enumerate(zip(cluster_chunks, cluster_queries)):        
        # Get distance 
        dist = get_similarities(cs, qs, embedding_model)
        cluster_distances.append(dist)

        ## Then, arg sort the chunks based on the distances 
        if search_method == 'greedy':
            ordered_idxs = greedy_search(dist)
        else:
            raise NotImplementedError("Currently only greedy search is supported")

        ordered_chunks = arg_sort_chunk(ordered_idxs, cs)
        cluster_ordered_chunks.append(ordered_chunks)

    if actual_num_cluster == 1:
        # If there is only one cluster, return the ordered chunks directly
        return cluster_ordered_chunks[0], actual_num_cluster
    else:
        # If there are multiple clusters, return a list of ordered chunks for each cluster
        return cluster_ordered_chunks, actual_num_cluster


def greedy_search(similarity_matrix):
    """    Perform a greedy search to find the path through the similarity matrix.
    
    Args:
        similarity_matrix: A square matrix of similarities between nodes.
        
    Returns:
        List of indices representing the path through the matrix.
    """
    # Start from the last node and find the path with maximum similarity
    n = len(similarity_matrix) - 1
    current_node = n
    visited_nodes = [current_node]
    path = []

    while len(visited_nodes) < n + 1:
        next_node = None
        max_similarity_score = -float('inf')

        for i in range(n + 1):
            if (i not in visited_nodes) and (
                    similarity_matrix[current_node][i] > max_similarity_score):
                max_similarity_score = similarity_matrix[current_node][i]
                next_node = i

        visited_nodes.append(next_node)
        path.append(next_node)
        current_node = next_node

    return path


def arg_sort_chunk(idx_seq, chunks):
    ordered_chunks = [chunks[idx] for idx in idx_seq]
    return ordered_chunks


def get_embedding_model(name, device, ablation_type=None):
    embedding_model = SentenceTransformer('BAAI/bge-m3', device=device)
    return embedding_model
