import numpy as np
from scipy.linalg import null_space, eigh, sqrtm
from sklearn.cluster import k_means


def kmeans(features: np.ndarray, num_clusters: int, normalize_rows: bool = False) -> np.ndarray:
    """
    :param features: (num_nodes, num_features) Feature matrix
    :param num_clusters: Number of clusters to discover
    :param normalize_rows: Whether to normalize rows of feature matrix or not
    :return clusters: (num_nodes,) Discovered clusters
    """
    if normalize_rows:
        features = features / np.linalg.norm(features, axis=1).reshape((-1, 1))
    _, clusters, _ = k_means(features, num_clusters)
    return clusters


def compute_top_eigen(mat: np.ndarray, k: int) -> np.ndarray:
    """
    :param mat: (num_nodes, num_nodes) Symmetric matrix for which eigenvectors must be computed
    :param k: Number of clusters to discover
    :return evec: (num_nodes, k) Top k eigenvectors of mat for the smallest k eigenvalues
    """
    assert mat.shape[0] >= k, 'Insufficient number of eigenvectors'
    _, vec = eigh(mat, subset_by_index=[0, k - 1])
    return vec


def compute_laplacian(adj_mat: np.ndarray, normalize_laplacian: bool = False) -> np.ndarray:
    """
    :param adj_mat: (num_nodes, num_nodes) Adjacency matrix of the observed graph
    :param normalize_laplacian: Whether to use normalized Laplacian or not
    :return laplacian: (num_nodes, num_nodes) The laplacian of the graph
    """
    degree_mat = np.diag(np.sum(adj_mat, axis=1))
    laplacian = degree_mat - adj_mat
    if normalize_laplacian:
        degree_mat_inv = np.sqrt(np.linalg.inv(1e-6 * np.eye(adj_mat.shape[0]) + degree_mat))
        laplacian = np.matmul(degree_mat_inv, np.matmul(laplacian, degree_mat_inv))
    return laplacian


def fair_sc(adj_mat: np.ndarray, fair_mat: np.ndarray, num_clusters: int, normalize_laplacian: bool = False,
            normalize_evec: bool = False) -> np.ndarray:
    """
    :param adj_mat: (num_nodes, num_nodes) Adjacency matrix of the observed graph
    :param fair_mat: (num_nodes, num_nodes) A graph specifying which node can represent which other nodes
    :param num_clusters: Number of clusters to discover
    :param normalize_laplacian: Whether to use normalized Laplacian or not
    :param normalize_evec: Whether to normalize the rows of eigenvector matrix before running k-means
    :return clusters: (num_nodes,) The cluster assignment for each node
    """

    # Compute the constraint matrix
    ones = np.ones(adj_mat.shape)
    c_mat = np.matmul(fair_mat, np.eye(adj_mat.shape[0]) - ones / adj_mat.shape[0])
    null_space_basis = null_space(c_mat)
    assert null_space_basis.shape[1] >= num_clusters, 'Rank of c_mat is too high'

    # Compute the Laplacian
    laplacian = compute_laplacian(adj_mat, normalize_laplacian=False)
    if normalize_laplacian:
        degree_mat = np.diag(adj_mat.sum(axis=1))
        q_mat = np.real(sqrtm(np.matmul(np.matmul(null_space_basis.T, degree_mat), null_space_basis)))
        q_inv = np.linalg.inv(1e-6 * np.eye(q_mat.shape[0]) + q_mat)
        null_space_basis = np.matmul(null_space_basis, q_inv)
    laplacian = np.matmul(null_space_basis.T, np.matmul(laplacian, null_space_basis))

    # Compute eigenvectors
    vec = compute_top_eigen(laplacian, num_clusters)
    vec = np.matmul(null_space_basis, vec)

    # Run k-means
    clusters = kmeans(vec, num_clusters, normalize_evec)

    return clusters


def normal_sc(adj_mat: np.ndarray, num_clusters: int, normalize_laplacian: bool = False, normalize_evec: bool = False) \
        -> np.ndarray:
    """
    :param adj_mat: (num_nodes, num_nodes) Adjacency matrix of the observed graph
    :param num_clusters: Number of clusters to discover
    :param normalize_laplacian: Whether to use normalized Laplacian or not
    :param normalize_evec: Whether to normalize the rows of eigenvector matrix before running k-means
    :return clusters: (num_nodes,) The cluster assignment for each node
    """

    # Compute the Laplacian
    laplacian = compute_laplacian(adj_mat, normalize_laplacian)

    # Compute eigenvectors
    vec = compute_top_eigen(laplacian, num_clusters)

    # Run k-means
    clusters = kmeans(vec, num_clusters, normalize_evec)

    return clusters
