#from pandas.core.arrays.categorical import factorize_from_iterable
import torch
import torch_geometric.utils as utils
from scipy.optimize import linear_sum_assignment
import scipy.sparse as sp
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import os 
import pickle
import copy

import numpy as np

class RBSCLUSTER():
    def __init__(self, args, qvectors=None):
        self.args = args
        self.threshold = 0.5
        self.flag_cluster_type = 2 # 1: cdhit, 2: K-means # clustering type
        self.n_cluster    = self.args.n_cluster
        self.flag_reduced = False
        self.sim_type     = 4
        self.default_node = self.args.n_codes

        self.qvectors = qvectors
        self.qvectors0 = copy.deepcopy(qvectors)
        
    def forward(self, sequences, prev_centroid=None, qvectors=None):
        clusters, reduced_seq, cluster_labels, sim_centroids, meanVq = self.f_cluster(sequences, prev_centroid=prev_centroid, cur_qvectors=qvectors)
        return clusters, reduced_seq, cluster_labels, sim_centroids, meanVq
    
    # Define a function to perform CD-HIT clustering
    def f_cluster(self, sequences, prev_centroid=None, cur_qvectors=None):        

        if prev_centroid is not None: 
            kmeans = KMeans(n_clusters=self.n_cluster, init=prev_centroid, n_init=1)
        else:
            kmeans = KMeans(n_clusters=self.n_cluster, init='k-means++')
        self.qvectors = cur_qvectors
        valid_idx = (sequences != self.default_node).cpu().numpy()
        input_sequences = sequences
        meanVq = []                 
        for k in range(len(input_sequences)):
            ide = sum(valid_idx[k]).item()
            if ide ==0: ide=1 # to prevent anomaly
            ndx = input_sequences[k,:ide].cpu().numpy().astype(np.int32) 
                
            meanVq.append(np.mean(self.qvectors[ndx,:], axis=0))
        meanVq = np.array(meanVq)    

        cluster_labels = kmeans.fit_predict(meanVq) 

        # Group sequences by cluster
        clusters = {}
        for seq_idx, cluster_id in enumerate(cluster_labels):
            if cluster_id not in clusters:
                clusters[cluster_id] = []
            clusters[cluster_id].append(seq_idx)
    
        sim_centroids = kmeans.cluster_centers_
        return clusters, input_sequences, cluster_labels, sim_centroids, meanVq
    
    def sequence_reduction(self, sequences):
        num_sequences = sequences.size(0)
        max_seq_len   = sequences.size(1)
        #output_seq = []
   
        max_len = 0
        #output_seq_th = torch.zeros((num_sequences, max_seq_len))
        output_seq_th = self.default_node*torch.ones((num_sequences, max_seq_len))
        
        for i in range(num_sequences):
            input_seq = sequences[i]    
            idx = (torch.nonzero( input_seq[:-1] - input_seq[1:] ) + 1).squeeze(-1) 
            reduced_seq = torch.cat( (input_seq[[0]], input_seq[idx]) , dim=0)
            #output_seq.append(reduced_seq)
            output_seq_th[i,:len(reduced_seq)] = reduced_seq
            max_len = max(max_len, len(reduced_seq))
    
        # idx = (torch.nonzero( sequences[:,:-1] - sequences[:,1:] ) + 1).squeeze(-1) 
        # output_seq = torch.cat( (input_seq[[0]], input_seq[idx]) , dim=0)
        output_seq_th = output_seq_th[:,:max_len]
        return output_seq_th    

    def match_vectors_brute_force(self, vector1, vector2, similarity_matrix):
        k = len(vector1)
        matching = []

        # Loop through each element in vector1
        for i in range(k):
            max_similarity = -np.inf
            matched_index = -1
        
            # Compare with each element in vector2
            for j in range(k):
                similarity = similarity_matrix[i][j]
            
                # If similarity is higher than previous max, update
                if similarity > max_similarity:
                    max_similarity = similarity
                    matched_index = j
        
            # Add matched index to the matching list
            matching.append((i, matched_index))

        return matching
        
    def match_vectors_Hungarian(self, seq1, seq2, match_type=None):
        
        # compute similairty matrix
        if match_type is None: # seq        
            similarity_matrix = self.sequence_similarity(seq1, seq2).detach().cpu().numpy()
        else: # Euclidean distance
            #similarity_matrix = torch.matmul(seq1, seq2.t()) / (seq1.norm(dim=1)[:, None] * seq2.norm(dim=1)).detach().cpu().numpy()
            if np.shape(seq1)[1] != np.shape(seq2)[1]: # if size is different -> add zero_pad
                if np.shape(seq1)[1] > np.shape(seq2)[1]:
                    diff_len = np.shape(seq1)[1] - np.shape(seq2)[1]
                    seq2 = np.concatenate((seq2, np.zeros((self.n_cluster, diff_len))), axis=1)                    
                else:
                    diff_len = np.shape(seq2)[1] - np.shape(seq1)[1]
                    seq1 = np.concatenate((seq1, np.zeros((self.n_cluster, diff_len))), axis=1)
                
            similarity_matrix = np.matmul(seq1, seq2.T) / (np.linalg.norm(seq1, axis=1)[:, None] * np.linalg.norm(seq2, axis=1))
    
        # check anomaly of NaN
        if np.any(np.isnan(similarity_matrix)): # there is non-lement
            matching = None
        else:
            # Convert similarity matrix to cost matrix
            cost_matrix = np.max(similarity_matrix) - similarity_matrix
    
            # Apply Hungarian algorithm
            row_indices, col_indices = linear_sum_assignment(cost_matrix)
    
            # Extract matching pairs
            matching = list(zip(row_indices, col_indices))
    
        return matching

    def rearrange_indices(self, seq_cur, seq_prev):
        # compute matching pair         
        matching = self.match_vectors_Hungarian(seq_prev, seq_cur)
        
        # Get the indices from seq1 that correspond to the matched elements in seq2
        rearranged_indices = [matching[i][1] for i in range(len(matching))]
    
        # Rearrange seq1 according to the rearranged indices
        rearranged_seq1 = [seq_cur[i] for i in rearranged_indices]
    
        return torch.stack(rearranged_seq1, dim=0)
    
    def rearrange_label_indices(self, seq_cur, seq_prev, seq_labels):
        # compute matching pair   

        # matching = self.match_vectors_Hungarian(seq_prev, seq_cur, "Euclidean") # 24.08.07 matching error 
        matching = self.match_vectors_Hungarian(seq_cur, seq_prev, "Euclidean")
        
        # Get the indices from seq1 that correspond to the matched elements in seq2
        rearranged_indices = copy.deepcopy(seq_labels) # if non --> return original indices
        if matching is not None:
            for i in range(len(matching)):
                target_indices = np.equal(seq_labels,i)            
                rearranged_indices[target_indices] = matching[i][1]
        
        return rearranged_indices
    
    def compute_centroid_topK(self, clusters, sequences):
        n_cluster = len(clusters)

        centroids = []
        centroid_indices =[]
        valid_indices = []
        existing_cluster_indices = list(clusters.keys())
        
        for k in range(n_cluster):
            #cur_sequence_set = sequences[clusters[k],:]
            cur_sequence_set = sequences[clusters[existing_cluster_indices[k]],:]

            # input sequence reduction
            if self.flag_reduced:
                input_sequences = self.sequence_reduction(cur_sequence_set)
            else:
                input_sequences = cur_sequence_set

            similarity_matrix_np = self.pairwise_similarity(input_sequences).cpu().numpy() # based on binary similarity     
            #mean_similarity = np.mean(similarity_matrix_np, axis=1)
            #max_idx = np.argmax(mean_similarity)
            sum_similarity = np.sum(similarity_matrix_np, axis=1)
            
            # Sort indices in descending order to get top-10 indices
            top_indices = (np.argsort(sum_similarity)[::-1][:self.args.num_centroid_sample]).astype(np.int32)
                
            # Store the indices of centroids
            centroid_indices.append(top_indices)

            # Append corresponding sequences to centroids list
            centroids.extend(cur_sequence_set[top_indices])

        # Convert centroids and indices to torch tensors
        centroids = torch.stack(centroids, dim=0)       # [ n_cluster x n_sample, n_dim ]
        #centroids = centroids.reshape(self.args.n_cluster, self.args.num_centroid_sample, -1)
        centroid_indices = np.array(centroid_indices)   # [ n_cluster x n_sample ]

        return centroids, centroid_indices

    def predict_label_topK(self, input_sequences, input_centroid):    
        # Compute similarity based on the specified similarity type
        
        #.. size check
        len_seq = input_sequences.size(0)
        #.. concatenation
        concatenated_seq = torch.concat((input_sequences, input_centroid), dim=0 )

        #.. conduct sequence reduciton
        if self.flag_reduced == True:
            reduced_concatenated_seq = self.sequence_reduction(concatenated_seq)
            sequences  = reduced_concatenated_seq[:len_seq,:]
            centroid   = reduced_concatenated_seq[len_seq:,:]
        else:
            sequences = input_sequences
            centroid  = input_centroid
            
        # Convert sequence2 to a set (since it's fixed)
        n_clst   = self.args.n_cluster
        n_sample = self.args.num_centroid_sample
        #total_common_elements = []
            
        # Initialize a list to store the count of common elements for each row of sequence1
        common_elements_per_row = []
        set2 = {}
        for k in range(n_clst):                
            sequence2 = centroid[n_sample*k:n_sample*(k+1),:]                
            set2[k] = set(np.unique(sequence2.flatten()))

        # Iterate over each row of sequence1
        for row in sequences:
            # Convert the current row of sequence1 to a set
            set1_row = set(np.unique(row))
            # Find the common elements between the current row of sequence1 and sequence2
            
            num_common = np.zeros(n_clst, dtype=np.int32)
            for k in range(n_clst):                
                # sequence2 = centroid[n_sample*k:n_sample*(k+1),:]                
                # set2 = set(sequence2.flatten())
                common_elements_row = set1_row.intersection(set2[k])
                num_common[k] = len(common_elements_row)
                
            # Append the count of common elements for the current row to the list
            common_elements_per_row.append(num_common.tolist())
            
        similarity_th_sum = torch.tensor(common_elements_per_row)

        # Get the index of the centroid with the highest similarity for each sequence
        max_similarity_indices = torch.argmax(similarity_th_sum, dim=1) 
    
        return max_similarity_indices

    def compute_centroid(self, clusters, sequences):
        n_cluster = len(clusters)

        centroids = []
        buf_max_idx = np.zeros(n_cluster, dtype=np.int32)
        
        for k in range(n_cluster):
            cur_sequence_set = sequences[clusters[k],:]

            # input sequence reduction
            if self.flag_reduced:
                input_sequences = self.sequence_reduction(cur_sequence_set)
            else:
                input_sequences = cur_sequence_set

            similarity_matrix_np = self.pairwise_similarity(input_sequences).cpu().numpy() # based on binary similarity     
            #mean_similarity = np.mean(similarity_matrix_np, axis=1)
            #max_idx = np.argmax(mean_similarity)
            sum_similarity = np.sum(similarity_matrix_np, axis=1)
            max_idx = np.argmax(sum_similarity)
            buf_max_idx[k] = int(max_idx)
            centroids.append(cur_sequence_set[max_idx]) # before sequence reduction
            
        #.. make array    
        centroids = torch.stack(centroids, dim=0)    
        
        return centroids, buf_max_idx 