'''
The implementation of centroid graph technique. It permits to find at the same time
The layers of each group, and the blocks of each groups. In this case, it is used
as initialization techique for the MSBM. It can be used as a technique to compute
Multi-Group Community Detection in multiplex graph.
'''
import numpy as np 
from sklearn.cluster import KMeans
import random as rd
from itertools import product


def euclidean_distance(x1:np.array,
                       x2:np.array):
    '''
    Compute the euclidean distance between two vectors

    Args:
        x1 (np.array): The first vector 
        x2 (np.array): The second vector
    
    Returns: 
        np.float64: The euclidean distance between the vectors
    '''
    return np.sqrt(np.sum((x1 - x2) ** 2))


def L_2_norm(M1:np.array,
             M2:np.array):
    '''
    Compute the Frobenius Norm of two matrices from the following equation
    ||M1-M2||^2 = ||M1||^2 + ||M2|| -2M1.T.dot(M2)

    Args: 
        M1 (np.array): The first matrix 
        M2 (np.array): The second matrix
    
    Returns:
        np.array : Frobenius norm matrix between M1 and M2
    '''
    
    M1_n = M1*M1
    M1_n = M1_n.sum(axis=0)
    M2_n = M2*M2
    M2_n = M2_n.sum(axis=0)
    M1_ones = np.ones((M2_n.shape[0],M1_n.shape[0]))
    M2_ones = np.ones((M1_n.shape[0],M2_n.shape[0]))
    M1_ones = M1_ones * M1_n.reshape(-1,1)
    M2_ones = M2_ones * M2_n.reshape(1,-1)
    m1m2 = M1.T.dot(M2)
    norm_ = M1_ones + M2_ones - 2*m1m2 
    norm_ = np.abs(norm_)
    norm_[norm_<0]=0
    norm_ -= np.diag(np.diag(norm_))
    return norm_


def Graph_embedding(q0,
                    m=1):
    '''
    Optimization of a vertex representation of the unified Graph

    Args:
        q0 (np.array): Previous representation of the vertex from the unified graph
        m (np.int): normalization factor, in this case it represents the number of layers
    
    Returns:
        np.array: The new representation of the vertex that optimize its quadratic function  
    '''
    ft = 0
    n = q0.shape[1]
    p0 = q0.sum(axis=0)/m - q0.sum(axis=0).mean()/m + 1/n 
    v_min = p0.min()
    if v_min <0:
        f =1
        lambda_m = 0
        while abs(f)> 1e-10:
            v1 = lambda_m - p0 
            posidx = np.argwhere(v1>0)
            npos =  posidx.shape[0]
            g = npos/n -1
            if g== 0 :
                g = np.finfo(float).eps 
            f = v1[posidx].sum()/n - lambda_m
            lambda_m = lambda_m - f/g 
            ft +=1
            if ft >100:
                v1 = -v1
                v1[v1<0] = 0
                return v1 
        v1 = -v1
        v1[v1<0] = 0
        return v1
    else:
        return p0
    

def GCM(data:np.array,
        c:np.int32,
        lambda_val:np.float64 =1,
        iter_N:np.int32 =100):
    '''
    GCM (Graph Centroid Model) computes one Unified graph that represents the multiplex graph which has 'c' components. This technique 
    is inspired from [1].

    Args: 
        data (np.array): The multiplex graph that is constructed by [nb_layer, nb_vertices,nb_vertices]
        c (np.int32): The constraint of having c component in the unified graph
        lambda_val (np.float64): The lagrangian multiplier, which is used to include the constraint of c
                                 component in the cost function
        iter_N (int32): The number of iteration from where the algorithm will end the computation if 
                        it is not atteint the convergence

    Retunrs: 
        U (np.array): The unified graph from the model
        F (np.array): Embedding of the unified graph, it represents the eigenvectors associated to 
                      'c' smallest eigenvales of the Laplacien matrix of the unified graph.
    
    References:
        [1] Wang, Hao & Yang, Yan & Liu, Bing. (2019). GMC: Graph-based Multi-view Clustering. 
            IEEE Transactions on Knowledge and Data Engineering. PP. 1-1. 10.1109/TKDE.2019.2903810. 
    
    '''
    zr = 1e-11
    data_norm = data.sum(axis=2)
    data = data / data_norm.reshape(data.shape[0],data.shape[1],1) 
    data = np.nan_to_num(data)
    n = data.shape[1]
    l = data.shape[0]
    U = np.zeros((n,n))
    U = data.sum(axis=0)
    U = U/l
    U /= U.sum(axis=1).reshape(-1,1)
    U_tmp = (U+U.T)/2
    U_tmp -= np.diag(np.diag(U_tmp))
    D = np.diag(U_tmp.sum(axis=1))
    L = D - U_tmp
    eig_val, eig_vec = np.linalg.eig(L)
    top_vec = np.argsort(eig_val)[:c]
    F = eig_vec[:,top_vec]
    w = np.ones(l) / l
    dist_all = np.zeros((l,n,n))
    for i in range(l):
        dist_all[i,:,:] = L_2_norm(data[i,:,:],data[i,:,:])
    ## Optimization computation 
    for iter_v in range(iter_N):
        for i in range(l):
            US = U - data[i,:,:]
            dist_us = np.power(np.linalg.norm(US,'fro'),2)
            if dist_us == 0:
                dist_us = np.finfo(float).eps 
            w[i] = 0.5/np.sqrt(dist_us)
        dist_ = L_2_norm(F.T,F.T)
        U = np.zeros((n,n))
        for i in range(n):
            _, node_n = np.where(data[:,i,:]!=0)
            node_n = np.unique(node_n)
            q = np.zeros((l,node_n.shape[0]))
            for j in range(l):
                S1 = data[j,i,node_n]
                dist1 = dist_[i,node_n]
                lw = l*w[j]
                lmw = lambda_val / lw
                q[j,:] = S1 - 0.5*lmw*dist1
            U[i,node_n] = Graph_embedding(q,l)
        del q
        sU = U.copy()
        sU = (sU + sU.T)/2
        sU -= np.diag(np.diag(sU))
        D = np.diag(sU.sum(axis=1))
        L = D - sU 
        F_old = F
        eig_val, eig_vec = np.linalg.eig(L)
        top_vec = np.argsort(eig_val)
        F = eig_vec[:,top_vec[:c]]
        fn1 = eig_val[top_vec[:c]].sum()
        fn2 = eig_val[top_vec[:c+1]].sum()
        if fn1 > zr:
            lambda_val = 2*lambda_val
        elif fn2 < zr:
            lambda_val = lambda_val /2
            F = F_old
        else:
            break
            
    return U,F



def kmeans_plus_plus(X:np.array, 
                     K:np.int32):
    '''
    Intialize the centroids using Kmeans++ technique, which consist of 
    chosing the farest centroids from the data [1]

    Args:
        X (np.array): The data to be clustres, where X.shape[0] represents the nb individuals
        K (np.int32): The number of centroids for find
    
    Returns:
        centroids (np.array): the centroids from Kmeans++ technique

    References:
        [1]Arthur, D.; Vassilvitskii, S. (2007). "k-means++: the advantages of careful seeding"

    '''
    n_samples = X.shape[0]
    centroids = np.zeros((K, X.shape[1],X.shape[2]))
    
    # Step 1: Randomly choose the first centroid from the data points
    centroid_idx = np.random.randint(n_samples)
    centroids[0] = X[centroid_idx]
    
    # Step 2: Select the remaining K-1 centroids using k-means++ logic
    for k in range(1, K):
        distances = np.zeros(n_samples)
        for i in range(n_samples):
            # Calculate the distance from each data point to the closest centroid
            distances[i] = np.min([euclidean_distance(X[i], centroids[j]) for j in range(k)])
        # Choose the next centroid from the data points with probabilities proportional to the squared distance
        prob = distances**2 / np.sum(distances**2)
        centroid_idx = np.random.choice(n_samples, p=prob)
        centroids[k] = X[centroid_idx]
    
    return centroids


def Multi_Group_assign(data:np.array,
                       centroid:np.array):
    '''
    Compute the assignement variables of each layer to a group in which is the closes

    Args:
        data (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]
        centroid (np.array): The graph centroid of each group, which is setted as [centroids,nb_vertices, nb_vertices]
    
    Returns:
        label_ (np.array): the group assignement of each layer
    '''
    n_samples = data.shape[0]
    n_group = centroid.shape[0]
    label_ = []
    for i in range(n_samples):
        distances = np.zeros(n_group)
        for k in range(n_group):
            distances[k] = euclidean_distance(data[i], centroid[k])
        label_.append(np.argmin(distances))
    return np.array(label_)


def Multi_Group_GCM(data:np.array,
                    g:np.int32,
                    c_:np.array,
                    lambda_val:np.array=None,
                    iter_N:np.int32=100):
    '''
    The algorithm that compute the centroids of each layer in which each one has a certain blocks.
    Because we identify that the groups may not forcely c_ disconnected component,the communities 
    of each group are identified from applyting a Kmeans algorithm of the embdding vector.

    Args:
        data (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]
        g (np.int32): The number of groups that we want to find 
        c_ (np.array): The number of blocks for each group, where c.shape[0] == g
        lambda_val (np.array): The Lagrangian variable for each group.
        inter_N (np.int): The number of iteration from where the algorithm will end the computation if 
                          it is not atteint the convergence
    
    Returns: 
        y_label (np.array): the group of each layer 
        z_block (np.array): the communities of each group
    '''
    n_layer = data.shape[0]
    n_vertices = data.shape[1]
    y_old = np.zeros(n_layer)
    z_block = np.zeros((n_vertices,g))
    centroids = kmeans_plus_plus(data,g)
    i = 0
    while(i<iter_N):
        y_label = Multi_Group_assign(data=data,
                                     centroid=centroids)
        centroids_tmp = np.zeros((g,data.shape[1],data.shape[2]))
        for j in range(g):
            index_layer = np.where(y_label == j)[0]
            U_,F_ = GCM(data=data[index_layer,:,:],
                        c=c_[j])
            centroids_tmp[j] = U_
            kmeans = KMeans(n_clusters=c_[j])
            z_block[:,j] = kmeans.fit(F_).labels_
        if (y_label == y_old).all()==True:
            break
        i+=1
        y_old = y_label
        centroids = centroids_tmp
    return y_label, z_block