
import numpy as np
from munkres import Munkres
from sklearn.metrics import f1_score
from sklearn.metrics.cluster import normalized_mutual_info_score

def get_NMI_mtc(gnd, clus_res):
    '''
    Function to get normalized mutual information (NMI) metric
    :param gnd: label assignment of ground-truth
    :param clus_res: label assignment of clustering result
    :return:
    '''
    # ====================
    return normalized_mutual_info_score(gnd, clus_res)

def get_AC_mtc(gnd_seq, pred_seq):
    '''
    Function to get (clustering) accuracy (AC) metric
    :param gnd_seq: label sequence of ground-truth
    :param pred_seq: label sequence of the partitioning (clustering) result
    :return: AC metric
    '''
    # ====================
    res_map = best_map(gnd_seq, np.array(pred_seq))
    AC = f1_score(gnd_seq, res_map, average='micro')

    return AC

def best_map(L1, L2):
    '''
    Function to get the best membership map from label sequence L1 to L2 for AC metric
    :param L1: label sequence L1
    :param L2: label sequence L2
    :return: the best map membership
    '''
    # ====================
    Label1 = np.unique(L1)
    nClass1 = len(Label1)
    Label2 = np.unique(L2)
    nClass2 = len(Label2)
    nClass = np.maximum(nClass1,nClass2)
    G = np.zeros((nClass,nClass))
    for i in range(nClass1):
        ind_cla1 = L1 == Label1[i]
        ind_cla1 = ind_cla1.astype(float)
        for j in range(nClass2):
            ind_cla2 = L2 == Label2[j]
            ind_cla2 = ind_cla2.astype(float)
            G[i,j] = np.sum(ind_cla2 * ind_cla1)
    # ==========
    m = Munkres()
    index = m.compute(-G.T)
    index = np.array(index)
    c = index[:,1]
    newL2 = np.zeros(L2.shape)
    for i in range(nClass2):
        newL2[L2 == Label2[i]] = Label1[c[i]]

    return newL2

def get_cond_mtc(edges, clus_res, num_clus):
    '''
    Function to get conductance metric w.r.t. a clustering result
    :param edges: edge list (undirected & 0-base node indices)
    :param clus_res: clustering result
    :param num_clus: number of clusters
    :return:
    '''
    # ====================
    cuts = [0.0 for _ in range(num_clus)]
    vols = [0.0 for _ in range(num_clus)]
    # ==========
    for (src, dst) in edges:
        # ==========
        src_lbl = clus_res[src]
        dst_lbl = clus_res[dst]
        # ==========
        vols[src_lbl] += 1.0
        vols[dst_lbl] += 1.0
        # ==========
        if src_lbl != dst_lbl:
            cuts[src_lbl] += 1.0
            cuts[dst_lbl] += 1.0
    # ==========
    cond = 0.0
    for c in range(num_clus):
        if vols[c] == 0:
            cond += 1.0
        else:
            cond += cuts[c] / vols[c]
    cond /= num_clus

    return cond