'''
    Genetates layers from different distributions with the exact assigment vertex-to-block assignment.
    Compute layer-to-group and vertex-to-block variables using MSBM and compare it to Kmeans and 
    SBM for each task.
'''
from MSBM import _initial_latent_variable, init_parameters, algorithm_MSBM
from utils import generate_layers, join_all_layer, nmi, _argmax_function, bar_plot_figures, random_st
from sklearn.cluster import KMeans
import numpy as np

if __name__ =='__main__':
    np.random.seed(3)
    n_ = np.array([25,25,25,25])
    n_dict = {
            '1':n_,
            '2':n_,
            '3':n_
    }
    p_1 = np.array([[0.5,0.3,0.3,0.3],
                [0.3,0.5,0.3,0.3],
                [0.3,0.3,0.5,0.3],
                [0.3,0.3,0.3,0.5]])
    p_2 = np.array([[0.5,0.1,0.1,0.1],
                [0.1,0.5,0.1,0.1],
                [0.1,0.1,0.5,0.1],
                [0.1,0.1,0.1,0.5]])

    p_3 = np.array([[0.5,0.5,0.5,0.5],
                [0.5,0.5,0.5,0.5],
                [0.5,0.5,0.5,0.5],
                [0.5,0.5,0.5,0.5]])
    # p_3 is considered as random graph without communities. we want to know if it can considered or not
    p_dict = {
            '1':p_1,
            '2':p_2,
            '3':p_3
    }
    G_multiplex = list()
    nb_layer = 10
    for key in p_dict.keys():
        G_multiplex_tmp = generate_layers(n_=n_dict[key],
                                        p_=p_dict[key],
                                        nb_layers=nb_layer)
        G_multiplex.append(G_multiplex_tmp)
    G_multiplex = join_all_layer(*G_multiplex)

    print('the layers are generated')
    
    # set the layer-to-group truth labels
    nb_layer_multiplex = G_multiplex.shape[0]
    y_truth = np.zeros((nb_layer_multiplex))
    for i in range(nb_layer_multiplex):
        y_truth[i*nb_layer:(i+1)*nb_layer] = i

    # Set the vertex-to-block truth lables
    z_truth = dict()
    z_tmp = np.zeros(100)
    for i in range(2):
        for j in range(4):
            k = j*25
            z_tmp[k:k+25] = j+1
        z_truth[i] = z_tmp
    z_truth[2] = np.zeros(100)
    
    
    # We set as we know the number of groups and the number of block for each group
    y,z = _initial_latent_variable(X=G_multiplex,
                                    init_paramter='K_centroid',
                                    random_state=random_st,
                                    n_groups=3,
                                    n_block=np.array([4,4,4]))
    # The performance of the NMI can not take into account the 
    # comparaison between two groups where one has only one class
    # In almost test, we saw that fro the group 3, either 
    # all the vertices in the same block, or only few random vertice
    # are in different cluster, regarding tp the almost one.
    # The convient metrics in this case are the precision, recall or F1 score
    # However, when we look for only one cluster for the group 3 
    # it work perfectly even with nmi
    #y,z = _initial_latent_variable(X=G_multiplex,
    #                                init_paramter='K_centroid',
    #                                random_state=random_st,
    #                                n_groups=3,
    #                                n_block=np.array([4,4,1]))
    print('the initilizations are ',y)
    print('also ', z)
    beta, alpha, pi = init_parameters(G_multiplex,y,z)
    #print('the values of parameters are ')
    #print('beta ',beta)
    #print('alpha ',alpha)
    #print('pi ',pi)

    y,z = algorithm_MSBM(data=G_multiplex,
                        y_=y,
                        z_=z,
                        beta=beta,
                        alpha=alpha,
                        pi=pi,
                        iteration=20)
    print('the MSBM is done')
    ## MSNM performance
    y_predict = y.argmax(axis=1)
    msbm_y_performance = nmi(y_truth,y_predict)

    z_perdict_msbm_dict = dict()
    y_predict_unique = []
    for element in y_predict:
        if element not in y_predict_unique:
            y_predict_unique.append(element)

    for truth,predict in enumerate(y_predict_unique):
        z_predict = z[predict].argmax(axis=1)
        print('the z predicted for the group ',truth, ' are ')
        print(z_predict)
        z_perdict_msbm_dict[predict] = nmi(z_truth[truth],z_predict)
    print('the performances are ',msbm_y_performance,z_perdict_msbm_dict)
    ## Kmeans on layers
    kmeans_layers  = KMeans(n_clusters=3)
    labels = kmeans_layers.fit(G_multiplex.reshape(nb_layer_multiplex,-1)).labels_
    kmeans_y_performance = nmi(y_truth,labels)
    print('the performance of kmeans are ',kmeans_y_performance)
    ## Single SBM 

    y,z = _initial_latent_variable(X=G_multiplex,
                                    init_paramter='random',
                                    random_state=random_st,
                                    n_groups=1,
                                    n_block=np.array([4]))

    beta, alpha, pi = init_parameters(G_multiplex,y,z)
    y,z = algorithm_MSBM(data=G_multiplex,
                         y_=y,
                         z_=z,
                         beta=beta,
                         alpha=alpha,
                         pi=pi,
                         iteration=20)

    
    z_perdict_sbm_dict = dict()
    y_predict_unique = []
    for element in y_predict:
        if element not in y_predict_unique:
            y_predict_unique.append(element)
        
    z_predict = z[0].argmax(axis=1)
    for truth in range(3):
        z_perdict_sbm_dict[truth] = nmi(z_truth[truth],z_predict)
    
    print('the performances of sbm are ', z_perdict_sbm_dict)
    bar_plot_figures(values=[kmeans_y_performance,msbm_y_performance],
                     labels= ['Kmeans','MSBM'],
                     x_axis_label='Methods',
                     y_axis_label='Normalized Mutuel Information')

    mean_z_msbm_performance = 0
    mean_z_sbm_performance = 0

    for key in z_perdict_msbm_dict.keys():
        mean_z_msbm_performance += z_perdict_msbm_dict[key]
        mean_z_sbm_performance += z_perdict_sbm_dict[key]
    mean_z_msbm_performance /=3
    mean_z_sbm_performance /=3
    bar_plot_figures(values=[mean_z_sbm_performance,mean_z_msbm_performance],
                    labels= ['SBM','MSBM'],
                    x_axis_label='Methods',
                    y_axis_label='Mean Normalized Mutuel Information')