# Code adapted from https://github.com/GTmac/FastRP

import numpy as np

from sklearn import random_projection
from sklearn.preprocessing import normalize, scale
from scipy.sparse import csr_matrix, csc_matrix, spdiags, issparse


# projection method: choose from Gaussian and Sparse
# input matrix: choose from adjacency and transition matrix
# alpha adjusts the weighting of nodes according to their degree
def fastrp_projection(
    A, q=3, dim=128, projection_method="gaussian", input_matrix="adj", alpha=None
):
    assert input_matrix == "adj" or input_matrix == "trans"
    assert projection_method == "gaussian" or projection_method == "sparse"

    if input_matrix == "adj":
        M = A
        N = A.shape[0]
    else:
        N = A.shape[0]
        normalizer = spdiags(np.squeeze(1.0 / csc_matrix.sum(A, axis=1)), 0, N, N)
        M = normalizer @ A
    # Gaussian projection matrix
    if projection_method == "gaussian":
        transformer = random_projection.GaussianRandomProjection(
            n_components=dim, random_state=42
        )
    # Sparse projection matrix
    else:
        transformer = random_projection.SparseRandomProjection(
            n_components=dim, random_state=42
        )
    Y = transformer.fit(M)
    # Random projection for A
    if alpha is not None:
        # Safe degree weights: avoid inf for zero-degree nodes with negative alpha
        deg = np.asarray(csc_matrix.sum(A, axis=1)).ravel().astype(np.float64)
        if alpha < 0:
            deg = np.where(deg > 0, deg, 1.0)  # replace 0 with 1 to avoid 0**neg -> inf
        weights = np.power(deg, alpha)
        Y.components_ = Y.components_ @ spdiags(weights, 0, N, N)
    cur_U = transformer.transform(M)
    U_list = [cur_U]

    for i in range(2, q + 1):
        print("Computing power", i)
        cur_U = M @ cur_U
        U_list.append(cur_U)
    return U_list


# When weights is None, concatenate instead of linearly combines the embeddings from different powers of A
def fastrp_merge(U_list, weights, normalization=False):
    # Ensure dense ndarrays (avoid np.matrix)
    dense_U_list = [(_U.toarray() if issparse(_U) else np.asarray(_U)) for _U in U_list]

    _U_list = (
        [normalize(_U, norm="l2", axis=1) for _U in dense_U_list]
        if normalization
        else dense_U_list
    )

    if weights is None:
        return np.concatenate(_U_list, axis=1)

    U = np.zeros_like(_U_list[0], dtype=float)
    for cur_U, weight in zip(_U_list, weights):
        U += cur_U * weight

    # Remove NaN/Inf that may still arise from numerical issues
    U = np.nan_to_num(U, nan=0.0, posinf=0.0, neginf=0.0)

    return scale(U)


# A is always the adjacency matrix
# the choice between adj matrix and trans matrix is decided in the conf
def fastrp_wrapper(A, conf):
    U_list = fastrp_projection(
        A,
        q=len(conf["weights"]),
        dim=conf["dim"],
        projection_method=conf["projection_method"],
        input_matrix=conf["input_matrix"],
        alpha=conf["alpha"],
    )
    U = fastrp_merge(U_list, conf["weights"], conf["normalization"])
    return U
