'''
    Genetates layers from the same distribution with assigment group different.
    Compute layer-to-group and vertex-to-block variables using MSBM and compare it to Kmeans and 
    SBM for each task.
'''
from MSBM import _initial_latent_variable, init_parameters, algorithm_MSBM
from utils import generate_layers, join_all_layer, nmi, _argmax_function, bar_plot_figures, random_st
from sklearn.cluster import KMeans
import numpy as np

if __name__== '__main__':
    np.random.seed(1)
    n_1 = np.array([25,25,25,25])
    n_2 = np.array([20,25,25,30])
    n_3 = np.array([30,30,20,20])
    n_dict = {
            '1':n_1,
            '2':n_2,
            '3':n_3
    }
    p = np.array([[0.5,0.3,0.3,0.3],
                [0.3,0.5,0.3,0.3],
                [0.3,0.3,0.5,0.3],
                [0.3,0.3,0.3,0.5]])
    p_dict = {
            '1':p,
            '2':p,
            '3':p
    }
    G_multiplex = list()
    nb_layer = 10
    nb_vertex = 100
    for key in p_dict.keys():
        G_multiplex_tmp = generate_layers(n_=n_dict[key],
                                        p_=p_dict[key],
                                        nb_layers=nb_layer)
        G_multiplex.append(G_multiplex_tmp)
    G_multiplex = join_all_layer(*G_multiplex)
    print('the layers are generated')
    # set the layer-to-group truth labels
    nb_layer_multiplex = G_multiplex.shape[0]
    y_truth = np.zeros((nb_layer_multiplex))
    for i in range(nb_layer_multiplex):
        y_truth[i*nb_layer:(i+1)*nb_layer] = i

    # Set the vertex-to-block truth lables
    z_truth = dict()
    z_tmp = np.zeros(100)
    for j in range(4):
        k = j*25
        z_tmp[k:k+25] = j+1
    z_truth[0] = z_tmp
    z_tmp = np.zeros(100)
    z_tmp[0:20] = 1
    z_tmp[20:45] = 2
    z_tmp[45:70] = 3
    z_tmp[70:] = 4
    z_truth[1] = z_tmp
    z_tmp = np.zeros(100)
    z_tmp[0:30] = 1
    z_tmp[30:60] = 2
    z_tmp[60:80] = 3
    z_tmp[80:] = 4
    z_truth[2] = z_tmp
    
    # We set as we know the number of groups and the number of block for each group
    y,z = _initial_latent_variable(X=G_multiplex,
                                    init_paramter='K_centroid',
                                    random_state=random_st,
                                    n_groups=3,
                                    n_block=np.array([4,4,4]))
    print('the initilizations are ',y)
    print('also ', z)
    beta, alpha, pi = init_parameters(G_multiplex,y,z)
    #print('the values of parameters are ')
    #print('beta ',beta)
    #print('alpha ',alpha)
    #print('pi ',pi)
    y,z = algorithm_MSBM(data=G_multiplex,
                        y_=y,
                        z_=z,
                        beta=beta,
                        alpha=alpha,
                        pi=pi,
                        iteration=20)
    print('the MSBM is done')
    ## MSNM performance
    y_predict = y.argmax(axis=1)
    msbm_y_performance = nmi(y_truth,y_predict)

    z_perdict_msbm_dict = dict()
    y_predict_unique = []
    for element in y_predict:
        if element not in y_predict_unique:
            y_predict_unique.append(element)

    for truth,predict in enumerate(y_predict_unique):
        z_predict = z[predict].argmax(axis=1)
        z_perdict_msbm_dict[predict] = nmi(z_truth[truth],z_predict)
    print('the performances are ',msbm_y_performance,z_perdict_msbm_dict)
    ## Kmeans on layers
    kmeans_layers  = KMeans(n_clusters=3)
    labels = kmeans_layers.fit(G_multiplex.reshape(nb_layer_multiplex,-1)).labels_
    kmeans_y_performance = nmi(y_truth,labels)
    print('the performance of kmeans are ',kmeans_y_performance)
    ## Single SBM 

    y,z = _initial_latent_variable(X=G_multiplex,
                                    init_paramter='random',
                                    random_state=random_st,
                                    n_groups=1,
                                    n_block=np.array([4]))

    beta, alpha, pi = init_parameters(G_multiplex,y,z)
    y,z = algorithm_MSBM(data=G_multiplex,
                         y_=y,
                         z_=z,
                         beta=beta,
                         alpha=alpha,
                         pi=pi,
                         iteration=20)

    
    z_perdict_sbm_dict = dict()
    y_predict_unique = []
    for element in y_predict:
        if element not in y_predict_unique:
            y_predict_unique.append(element)
        
    z_predict = z[0].argmax(axis=1)
    for truth in range(3):
        z_perdict_sbm_dict[truth] = nmi(z_truth[truth],z_predict)
    
    print('the performances of sbm are ', z_perdict_sbm_dict)
    bar_plot_figures(values=[kmeans_y_performance,msbm_y_performance],
                     labels= ['Kmeans','MSBM'],
                     x_axis_label='Methods',
                     y_axis_label='Normalized Mutuel Information')

    mean_z_msbm_performance = 0
    mean_z_sbm_performance = 0

    for key in z_perdict_msbm_dict.keys():
        mean_z_msbm_performance += z_perdict_msbm_dict[key]
        mean_z_sbm_performance += z_perdict_sbm_dict[key]
    mean_z_msbm_performance /=3
    mean_z_sbm_performance /=3
    bar_plot_figures(values=[mean_z_sbm_performance,mean_z_msbm_performance],
                    labels= ['SBM','MSBM'],
                    x_axis_label='Methods',
                    y_axis_label='Mean Normalized Mutuel Information')