# Implementation of the SCORE+ method
# (1) construct regularized graph Laplacian
# (2) EVD of norm graph Lap
# (3) arrange spectral embedding (determine whether to use K+1 eigenvalues/vectors)
# (4) normalize spec emb
# (5) KMeans clustering

from .meth_utils import *
from scipy import linalg
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigsh
#from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans

def SCORE_plus(edges, num_nodes, num_clus, tau, seed=None, t=0.1):
    '''
    Function to implement SCORE+ alg (via dense EVD)
    :param edges: edge list
    :param num_nodes: number of nodes
    :param num_clus: number of clusters (i.e., K)
    :param tau: degree-corrected term
    :param t: threshold to determine whether to use K+1 eigenvalues/vectors
    :param seed: random seed
    :return: clustering result
    '''
    # ====================
    reg_lap = get_reg_lap(edges, num_nodes, tau)
    # ==========
    eigvals_, eigvecs_ = linalg.eig(reg_lap)
    eigvals = np.array([eigvals_[i].real for i in range(num_nodes)])
    sort_idx = sorted(range(len(eigvals)), key=lambda k: eigvals[k], reverse=True)
    eigvals = eigvals[sort_idx]
    # ==========
    eigvecs = eigvecs_.real[:, sort_idx]
    spec_emb = eigvecs[:, 0:num_clus+1]
    for r in range(num_clus+1):
        spec_emb[:, r] *= eigvals[r]
    base_emb = spec_emb[:, 0]
    #for i in range(num_nodes):
    #    if base_emb[i] == 0:
    #        base_emb[i] = 1e-3
    spec_emb = spec_emb[:, 1:num_clus+1]
    # ==========
    for r in range(num_clus):
        spec_emb[:, r] /= base_emb
    if (1 - eigvals[num_clus]/eigvals[num_clus-1]) > t:
        spec_emb = spec_emb[:, 0:num_clus-1]
    # ==========
    kmeans = KMeans(n_clusters=num_clus, random_state=seed).fit(spec_emb)
    clus_res = kmeans.labels_

    return clus_res

def SCORE_plus_sp(edges, num_nodes, num_clus, tau, seed=None, t=0.1):
    '''
    Function to implement SCORE+ alg (via dense EVD)
    :param edges: edge list
    :param num_nodes: number of nodes
    :param num_clus: number of clusters (i.e., K)
    :param tau: degree-corrected term
    :param t: threshold to determine whether to use K+1 eigenvalues/vectors
    :param seed: random seed
    :return: clustering result
    '''
    # ====================
    src_idxs, dst_idxs, vals = get_reg_lap_sp(edges, num_nodes, tau)
    reg_lap_sp = csr_matrix((vals, (src_idxs, dst_idxs)), shape=(num_nodes, num_nodes))
    K = min(num_clus+200+1, num_nodes)
    eigvals_, eigvecs_ = eigsh(reg_lap_sp, k=K)
    eigvals = np.array(eigvals_)
    sort_idx = sorted(range(len(eigvals)), key=lambda k: eigvals[k], reverse=True)
    eigvals = eigvals[sort_idx]
    # ==========
    eigvecs = eigvecs_[:, sort_idx]
    spec_emb = eigvecs[:, 0:num_clus+1]
    for r in range(num_clus + 1):
        spec_emb[:, r] *= eigvals[r]
    base_emb = spec_emb[:, 0]
    for i in range(num_nodes):
        if base_emb[i] == 0:
            base_emb[i] = 1e-3
    spec_emb = spec_emb[:, 1:num_clus+1]
    # ==========
    for r in range(num_clus):
        spec_emb[:, r] /= base_emb
    if (1 - eigvals[num_clus]/eigvals[num_clus-1]) > t:
        spec_emb = spec_emb[:, 0:num_clus-1]
    # ==========
    kmeans = KMeans(n_clusters=num_clus, random_state=seed).fit(spec_emb)
    clus_res = kmeans.labels_

    return clus_res
