import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score
from munkres import Munkres
from sklearn.preprocessing import normalize
from scipy.sparse.linalg import svds
from sklearn.cluster import SpectralClustering

class Utils():
    
    def __init__(self,num_clusters=2):
        self.num_clusters = num_clusters
        
    def thrC(self, C, ro):
        if ro < 1:
            N = C.shape[1]
            Cp = np.zeros((N,N))
            S = np.abs(np.sort(-np.abs(C),axis=0))
            Ind = np.argsort(-np.abs(C),axis=0)
            for i in range(N):
                cL1 = np.sum(S[:,i]).astype(float)
                stop = False
                csum = 0
                t = 0
                while(stop == False):
                    csum = csum + S[t,i]
                    if csum > ro*cL1:
                        stop = True
                        Cp[Ind[0:t+1,i],i] = C[Ind[0:t+1,i],i]
                    t = t + 1
        else:
            Cp = C
        return Cp

    def post_proC(self, C, K, d, alpha, with_diag):
        # C: coefficient matrix, K: number of clusters, d: dimension of each subspace
        n = C.shape[0]
        A = 0.5*(np.abs(C) + np.abs(C.T))
        if with_diag:
            A = A - np.diag(np.diag(A)) + np.eye(n,n) # for sparse A, this step will make the algorithm more numerically stable
        else:
            A = A - np.diag(np.diag(A))
        r = d*K + 1
        U, S, _ = svds(A,r,v0 = np.ones(n))
        U = U[:,::-1] 
        S = np.sqrt(S[::-1])
        S = np.diag(S)
        U = U.dot(S)
        U = normalize(U, norm='l2', axis = 1)  
        Z = U.dot(U.T)
        Z = Z * (Z>0)
        L = np.abs(Z ** alpha)
        L = L/L.max()
        L = 0.5 * (L + L.T)
        labels = self.spectral_clustering(L)
        return labels, L
    
    def map_by_hungarian(self, data_now, data_last):
        conf_mat = confusion_matrix(data_now,data_last)
        m = Munkres()
        mappings = dict(m.compute(conf_mat.max() - conf_mat))
        data_mapped = []
        for i in data_now:
            try:
                data_mapped.append(mappings[i])
            except:
                unique, counts = np.unique(data_now, return_counts=True)
                print(dict(zip(unique, counts)))
                unique, counts = np.unique(data_last, return_counts=True)
                print(dict(zip(unique, counts)))
                data_mapped.append(i)
        return np.asarray(data_mapped,dtype=np.int32)
    
    def spectral_clustering(self,affinity_matrix):
        clustering = SpectralClustering(n_clusters=self.num_clusters, eigen_solver="arpack", random_state=42,
                                        affinity="precomputed", assign_labels="kmeans", n_jobs=-1).fit(affinity_matrix)
        data_output = clustering.labels_
        return np.asarray(data_output,dtype=np.int32)
    
    def refine_labels(self, weights, last_labels):
        c_thr = self.thrC(weights,0.5)
        cluster_labels, L = self.post_proC(c_thr,self.num_clusters,12,8,with_diag=True)
        new_labels = self.map_by_hungarian(cluster_labels, last_labels)
        return new_labels

    def get_one_hot(self, targets):
        res = np.eye(self.num_clusters)[np.array(targets).reshape(-1)]
        res = res.reshape(list(targets.shape)+[self.num_clusters])
        return np.asarray(res,dtype=np.uint8)
    
    def clustering_error(self,c_matrix,labels_orig):
        labels,L = self.post_proC(c_matrix, self.num_clusters, 12, 8, with_diag=True)
        labels_ref = self.map_by_hungarian(labels, labels_orig)
        return 1-accuracy_score(labels_orig,labels_ref)
    
    def balanced_sampling(self, data, classes, n_per_class, chunk_num, random_state=None):
        if random_state:
            np.random.seed(random_state)
        idx = np.random.permutation(data.shape[0])
        x,y = data[idx], classes[idx]
        balanced_indices = []
        for i in np.unique(y):
            balanced_indices.extend(np.where(y==i)[0][chunk_num*n_per_class:(chunk_num+1)*n_per_class])
        X_balanced, y_balanced = x[balanced_indices],y[balanced_indices]
        return X_balanced, y_balanced
    
    def sort_dataset(self,data_in,labels_in):
        dataset_new = np.empty(data_in.shape)
        labels_new = np.empty(labels_in.shape,dtype=np.uint8)
        group, counts = np.unique(labels_in,return_counts=True)
        counter = 0
        for g,c in zip(group, counts):
            if len(labels_in.shape)!=1:
                inds = labels_in[:,0]==g
            else:
                inds = labels_in==g
            data_i = data_in[inds]
            dataset_new[counter:counter+c] = data_i
            labels_new[counter:counter+c] = g
            counter+=c
        return dataset_new, labels_new
