'''
The implementation of Mixture Stochastic Block Model for multi-group community detection 
in multiplex group. We implement the adjacency matrices of each layer in np.array bloc.
We know that there are better implementation with using technique for sparse representation.
The codes that will be publish with the paper if it will be accepted, will be more optimal.
'''

import numpy as np 
import os 
from tqdm import tqdm
import matplotlib.pyplot as plt 
import sys 
import random as rd
from itertools import product
from initialization_algo import Multi_Group_GCM
from utils import _argmax_function


def _initial_latent_variable(X:np.array,
                             init_paramter:str,
                             random_state,
                             n_groups:np.int32,
                             n_block:np.array):
    '''
    Initialize the layer_to_group and vertex_to_block variable depending of the mode the cas selected

    Args:
        X (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]
        init_parameters (str): Define the mode of the intialization either 'random' or 'K_centroid'
        n_groups (np.int32): The number of groups in the multiplex graph
        n_block (np.array): It contains the number of blocks for each group
    Retruns:
        y (np.array): the group assignement of each layer
        z (np.array): the communities of each group
    '''
    n_layers = X.shape[0]
    n_vertices = X.shape[1]
    assert n_block.shape[0] == n_groups, 'The differrent block must math with number of group '
    y = np.zeros((n_layers,n_groups),dtype=np.float128)
    z = dict()
    for i in range(n_groups):
        z[i] = np.zeros((n_vertices,n_block[i]),dtype=np.float128)
    if init_paramter == 'random':
        # Random layer-to-group initialization
        y = random_state(type_='uniform',size=(n_layers, n_groups))
        y = _argmax_function(y)
        # Random vertex-to-block initialization
        for i in range(n_groups):
            z[i] = random_state(type_='uniform',size=(n_vertices, n_block[i]))
            z[i] = _argmax_function(z[i])
    elif init_paramter == 'K_centroid':
        y_,z_ = Multi_Group_GCM(data=X,
                                 g=n_groups,
                                 c_=n_block)
        y[np.arange(y.shape[0]),y_] = 1
        z_ = z_.astype(int)
        for i in range(n_groups):
            z[i][np.arange(z_[:,i].shape[0]),z_[:,i].tolist()] = 1
    return y,z

def init_parameters(X:np.array,
                    y:np.array,
                    z:dict):
    '''
    given the latent variables (y and z), compute the parameters of this distribution.

    Args:
        X (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]
        y (np.array): the group assignement of each layer
        z (dict): the communities of each group
    
    Returns:
        beta (np.array): The mixture parameter of MSBM model 
        alpha (dict): The mixture parameters for each group's SBM 
        pi (dict): The Bernoulli parameters of each group's SBM 

    After getting the initialization of the lattent variable. Must initilize the paramters of the mixture 
    distribution and the bernouli of each block
    '''
    nk = y.sum(axis=0)
    beta = nk/y.shape[0]
    alpha = dict()
    pi = dict()
    for i in z.keys():
        layers_tmp = np.where(y[:,i]==1)[0]
        alpha[i], pi[i] =  estimate_sbm_params(X[layers_tmp,:,:],
                                               z[i])
    return beta, alpha,pi

def estimate_sbm_params(data_used,
                        z_):
    '''
    Estimates the parameters of The SBM model given its layers that assignement parameters.
    In this case, the layers are considered to be independant.

    Args:
        data_used (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]
        z_ (np.array): The assignement variables of the SBM
    
    Returns:
        alpha (np.array): The mixture parameters of the SBM 
        pi (np.array) : The Bernoulli parameters of the SBM
    '''
    alpha = z_.sum(axis=0)
    alpha /= z_.shape[0]
    pi = np.zeros((z_.shape[1],z_.shape[1]),dtype=np.float128)
    tmp_A = np.zeros((data_used.shape))
    for z_1,z_2 in product(range(z_.shape[1]),range(z_.shape[1])):
        for ind in range(tmp_A.shape[0]):
            tmp_A[ind] = data_used[ind] * z_[:,z_1].reshape(-1,1)*z_[:,z_2]
        pi_tmp = tmp_A.sum()
        z_tmp = z_[:,z_1].reshape(-1,1)*z_[:,z_2]
        z_tmp = z_tmp.sum()*data_used.shape[0]
        pi[z_1,z_2] = pi_tmp/ z_tmp
    return alpha , pi


def single_sbm_log(data_used:np.array,
                   z_:np.array,
                   iteration:np.int32=100):
    '''
    Infere the parameters of SBM and estimate the assignement variables that maximize the Evidence Upper Bound
    which is based on Variational Estimation Maximization [1]

    Args:
        data_used (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]. It represents
                              the layer of a given group.
        z_ (np.array): The initiaization of vertex_to_block variables for this group
        iteration (np.int32): The number of iteration from where the algorithm will end the computation if 
                              it is not atteint the convergence
    Returns:
        z_ (np.array): The assignement variable of each vertex to a given block
        alpha (np.array): The infered mixture parameters of the SBM 
        pi (np.array) : The infered Bernoulli parameters of the SBM

    References: 
        [1] Celisse, A., Daudin, J. J., & Pierre, L. (2012). 
        Consistency of maximum-likelihood and variational estimators in the stochastic block model.
    '''
    i = 0
    while tqdm(i<iteration):
        alpha = z_.sum(axis=0)
        alpha /= z_.shape[0]
        alpha[alpha==0] = np.finfo(float).eps
        pi = np.zeros((z_.shape[1],z_.shape[1]))
        tmp_A = np.zeros((data_used.shape))
        for z_1,z_2 in product(range(z_.shape[1]),range(z_.shape[1])):
            for ind in range(tmp_A.shape[0]):
                tmp_A[ind] = data_used[ind] * z_[:,z_1].reshape(-1,1)*z_[:,z_2]
            pi_tmp = tmp_A.sum()
            z_tmp = z_[:,z_1].reshape(-1,1)*z_[:,z_2]
            z_tmp = z_tmp.sum()*data_used.shape[0]
            pi[z_1,z_2] = pi_tmp/ z_tmp
        pi = np.nan_to_num(pi, nan=0)
        pi[pi==0] = np.finfo(float).eps
        pi[pi==np.inf] = np.finfo(float).eps
        pi[pi==-np.inf] = np.finfo(float).eps
        z_new = np.zeros(z_.shape)
        for ind,block in tqdm(product(range(z_.shape[0]),range(z_.shape[1]))):
            tmp = 0
            for q in range(z_.shape[1]):
                used_samples = np.where(z_[:,q]!=0)[0]
                for layer in range(data_used.shape[0]):
                    number_of_ones = np.where(data_used[layer,ind,used_samples]==1)[0].shape[0]
                    number_of_zeros = np.where(data_used[layer,ind,used_samples]==0)[0].shape[0]
                    tmp += number_of_ones*np.log(pi[block,q]) + number_of_zeros * np.log(1- pi[block,q])
            tmp += np.log(alpha[block])
            z_new[ind,block] = tmp
        z_new = _argmax_function(z_new)
        if (z_new == z_).all()==True:
            print('attent to a fixed point')
            break
        z_ = z_new
        i+=1
    return z_, alpha, pi

def y_estimation_with_log(data:np.array,
                          z_:dict,
                          beta:np.array,
                          alpha:dict,
                          pi:dict):
    '''
    Estimate the layer-to-group assignement from given multiplex graph and 
    parameters, using the log-likehood technique

    Args:

        data (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]. It represents
                         the layer of a given group. 
        z_ (dict): The assignement variable of each vertex to a given block, for each group
        beta (np.array): The mixture parameter of MSBM model 
        alpha (dict): The mixture parameters for each group's SBM 
        pi (dict): The Bernoulli parameters of each group's SBM 
    
    Retruns:

        y_new (np.array): The layer-to-group assingment variable for each layer
    '''
    n_component = len(z_.keys())
    n_layer, _ , _ = data.shape
    y_estimated = np.zeros((n_layer,n_component))
    beta[beta == 0] = np.finfo(float).eps
    for sample, component in product(range(n_layer),range(n_component)):
        y_estimated[sample,component] = proba_of_block_single_layer_log(data=data[sample],
                                                                        z_= z_[component],
                                                                        pi_= pi[component],
                                                                        alpha_=alpha[component])
        y_estimated[sample,component] += np.log(beta[component])
    y_new = _argmax_function(y_estimated)
    return y_new


def z_estimation_log(data_used:np.array,
                     z_:np.array,
                     pi:np.array,
                     alpha:np.array):
    '''
    Estimating the block-to-vertex assignement for a given group, using the log-likelihood function.
    
    Args:

        data_used (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]. It represents
                         the layer of a given group. 
        z_ (np.array): The previous vertex-to-block
        alpha (np.array): The mixture parameters of SBM 
        pi (np.array): The Bernoulli parameters of SBM 
    
    Returns:

        z_new (np.array): The new optimized vertex-to-block assignements 
    '''
    n_vertices =  data_used.shape[-1]
    n_component = pi.shape[0]
    z_new = np.zeros((n_vertices,n_component))
    for ind,block in tqdm(product(range(z_.shape[0]),range(z_.shape[1]))):
        tmp = 0
        for q in range(z_.shape[1]):
            used_samples = np.where(z_[:,q]!=0)[0]
            for layer in range(data_used.shape[0]):
                number_of_individuals = used_samples.shape[0]
                number_of_ones = np.where(data_used[layer,ind,used_samples]==1)[0].shape[0]
                number_of_zeros = np.where(data_used[layer,ind,used_samples]==0)[0].shape[0]
                tmp += number_of_ones*np.log(pi[block,q]) + number_of_zeros * np.log(1- pi[block,q])
        tmp += np.log(alpha[block])
        z_new[ind,block] = tmp
    z_new = _argmax_function(z_new)
    return z_new

def bernouli_distribution_block_single_layer_log(data:np.array,
                                                z_:np.array,
                                                pi:np.array):
    '''
    Compute the Stochastic Block Bernouli log-likelihood for a single layer

    Agrs:
        data (np.array): The adjacency matrix of a single layer 
        z_ (np.array): The vertex-to-block assignment variables of the group of this layer
        pi (np.array): The parameters of Bernoulli distributions of the SBM model associated to 
                       the group of this layer
    
    Returns:
        bernouli_block_log (np.float64): The log-likelihood of the Bernoulli distribution
                                         of this layer with SBM parameters from its group
    '''
    bernouli_block_log = 0
    for index, column in product(range(pi.shape[0]), range(pi.shape[1])):
        i_element = np.where(z_[:,index]==1)[0]
        j_element = np.where(z_[:,column]==1)[0]
        data_tmp = data[i_element,:].copy()
        data_tmp = data_tmp[:,j_element]
        bernouli_block_log += bernouli_distribution_log(data_tmp,pi[index,column])
    return bernouli_block_log


def proba_of_block_single_layer_log(data:np.array,
                                    z_:np.array,                                    
                                    alpha_:np.array,
                                    pi_:np.array):
    '''
    Compute the distribution of block model for one layer, giving the assignement variable and 
    SBM parameters
    
    Args:

        data (np.array):The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]. It represents
                         the layer of a given group.
        z_ (np.array): The assignement variable of each vertex to a given block
        alpha (np.array): The mixture parameters of SBM 
        pi_ (np.array): The Bernoulli parameters of SBM 
    
    Returns: 
        prob_dist_log (np.float64): the log-likelihood of the SBM giving the parameters of one layer. 
    '''
    prob_dist_log = 0
    prob_dist_log += bernouli_distribution_block_single_layer_log(data,z_,pi_)
    prob_dist_log += alpha_distribution_block_log(z_,alpha_)
    return prob_dist_log

def proba_of_block_sbm_group_log(data:np.array,
                                 z_:np.array,
                                 pi_:np.array,
                                 alpha_:np.array):
    '''
    Compute the distribution of block model for a group that contains several layers. The layers
    are assumed to be i.i.d from the SBM distribution
    Args:

        data (np.array):The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]. It represents
                         the layer of a given group.
        z_ (np.array): The assignement variable of each vertex to a given block
        alpha (np.array): The mixture parameters of SBM 
        pi_ (np.array): The Bernoulli parameters of SBM 
    
    Returns: 
        prob_dist_log (np.float64): the log-likelihood of the SBM giving the parameters for the group of layers. 
    '''
    
    prob_dist_log = 0
    for i in range(data.shape[0]):
        prob_dist_log += bernouli_distribution_block_single_layer_log(data[i],z_,pi_)
    prob_dist_log += alpha_distribution_block_log(z_,alpha_)
    return prob_dist_log


def log_likelihood(data:np.array,
                   y_:np.array,
                   z_:dict,
                   beta_:np.array,
                   alpha_:dict,
                   pi_:dict
                   ):
    '''
    The log-likelihood of the multiplex graph given the latent variables and 
    parameters of each SBM of each group
    Args:

        data (np.array):The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]. It represents
                         the layer of a given group.
        y (np.array): the group assignement of each layer
        z_ (dict): The assignement variable of each vertex to a given block for each group
        beta (np.array): The mixture parameter of MSBM model 
        alpha (dict): The mixture parameters of SBM for each group
        pi_ (dict): The Bernoulli parameters of SBM for each group
    
    Returns:
        l_log (np.float64): The log-likelihood of the Mixture Stochastic Block Model given the data,
                            latent variables, and multiplex graph.

    '''
    l_log = 0
    for i in z_.keys():
        layers_inde = np.where(y_[:,i]==1)[0]
        l_log += proba_of_block_sbm_group_log(data[layers_inde],
                                              z_=z_[i],
                                              pi_=pi_[i],
                                              alpha_=alpha_[i])
        l_log += np.log(beta_[i])*layers_inde.shape[0]
    return l_log


def alpha_distribution_block_log(z_:np.array,
                                 alpha_:np.array):
    '''
    Compute the distribution of log-likelihood from the mixture parameter alpha giving the 
    assignement vertex-to-block.

    Args:

        z_(np.array): The assignement variable of each vertex to a given block
        alpha_ (np.array): The value of parameters of given SBM
    
    Returns:
        alpha_dist_log (np.float64): The log-likelihood from the mixture parameter alpha
    '''
    alpha_dist_log = 0
    for i in range(z_.shape[1]):
        number_of_element = np.where(z_[:,i]==1)[0].shape[0]
        alpha_dist_log += number_of_element*np.log(alpha_[i])
    return alpha_dist_log


def bernouli_distribution_log(data:np.array, 
                              pi:np.array):
    '''
    Compute the log-likelyhood Bernouli distribution of data

    Args:

        data (np.array): The individuals that follows the same Bernoulli distribution
        pi (np.float64): The probability of having edge between the individuals
    
    Returns:
        bernouli_dist_log (np.float64): The log-likelyhood Bernouli distribution of data

    '''
    bernouli_dist_log = 0
    ones_index = np.where(data == 1)[0]
    ones_number = ones_index.shape[0]
    zeros_index = np.where(data == 0)[0]
    zeros_number = zeros_index.shape[0]
    bernouli_dist_log += ones_number*np.log(pi) +zeros_number*np.log(1-pi)
    return bernouli_dist_log


def random_st(type_:str,
              size:tuple,
              random_func=rd,
              a:np.float64=0,
              b:np.float64=1):
    '''
    Generates from teh random_func an array of a given size between [a,b]

    Args:

        type_ (str): The type of distribution that was used to compute the random values
        size (tuple): The size the array that wants to generates
        random_fun: the function which is used for generation
        a (np.float64): The lower bound of the generated variables
        b (np.float64): The higer bound of the generated variables
    
    Returns:

        vector (np.array): the random values of size [size]
    '''
    def one_iter(x):
        x = getattr(random_func,type_)(a,b)
        return x
    vector = np.zeros((size[0],size[1]))
    vector_func = np.vectorize(one_iter)
    vector = vector_func(vector)
    return vector


def algorithm_MSBM(data:np.array,
                   y_:np.array,
                   z_:dict,
                   beta:np.array,
                   alpha:dict,
                   pi:dict,
                   iteration:np.int64):
    '''
    The implementation of the algorithm that alternatively computes 
    between the estimation of latent variable and maximize the parameters.

    Args:
        data (np.array): The multiplex graph where is setted as [layer,nb_vertices,nb_vertices]
        y_ (np.array): The intial group assignement of each layer
        z_ (dict): The intial communities of each group
        beta (np.array): The intial mixture parameter of MSBM model 
        alpha (dict): The intial mixture parameters for each group's SBM 
        pi (dict): The intial Bernoulli parameters of each group's SBM 
        iteration (np.int32): The intial number of iteration from where the algorithm will end the computation if 
                              it is not atteint the convergence
    
    Returns:
        y_new (np.array): the group assignement of each layer
        z_new (dict): the communities of each group
    Compute the alternation between computing both the estimation 
    and the parameters maximization
    '''
    i = 0
    while tqdm(i<iteration):
        y_new = y_estimation_with_log(data=data,z_=z_,beta=beta,alpha=alpha,pi=pi)
        z_new = dict()
        alpha_new = dict()
        pi_new = dict()
        for i in range(y_new.shape[1]):
            layers = np.where(y_new[:,i]==1)[0]
            z_new[i], alpha_new[i], pi_new[i] = single_sbm_log(data_used=data[layers],z_=z_[i])
        nk = y_new.sum(axis=0)
        beta_new = nk/y_new.shape[0]
        if (y_new == y_).all()==True:
            print('fixe point attended')
            break
        y_ = y_new
        z_ = z_new
        beta = beta_new
        alpha = alpha_new
        pi = pi_new
    return y_new, z_new