
import networkx as nx
from networkx.algorithms.community import greedy_modularity_communities,_naive_greedy_modularity_communities
from sklearn.metrics.cluster import normalized_mutual_info_score
import numpy as np
import matplotlib.pyplot as plt
from utils import *

class modularity_max:   
    
    def __init__(self):
        self.name = "Modularity MAX"
        self.label = []
        
    def input(self, x):
        
        output_size = x.shape[1]
        
        Dtable = np.zeros([output_size, output_size])
        TPMatrix = np.zeros([output_size, output_size])
        
        prev_state = np.argmax(x[0])
        state = None
        for i in x[1:]:
            if np.max(i) != 1:
                continue
            state = np.argmax(i)
            Dtable[prev_state][state]+= 1                
            prev_state = state
                
        for i, j in enumerate(Dtable):
            state_total = np.sum(j)
            if state_total == 0:
                continue
            TPMatrix[i] = Dtable[i]/state_total
    
        g = utils.matrix_to_graph(TPMatrix)
        c = list(greedy_modularity_communities(g))
        #c = list(_naive_greedy_modularity_communities(g))
        #print("c",c)
        """label = np.zeros(output_size)
        label_count=0
        for index, i in enumerate(c):
            label_count+=1
            for k in i:
                label[k]=label_count
        
        self.label = label"""
        #plt.imshow(TPMatrix)
        #plt.show()
        #return label,c
        label=[]
        for index, i in enumerate(c):
            for k in i:
                label.append(index)
        
        self.label = label
        #plt.imshow(TPMatrix)
        return label   
    def evaluation(self, true_label):
        if len(self.label) == 0:
            return
        return normalized_mutual_info_score(self.label,true_label)