
import numpy as np
from scipy import linalg
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigsh
from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans

import warnings
warnings.filterwarnings('ignore')

def get_reg_lap_ASCENT(edges, num_nodes, tau):
    '''
    Function to get (dense) regularized graph Laplacian
    i.e., D_{tau}^{-1/2}AD_{\tau}^{-1/2}
    :param edges: edge list
    :param num_nodes: number of nodes
    :param tau: node-wise degree-corrected terms
    :return: (dense) reg Lap mat
    '''
    # ====================
    degs = [0.0 for _ in range(num_nodes)]
    for (src, dst) in edges:
        degs[src] += 1.0
        degs[dst] += 1.0
    for i in range(num_nodes):
        degs[i] += tau[i]
    # ==========
    lap = np.zeros((num_nodes, num_nodes))
    for (src, dst) in edges:
        v = 1.0 / (np.sqrt(degs[src])*np.sqrt(degs[dst]))
        lap[src, dst] = v
        lap[dst, src] = v

    return lap

def get_reg_lap_sp_ASCENT(edges, num_nodes, tau):
    '''
    Function to get sparse regularized graph Laplacian
    i.e., D_{tau}^{-1/2}AD_{\tau}^{-1/2}
    :param edges: edge list
    :param num_nodes: number of nodes
    :param tau: node-wise degree-corrected terms
    :return: (sparse) reg Lap mat
    '''
    # ====================
    degs = [0.0 for _ in range(num_nodes)]
    for (src, dst) in edges:
        degs[src] += 1.0
        degs[dst] += 1.0
    for i in range(num_nodes):
        degs[i] += tau[i]
    # ==========
    src_idxs = []
    dst_idxs = []
    vals = []
    for (src, dst) in edges:
        # ==========
        v = 1.0 / (np.sqrt(degs[src])*np.sqrt(degs[dst]))
        # ==========
        src_idxs.append(src)
        dst_idxs.append(dst)
        vals.append(v)
        src_idxs.append(dst)
        dst_idxs.append(src)
        vals.append(v)

    return src_idxs, dst_idxs, vals

def ASCENT(edges, num_nodes, num_clus, tau, seed=None):
    '''
    ASCENT (w/ dense EVD)
    :param edges: edge list
    :param num_nodes: number of nodes
    :param num_clus: number of clusters (i.e., K)
    :param tau: node-wise degree-corrected terms
    :param seed: random seed
    :return: clustering result
    '''
    # ====================
    reg_lap = get_reg_lap_ASCENT(edges, num_nodes, tau)
    # ==========
    eigvals_, eigvecs_ = linalg.eig(reg_lap)
    eigvals = np.array([eigvals_[i].real for i in range(num_nodes)])
    sort_idx = sorted(range(len(eigvals)), key=lambda k: eigvals[k], reverse=True)
    eigvals = eigvals[sort_idx]
    # ==========
    eigvecs = eigvecs_.real[:, sort_idx]
    spec_emb = eigvecs[:, 0:num_clus+1]
    for r in range(num_clus+1):
        spec_emb[:, r] *= eigvals[r]
    # ==========
    spec_emb = normalize(spec_emb, 'l2')
    # ==========
    kmeans = KMeans(n_clusters=num_clus, random_state=seed).fit(spec_emb)
    clus_res = kmeans.labels_

    return clus_res

def ASCENT_sp(edges, num_nodes, num_clus, tau, seed=None):
    '''
    ASCENT (w/ sparse EVD)
    :param edges: edge list
    :param num_nodes: number of nodes
    :param num_clus: number of clusters (i.e., K)
    :param tau: node-wise degree-corrected terms
    :param seed: random seed
    :return: clustering result
    '''
    # ====================
    src_idxs, dst_idxs, vals = get_reg_lap_sp_ASCENT(edges, num_nodes, tau)
    reg_lap_sp = csr_matrix((vals, (src_idxs, dst_idxs)), shape=(num_nodes, num_nodes))
    # ==========
    K = min(num_clus+200+1, num_nodes)
    #K = min(num_clus+100+1, num_nodes)
    eigvals_, eigvecs_ = eigsh(reg_lap_sp, k=K)
    eigvals = np.array(eigvals_)
    sort_idx = sorted(range(len(eigvals)), key=lambda k: eigvals[k], reverse=True)
    eigvals = eigvals[sort_idx]
    # ==========
    eigvecs = eigvecs_[:, sort_idx]
    spec_emb = eigvecs[:, 0:num_clus+1]
    for r in range(num_clus+1):
        spec_emb[:, r] *= eigvals[r]
    # ==========
    spec_emb = normalize(spec_emb, 'l2')
    # ==========
    kmeans = KMeans(n_clusters=num_clus, random_state=seed).fit(spec_emb)
    clus_res = kmeans.labels_

    return clus_res

