
from meth.ASCENT import *
from scipy.sparse import csr_matrix
from utils import *
import time

def get_sup_sp(edges, degs):
    # ====================
    num_nodes = len(degs)
    # ==========
    src_idxs = []
    dst_idxs = []
    vals = []
    # ==========
    for i in range(num_nodes): # Diagonal elements
        src_idxs.append(i)
        dst_idxs.append(i)
        vals.append(1.0 / (degs[i]+1.0))
    # ==========
    for (src, dst) in edges:
        # ==========
        src_idxs.append(src)
        dst_idxs.append(dst)
        vals.append(1.0 / (degs[src]+1.0))
        # ==========
        src_idxs.append(dst)
        dst_idxs.append(src)
        vals.append(1.0 / (degs[dst]+1.0))

    return src_idxs, dst_idxs, vals

# ====================
data_name = 'wiki'
rand_seed_list = [0, 1, 10, 100, 1000]
Ks = [2, 8, 32]

theta = 0.1
L = 4

# ====================
edges = []
f_input = open('data/%s.edgelist' % (data_name), 'r')
for line in f_input.readlines():
    rec = line.strip().split(' ')
    src = int(rec[0])
    dst = int(rec[1])
    edges.append((src, dst))
f_input.close()
num_edges = len(edges)
num_nodes = np.max(edges) + 1
# ==========
print('DATA %s #NODES %d #EDGES %d #CLUS N/A'
      % (data_name, num_nodes, num_edges))

# ====================
degs = [0 for _ in range(num_nodes)]
for (src, dst) in edges:
    degs[src] += 1
    degs[dst] += 1

# ====================
tau = np.array(degs).reshape((-1, 1))
src_idxs, dst_idxs, vals = get_sup_sp(edges, degs)
sup_sp = csr_matrix((vals, (src_idxs, dst_idxs)), shape=(num_nodes, num_nodes))
for _ in range(L):
    tau = sup_sp.dot(tau)
tau = tau*theta
tau = list(tau[:, 0])

# ====================
print('%s L=%d theta=%.1f' % (data_name, L, theta))
for num_clus in Ks:
    # ==========
    cond_mtc_list = []
    time_list = []
    for rand_seed in rand_seed_list:
        # ==========
        time_s = time.time()
        #clus_res = ASCENT(edges, num_nodes, num_clus, tau, seed=rand_seed)
        clus_res = ASCENT_sp(edges, num_nodes, num_clus, tau, seed=rand_seed)
        time_e = time.time()
        run_time = time_e - time_s
        # ==========
        cond_mtc = get_cond_mtc(edges, clus_res, num_clus)
        # ==========
        cond_mtc_list.append(cond_mtc)
        time_list.append(run_time)
    # ==========
    cond_mean = np.mean(cond_mtc_list)
    cond_std = np.std(cond_mtc_list)
    time_mean = np.mean(time_list)
    time_std = np.std(time_list)
    print('K=%d COND %.4f~(%.4f) TIME %.4f~(%.4f)' %
          (num_clus, cond_mean, cond_std, time_mean, time_std))
