# Implementation of RSC method
# (1) construct regularized graph Laplacian
# (2) EVD of reg graph Lap
# (3) arrange spectral embedding
# (4) normalize spec emb
# (5) KMeans clustering

from .meth_utils import *
from scipy import linalg
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigsh
from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans

def RSC(edges, num_nodes, num_clus, tau, seed=None):
    '''
    Function to implement RSC alg (via dense EVD)
    :param edges: edge list
    :param num_nodes: number of nodes
    :param num_clus: number of clusters (i.e., K)
    :param tau: degree-corrected term
    :param seed: random seed
    :return: clustering result
    '''
    # ====================
    reg_lap = get_reg_lap(edges, num_nodes, tau)
    # ==========
    eigvals_, eigvecs_ = linalg.eig(reg_lap)
    eigvals = np.array([eigvals_[i].real for i in range(num_nodes)])
    sort_idx = sorted(range(len(eigvals)), key=lambda k: eigvals[k], reverse=True)
    #eigvals = eigvals[sort_idx]
    # ==========
    eigvecs = eigvecs_.real[:, sort_idx]
    spec_emb = eigvecs[:, 0:num_clus]
    # ==========
    spec_emb = normalize(spec_emb, 'l2')
    # ==========
    kmeans = KMeans(n_clusters=num_clus, random_state=seed).fit(spec_emb)
    clus_res = kmeans.labels_

    return clus_res

def RSC_sp(edges, num_nodes, num_clus, tau, seed=None):
    '''
    Function to implement RSC alg (via sparse EVD)
    :param edges: edge list
    :param num_nodes: number of nodes
    :param num_clus: number of clusters (i.e., K)
    :param tau: degree-corrected term
    :param seed: random seed
    :return: clustering result
    '''
    # ====================
    src_idxs, dst_idxs, vals = get_reg_lap_sp(edges, num_nodes, tau)
    reg_lap_sp = csr_matrix((vals, (src_idxs, dst_idxs)), shape=(num_nodes, num_nodes))
    # ==========
    K = min(num_clus+200, num_nodes)
    eigvals_, eigvecs_ = eigsh(reg_lap_sp, k=K)
    eigvals = np.array(eigvals_)
    sort_idx = sorted(range(len(eigvals)), key=lambda k: eigvals[k], reverse=True)
    #eigvals = eigvals[sort_idx]
    # ==========
    eigvecs = eigvecs_[:, sort_idx]
    spec_emb = eigvecs[:, 0:num_clus]
    # ==========
    spec_emb = normalize(spec_emb, 'l2')
    # ==========
    kmeans = KMeans(n_clusters=num_clus, random_state=seed).fit(spec_emb)
    clus_res = kmeans.labels_

    return clus_res