
from meth.ASCENT import *
from scipy.sparse import csr_matrix
from utils import *
import time

def get_sup_sp(edges, degs):
    # ====================
    num_nodes = len(degs)
    # ==========
    src_idxs = []
    dst_idxs = []
    vals = []
    # ==========
    for i in range(num_nodes): # Diagonal elements
        src_idxs.append(i)
        dst_idxs.append(i)
        vals.append(1.0 / (degs[i]+1.0))
    # ==========
    for (src, dst) in edges:
        # ==========
        src_idxs.append(src)
        dst_idxs.append(dst)
        vals.append(1.0 / (degs[src]+1.0))
        # ==========
        src_idxs.append(dst)
        dst_idxs.append(src)
        vals.append(1.0 / (degs[dst]+1.0))

    return src_idxs, dst_idxs, vals

# ====================
data_name = 'caltech'
rand_seed_list = [0, 1, 10, 100, 1000]

theta = 0.1
L = 3

# ====================
edges = []
f_input = open('data/%s.edgelist' % (data_name), 'r')
for line in f_input.readlines():
    rec = line.strip().split(' ')
    src = int(rec[0])
    dst = int(rec[1])
    edges.append((src, dst))
f_input.close()
num_edges = len(edges)
# ==========
gnd = []
f_input = open('data/%s.gnd' % (data_name), 'r')
for line in f_input.readlines():
    lbl = int(line.strip())
    gnd.append(lbl)
f_input.close()
num_nodes = len(gnd)
num_clus = np.max(gnd) + 1
# ==========
print('DATA %s #NODES %d #EDGES %d #CLUS %d'
      % (data_name, num_nodes, num_edges, num_clus))

# ====================
degs = [0 for _ in range(num_nodes)]
for (src, dst) in edges:
    degs[src] += 1
    degs[dst] += 1

# ====================
tau = np.array(degs).reshape((-1, 1))
src_idxs, dst_idxs, vals = get_sup_sp(edges, degs)
sup_sp = csr_matrix((vals, (src_idxs, dst_idxs)), shape=(num_nodes, num_nodes))
for _ in range(L):
    tau = sup_sp.dot(tau)
tau = tau * theta
tau = list(tau[:, 0])

# ====================
NMI_mtc_list = []
AC_mtc_list = []
cond_mtc_list = []
time_list = []
# ==========
for rand_seed in rand_seed_list:
    # ==========
    time_s = time.time()
    #clus_res = ASCENT(edges, num_nodes, num_clus, tau, seed=rand_seed)
    clus_res = ASCENT_sp(edges, num_nodes, num_clus, tau, seed=rand_seed)
    time_e = time.time()
    run_time = time_e - time_s
    # ==========
    NMI_mtc = get_NMI_mtc(gnd, clus_res)
    AC_mtc = get_AC_mtc(gnd, clus_res)
    cond_mtc = get_cond_mtc(edges, clus_res, num_clus)
    # ==========
    NMI_mtc_list.append(NMI_mtc)
    AC_mtc_list.append(AC_mtc)
    cond_mtc_list.append(cond_mtc)
    time_list.append(run_time)
# ==========
NMI_mean = np.mean(NMI_mtc_list)
NMI_std = np.std(NMI_mtc_list)
AC_mean = np.mean(AC_mtc_list)
AC_std = np.std(AC_mtc_list)
cond_mean = np.mean(cond_mtc_list)
cond_std = np.std(cond_mtc_list)
time_mean = np.mean(time_list)
time_std = np.std(time_list)
print('%s L=%d theta=%.1f '
      'NMI %.4f~(%.4f) AC %.4f~(%.4f) '
      'COND %.4f~(%.4f) TIME %.4f~(%.4f)' %
      (data_name, L, theta,
       NMI_mean, NMI_std, AC_mean, AC_std,
       cond_mean, cond_std, time_mean, time_std))
