
import numpy as np
import scipy.sparse as sp 
import time
from utils.data_utils import CoarsenedGraph, Graph, Config
from typing import Optional
from copy import deepcopy
from sortedcontainers import SortedList
from numba import jit
from tqdm import tqdm




def compute_rsa_exact(Gc:CoarsenedGraph, G:Graph, laplacian_name:Optional[str] = "combinatorial",
                      R_is_L:Optional[str] = True, other_R:Optional[sp.csr_matrix] = None) -> float:
    """
    Compute an determinist of the RSA constant.
    Args:
       
        R_is_L(bool): if True the R matrix is a subspace of laplacian eigenvectors  of the norm
        other_R(sp.csr_matrix): the R matrix to use

    """
    if other_R is not None:
        R = other_R
    elif Gc.R is not None:
        R = Gc.R
    else:
        raise ValueError("R is not defined, please provide a R matrix or a coarsened graph with a R matrix")
    
    P = Gc.P
    Q = Gc.Q
    Pi = Q @ P
    Pi_perp = sp.eye(Pi.shape[0]) - Pi
    L_norm = G.get_laplacian(laplacian_name)
    N, K = R.shape
    np.random.seed(42)

    eig_val, eig_vec = np.linalg.eigh(L_norm.toarray())
    eig_val = eig_val.ravel()
    mask_zero_eig = eig_val < 1e-5
    eig_val[mask_zero_eig] = 0
    eig_val_sqrt = np.sqrt(eig_val)
    L_sqrt = eig_vec @ sp.diags(eig_val_sqrt) @ eig_vec.T
    eig_val[mask_zero_eig] = 1
    eig_val_inv_sqrt = 1/np.sqrt(eig_val)
    eig_val_inv_sqrt[mask_zero_eig] = 0
    L_inv_sqrt = eig_vec @ sp.diags(eig_val_inv_sqrt) @ eig_vec.T
    L_sqrt = sp.csr_matrix(L_sqrt)
    L_inv_sqrt = sp.csr_matrix(L_inv_sqrt)
    if not R_is_L:
        matrix_to_compute_norm = L_sqrt @ Pi_perp @ R @ R.T @ L_inv_sqrt
    else:
        matrix_to_compute_norm = L_sqrt @ Pi_perp @ R @ sp.diags(eig_val_inv_sqrt[:K])
    
    matrix_to_compute_norm = matrix_to_compute_norm.toarray()
    rsa = np.linalg.norm(matrix_to_compute_norm, ord=2)
    return rsa




def coarsen_algo_inspired_loukas(original_graph: Graph, config: Config, R_files: Optional[str] = None, graph_symmetric: bool = True, graph_name:Optional[str]= None)-> CoarsenedGraph:
    """
    Coarsen a graph with different variations.
    Args:
        G  graph to coarsen represented as a sparse adjacency matrix
        config (Config): configuration for the coarsening
            config.r (float): ratio of nodes to keep
            config.K (int): number of eigenvectors to preserve
            config.max_levels (int): maximum number of levels
            config.preserving (string): method to preserve the laplacian here it is self looped adjacency
            config.space_preserved (string): space preserved by the coarsening eigenvectors
            config.method (string): method to coarsen the graph here it is edges or neighborhood
            config.method_cost (string): method to compute the cost here it is spectral
            config.n_e (int): number of edges to remove at maximum at one step
        original_features (torch tensor): features of the original graph
        R_files (str) : space preserved already computed

    Returns:
        G_c (pyg graph): coarsened graph
        Q (torch tensor): coarsening matrix
        Q_inv (torch tensor): inverse of the coarsening matrix
        R_Q (torch tensor): Preserved space
    """
    adjacency_sparse_original = original_graph.csr_adj
    original_features = original_graph.features
    init_time_2 = time.time()
    if not sp.isspmatrix_csr(adjacency_sparse_original):
        print("Converting the adjacency matrix to csr, was originally in following format")
        adjacency_sparse_original = adjacency_sparse_original.tocsr()

    if sum(adjacency_sparse_original.diagonal()) > 0:
        print("The graph contains self loops, we will remove them; the version to add them at final step is not implemented yet")
        print("This is due to the difficulty to deal with laplacian for self looped graphs")
        non_null_diag = adjacency_sparse_original.diagonal()
        adjacency_sparse_original = adjacency_sparse_original - sp.diags_array(non_null_diag)
    r = config.r
    if r <0 :
        r = 0
    if r > 0.9999:
        r = 0.9999
    N = adjacency_sparse_original.shape[0]
    n, n_target = N, np.ceil((1 - r) * N) # current and target graph sizes
    K = config.K  # number of eigenvectors to preseve

    if config.n_e is None and config.n_e_percent is None:
        print("error n_e and n_e_percent are both None")
    elif config.n_e is None and config.n_e_percent is not None:
        print("n_e is None, we will use the percentage of nodes to remove adaptative")
        num_levels_estimated = int(np.floor(np.log(1 - r) / np.log(1 - config.n_e_percent))) #num_levels = floor( log(1-r)/log(1-n_e_percent))
        if config.n_e_percent < 0.08:
            print("n_e_percent is too low, we will add a 200 buffer to fix the number of levels")
            num_levels_estimated += 200
    elif config.n_e is not None and config.n_e_percent is None:
        print("n_e_percent is None, we will use the fixed number of nodes to remove")
        num_levels_estimated = int(np.floor((r*N/config.n_e))) 
    else:
        num_levels_estimated = int(np.ceil(np.log(1 - r) / np.log(1 - config.n_e_percent))) #num_levels = floor( log(1-r)/log(1-n_e_percent))
        print("n_e and n_e_percent are both not None, we will use the adaptative number of nodes to remove")

    P_loukas = sp.eye(N, format='csr')  # coarsening matrix
    P_mp = sp.eye(N, format='csr')  # coarsening matrix
    intermediate_Q_combi = sp.eye(N, format='csr')  # lifting matrix intermediate
    Q_combi_lap = sp.eye(N, format='csr')  # lifting matrix
    Q_lifting = sp.eye(N, format='csr')  # coarsening matrix
    if not (csr_is_symmetric(adjacency_sparse_original)):
        print("warning the original graph is not symmetric")
        
    if np.any(np.sum(adjacency_sparse_original, axis=1) == 0):
        print("a row of the adjacency matrix is full of zeros, this will lead to a non connected graph")

    adjacency_intermediate = deepcopy(adjacency_sparse_original)

    if num_levels_estimated > config.max_levels:
        raise ValueError("error the number of levels estimated is higher than the maximum level")

    for level in tqdm(range(1, num_levels_estimated + 2)):

        degree_intermediate = np.array(adjacency_intermediate.sum(axis=1)).squeeze() #convert to 1D to work with sp.diags_array after
        L_intermediate = sp.diags_array(degree_intermediate) - adjacency_intermediate
        if config.laplacian_norm == "normalized_self_loop":
            D_N = degree_intermediate + 1
            #L_propag would correspond to A+I normalized
        elif config.laplacian_norm == "normalized":
            D_N = degree_intermediate
            #L_propag would correspond to A normalized
        elif config.laplacian_norm == "combinatorial" :
            D_N = np.ones(adjacency_intermediate.shape[0])
            #L_propag would correspond to laplacian not normalized
        else :
            raise ValueError("error no laplacian preserved")
        
        D_N_inv = D_N**(-0.5)
        D_N_inv_matrix = sp.diags_array(D_N_inv)
        L_norm = D_N_inv_matrix @ L_intermediate @ D_N_inv_matrix
        if config.space_preserved == "eigenvectors_delta_in":
            delta = 1e-6
            L_norm = L_norm + delta * sp.eye(L_norm.shape[0], format='csr')

        
        # how much more we need to reduce the current graph (for maximum at this step)
        r_cur = np.clip(1 - (n_target / n), 0.0, config.max_level_r)
        #print(f"Level {level} : {n} nodes, target {n_target} nodes, ratio r_cur {r_cur:.4f}")
        time_begin_level = time.time()
        if level == 1:
            if config.laplacian_preserved != config.laplacian_norm:
                raise ValueError("error the laplacian preserved is supposed to be the same as the laplacian norm ")
            else :
                L_space_preserved = L_norm
            
            if R_files is not None:
                raise ValueError("error R_files is not None, this is not implemented yet")

            else :
                print("computing R")
                if config.laplacian_preserved == "combinatorial":
                    #settle omega value higher for shifting trick in eig decomposition
                    omega = 2 * (np.max(L_norm.diagonal()) + 1)
                    eigenvalues_preserved, eigenvectors_preserved = compute_sp_min_eig_fast(L_space_preserved, n_eig=K, omega=omega, matrix_symmetric=graph_symmetric)
                else :
                    eigenvalues_preserved, eigenvectors_preserved = compute_sp_min_eig_fast(L_space_preserved, n_eig=K, matrix_symmetric=graph_symmetric)
                mask = eigenvalues_preserved < 1e-5  # threshold just to ensure the 0 value, should mask N-K values
                eigenvalues_preserved[mask] = 1
                dinvsqrt = eigenvalues_preserved ** (-1 / 2)
                dinvsqrt[mask] = 0
                if config.space_preserved == "eigenvectors" or config.space_preserved == "eigenvectors_delta_in":
                    R = eigenvectors_preserved[:, :K]
                elif config.space_preserved == "eigenvectors_non_null":
                    R = eigenvectors_preserved[:,1:K+1]
                else:
                    raise ValueError("error space preserved not implemented")
                print(f"computing R done, time needed: {time.time() - init_time_2}", flush=True)
            R_sparse = sp.csr_matrix(R) 
            #R_sparse not cesserily usefull because R Dense 
            B0 = R_sparse @ sp.diags_array(dinvsqrt[:K])
            A_cost = B0
            B = B0
            copy_B0 = deepcopy(B0)
            B_intermediary = L_norm @ copy_B0
            P_rao_intermediary = deepcopy(L_space_preserved)
            degree_graph_original = deepcopy(D_N)

        else:
            
            if not (csr_is_symmetric(L_norm)):
                if np.max(np.abs(L_norm - L_norm.T)) > 1e-5:
                    print("warning L_normalized is not symmetric", flush=True)
                    raise ValueError("the max diff > 1e-5")
            product = B.T @ L_norm @ B 
            #convert to dense to use numpy
            product = product.todense()
            if graph_symmetric:
                d, V = np.linalg.eigh(product)
            else :
                print("the graph is not symmetric, we will use eig instead of eigh")
                d, V = np.linalg.eig(product)
                d = np.real(d)
                V = np.real(V)
            mask = d < 1e-5   
            d[mask] = 1
            dinvsqrt = d ** (-1 / 2)
            dinvsqrt[mask] = 0
            V = sp.csr_matrix(V)
            A_cost = B  @ V  @ sp.diags_array(dinvsqrt, format ='csr') @ V.T

        if config.method == "edges":
            A_cost_numpy = A_cost.toarray()
            coarsening_list, epsilon_step  = contract_variation_edges(adjacency_intermediate, r=r_cur, method_cost= config.method_cost,
                                                                A_cost=A_cost_numpy,
                                                                n_e_percent=config.n_e_percent,
                                                                n_e = config.n_e)
            
        elif config.method == "neighborhood":
            A_cost_numpy = A_cost.toarray()
            coarsening_list, epsilon_step = contract_variation_neighborhood(adjacency_intermediate, r=r_cur, method_cost= config.method_cost,
                                                                A_cost=A_cost_numpy,
                                                                n_e_percent=config.n_e_percent,
                                                                n_e = config.n_e)
        else:
            raise ValueError(f"config.method error not implemented which was {config.method}")

        if len(coarsening_list) == 0:
            break

        intermediate_Q_combi = get_lifting_matrix(adjacency_intermediate, coarsening_list) #full of one
        adjacency_intermediate_ended = coarsen_matrix_Laplacian(adjacency_intermediate, intermediate_Q_combi) 
        if config.delete_diagonal:  
            non_null_diag = adjacency_intermediate_ended.diagonal()
            adjacency_intermediate_ended = adjacency_intermediate_ended - sp.diags_array(non_null_diag)
        if graph_symmetric:
            #"enforce symmety"
            adjacency_intermediate = (adjacency_intermediate_ended + adjacency_intermediate_ended.T) / 2  # this is only needed to avoid complex eigenvalues and be sure of symmetry
        else :
            adjacency_intermediate = adjacency_intermediate_ended
            print("remember no symmetry enforced")


        degree_after_c = np.array(adjacency_intermediate.sum(axis=1)).squeeze()
        if config.laplacian_norm == "normalized_self_loop":
            D_n = degree_after_c + 1
            #L_propag would correspond to A+I normalized
        elif config.laplacian_norm == "normalized":
            D_n = degree_after_c
            #L_propag would correspond to A normalized
        elif config.laplacian_norm == "combinatorial" :
            D_n = np.ones(adjacency_intermediate.shape[0])
            #L_propag would correspond to laplacian not normalized
        D_N_sqrt = D_N**(0.5)
        D_N_sqrt_matrix = sp.diags_array(D_N_sqrt)
        D_n_inv_sqrt = D_n**(-0.5)
        D_n_inv_sqrt_matrix = sp.diags_array(D_n_inv_sqrt)
        intermediate_Q_norm = D_N_sqrt_matrix @ intermediate_Q_combi @ D_n_inv_sqrt_matrix
        Q_combi_lap = Q_combi_lap @ intermediate_Q_combi
        Q_lifting = Q_lifting @ intermediate_Q_norm

        
        
        P_intermediate = moore_penrose_lifting_matrix(intermediate_Q_norm)
        P_loukas = P_intermediate @ P_loukas

        P_rao_intermediary =  intermediate_Q_norm.T @ P_rao_intermediary

        P_intermediate_combi = moore_penrose_lifting_matrix(intermediate_Q_combi)
        B = P_intermediate_combi @ B

        n = adjacency_intermediate.shape[0]

        if n <= n_target:
            break

    laplacian_final = get_laplacian(adjacency_intermediate, laplacian_name=config.laplacian_preserved)
    laplacian_inv_final= inv_Laplacian(laplacian_final)
    P_rao_true = laplacian_inv_final @ P_rao_intermediary 
    P_mp = moore_penrose_lifting_matrix(Q_lifting)
    
    if original_features is not None:
        features_coarsened = P_loukas @ original_features
        print( "By default loukas based coarsened features")
    else:
        features_coarsened = None
    
    P_chosen = P_loukas
        
    R_sparse = sp.csr_matrix(R)

    config.method_name = "intermediate_adapt_loukas_normalized"
    print(f"Time needed for coarsening: {time.time() - init_time_2}")
    #print("R_sparse end", R_sparse, flush=True)
    
    Gc = CoarsenedGraph(csr_adj=adjacency_intermediate,P=P_chosen, Q=Q_lifting, features=features_coarsened,
                        R=R_sparse, method_name = "intermediate_adapt_loukas", method_config = config, test_mask = original_graph.test_mask,
                        val_mask = original_graph.val_mask, train_mask = original_graph.train_mask, labels = original_graph.labels,
                        P_rao = P_rao_true, P_mp = P_mp, P_loukas = P_loukas)
    return Gc

def csr_is_symmetric(matrix):
    matrix_transpose = matrix.transpose()
    return (matrix != matrix_transpose).nnz == 0


def moore_penrose_lifting_matrix(Q):
    """Compute the Moore-Penrose pseudo-inverse of a "well-partitionnned" lifting matrix.
    Args:
        Q (sparse csr matrix): lifting matrix
    Returns:
        sparse csr matrix: coarsening matrix
    """
    #compute Q^TQ which is supposed to be diagonal then invert it and multiply by Q^T

    Q_bis = deepcopy(Q)
    product = Q_bis.T @ Q_bis
    #check if product is diagonal, if nnz is equal to the number of rows
    if product.nnz != product.shape[0]:
        raise ValueError("The lifting matrix is not well-partitioned, Q^TQ is not diagonal")
    
    diagonal = product.diagonal()
    diagonal_inv = diagonal**(-1)
    matrix_diagonal_inv = sp.diags_array(diagonal_inv)
    return matrix_diagonal_inv @ Q_bis.T


def compute_sp_min_eig_fast(matrix_sparse : sp.csr_matrix, n_eig : int, omega : Optional[float] = 3, matrix_symmetric : Optional[bool] = True):
    """
    compute the K smallest eigenvalues and eigenvectors of a matrix
    by computing the largest eigenvalues of the negative matrix
    and reshift after
    it is quicker
    """
    shifted_matrix = omega * sp.eye(matrix_sparse.shape[0], format='csr') - matrix_sparse
    if matrix_symmetric:
        eigenvalues_shifted, eigenvectors = sp.linalg.eigsh(shifted_matrix, k=n_eig, which='LM')
    else:
        print("warning matrix not symmetric", flush=True)
        eigenvalues_shifted, eigenvectors = sp.linalg.eigs(shifted_matrix, k=n_eig, which='LM')
        eigenvalues_shifted = np.real(eigenvalues_shifted)
        eigenvectors = np.real(eigenvectors)
    eigenvalues = omega - eigenvalues_shifted
    eigenvalues = eigenvalues[::-1]
    eigenvectors = eigenvectors[:, ::-1]
    return eigenvalues, eigenvectors


def get_laplacian(adjacency, laplacian_name="combinatorial"):
    """Compute the Laplacian matrix of the graph.
    Args:
        adjacency (sparse csr matrix): adjacency matrix
        laplacian_name (str): type of Laplacian to compute. Can be 'combinatorial', 'normalized' or 'normalized_self_loop'
    Returns:
        sparse csr matrix: Laplacian matrix
    """
    if laplacian_name == "combinatorial":
        #print("Computing combinatorial Laplacian: D-A")
        degree_matrix = np.array(adjacency.sum(axis=1)).squeeze()
        laplacian_combinatorial = sp.diags_array(degree_matrix) - adjacency
        #the type of the laplacian is csr_matrix
        return laplacian_combinatorial
    elif laplacian_name == "normalized":
        #print("Computing normalized Laplacian D^(-0.5) * (D-A) * D^(-0.5)")
        degree_matrix = np.array(adjacency.sum(axis=1)).squeeze()
        #degree_matrix = np.diag(np.sum(self.csr_adj, axis=1))
        degree_matrix_inv = degree_matrix**(-0.5)
        degree_matrix_inv = sp.diags_array(degree_matrix_inv)
        normalized_laplacian = degree_matrix_inv @ (sp.diags_array(degree_matrix) -adjacency) @ degree_matrix_inv
        return normalized_laplacian
    elif laplacian_name == "normalized_self_loop":
        #print("Computing normalized Laplacian with self loop : (D+I)^(-0.5) (A+I) (D+I)^(-0.5)")
        degree_matrix = np.array(adjacency.sum(axis=1)).squeeze()
        #degree_matrix = np.diag(np.sum(self.csr_adj, axis=1))
        degree_matrix_self_loop = degree_matrix + 1
        degree_matrix_self_loop_inv = degree_matrix_self_loop**(-0.5)
        degree_matrix_self_loop_inv = sp.diags_array(degree_matrix_self_loop_inv)
        normalized_self_loop_laplacian = degree_matrix_self_loop_inv @ (sp.diags_array(degree_matrix) - adjacency) @ degree_matrix_self_loop_inv
        return normalized_self_loop_laplacian
    else:
        raise ValueError(f"Unknown Laplacian type: {laplacian_name}, consider using 'combinatorial', 'normalized' or 'normalized_self_loop' or implement the new type of Laplacian")
    

def inv_Laplacian(L):
    """Compute the inverse of the Laplacian matrix.
    Args:
        L (sparse csr matrix): Laplacian matrix
    Returns:
        sparse csr matrix: inverse of the Laplacian matrix
    """
    L_dense = L.todense()
    mask_zeros_L = L_dense != 0
    eig, eiv = np.linalg.eigh(L_dense)
    mask = eig < 1e-6  
    eig[mask] = 1
    dinvsqrt = eig ** (-1)
    dinvsqrt[mask] = 0
    L_inv = eiv  @ np.diag(dinvsqrt) @ eiv.T


    L_inv = sp.csr_matrix(L_inv)
    return L_inv




def contract_variation_edges(adjacency_normalized, r, method_cost, A_cost, n_e_percent=None, n_e = None, print_tqdm = False):
    """Contract a graph by removing edges with the lowest cost. The cost is computed using the method_cost. The number of edges to remove is determined by r and n_e.

    Args:
        adjacency_normalized (sparse csr matrix): normalized adjacency matrix
        r (float): ratio of nodes to keep
        method_cost (string): method to compute the cost
        A_cost (sparse matrix): cost matrix
        n_e_percent (float, optional): proportion max of nodes to remove. Defaults to 0.2.
    Returns:
        list of int: list of nodes merged together
        float: total cost of the coarsening
    """
    N = adjacency_normalized.shape[0]
    if n_e is None and n_e_percent is not None:
        n_e = int(n_e_percent * N) #adaptative percentage
        if n_e < 1:
            print("n_e is 0 with percentage we add 1", flush=True)
            n_e = 1
    elif n_e is None and n_e_percent is None:
        raise ValueError("both n_e and n_e_percent is None")
        #print("both n_e and n_e_percent is None", flush=True)
    degree_normalized = np.array(adjacency_normalized.sum(axis=1), dtype=np.float64 ).squeeze()
    
    N_edges = adjacency_normalized.nnz 
    family_candidate = SortedList(key=lambda x: -x[1]) #descending order to use pop
    def cost_function (edge):
        if method_cost == "Spectral":
            cost = subgraph_cost_spectral(edge,adjacency_normalized, A_cost, degree_normalized)
        elif method_cost == "Spectral_np":
            W_restrict = adjacency_normalized[edge, :][:, edge]
            W_restrict_array = W_restrict.toarray().astype(np.float64)
            W_restrict_array = np.ascontiguousarray(W_restrict_array)
            edge_array = np.array(edge)
            cost = subgraph_cost_spectral_np(edge_array,W_restrict_array, A_cost, degree_normalized)
        else :
            raise ValueError("error your method_cost is not implemented")
            #print("error your method_cost is not implemented", method_cost, flush=True)
        return edge, cost
    
    adjacency_coo = adjacency_normalized.tocoo()

    if print_tqdm:
        for i in tqdm(range(N_edges), desc="Computing edges costs"):
            index_row = adjacency_coo.row[i]
            index_col = adjacency_coo.col[i]
            if index_row < index_col: #because it is symmetric count only one time (upper triangle)
                family_candidate.add(cost_function([index_row, index_col]))
    else:
        for i in range(N_edges):
            index_row = adjacency_coo.row[i]
            index_col = adjacency_coo.col[i]
            if index_row < index_col:
                family_candidate.add(cost_function([index_row, index_col]))
           
    marked = np.zeros(N, dtype=bool)
    epsilon_step = 0
    coarsening_list = []
    n_reduce = int(N * r)
    edge_took = 0
    count_already_marked = 0
    marked_block = np.zeros(N)
    while (len(family_candidate) > 0) and (edge_took < n_e):
        edge, cost = family_candidate.pop()

        i_marked = marked[edge]
        
        if not any(i_marked):
            n_gain = 1 
            if n_gain > n_reduce:
                continue  # this helps avoid over-reducing 
            marked[edge] = True
            marked_block[edge[0]] += 1
            marked_block[edge[1]] += 1
            coarsening_list.append(edge)
            epsilon_step += cost * 2
            edge_took += 1
            n_reduce -= n_gain
            if n_reduce <= 0:
                break
        else: 
            if marked[edge[0]] :
                marked_block[edge[0]] += 1
            elif marked[edge[1]]:
                marked_block[edge[1]] += 1
            count_already_marked += 1
    return coarsening_list, epsilon_step



def contract_variation_neighborhood(adjacency_normalized, r, method_cost, A_cost, n_e_percent=None, n_e = None):
    """Contract a graph by removing edges with the lowest cost. The cost is computed using the method_cost. The number of edges to remove is determined by r and n_e.

    Args:
        adjacency_normalized (sparse csr matrix): normalized adjacency matrix
        r (float): ratio of nodes to keep
        method_cost (string): method to compute the cost
        A_cost (sparse matrix): cost matrix
        n_e (int, optional): number of edges to remove at maximum. Defaults to 10e5.
    Returns:
        list of int: list of nodes merged together
        float: total cost of the coarsening
    """
    degree_normalized = np.array(adjacency_normalized.sum(axis=1), dtype=np.float64 ).flatten()
    N = adjacency_normalized.shape[0]
    if n_e is None and n_e_percent is not None:
        n_e = int(n_e_percent * N) #adaptative percentage
        if n_e < 1:
            print("n_e is 0 with percentage we add 1")
            n_e = 1
    elif n_e is None and n_e_percent is None:
        raise ValueError("both n_e and n_e_percent is None", flush=True)
    
    family_candidate = SortedList(key=lambda x: -x[1]) #descending order to use pop
    def cost_function (node_set):
        if method_cost == "Spectral":
            cost = subgraph_cost_spectral(node_set,adjacency_normalized, A_cost, degree_normalized)
        elif method_cost == "Spectral_np":
            W_restrict = adjacency_normalized[node_set, :][:, node_set]
            W_restrict_array = W_restrict.toarray().astype(np.float64)
            node_array = np.array(node_set)
            cost = subgraph_cost_spectral_np(node_array,W_restrict_array, A_cost, degree_normalized)
        else :
            raise ValueError(f"error your method_cost is not implemented {method_cost}")
        return node_set, cost
    

    time_begin_loop = time.time()
    for i in range(N): #not tqdm too long
        neighbors_of_i = adjacency_normalized.indices[adjacency_normalized.indptr[i]:adjacency_normalized.indptr[i+1]]
        neighbors_of_i_and_i = np.append(neighbors_of_i, i)
        family_candidate.add(cost_function(neighbors_of_i_and_i))
    marked = np.zeros(N, dtype=bool)
    epsilon_step = 0
    coarsening_list = []
    n_reduce = int(N * r)
    edge_took = 0
    count_already_marked = 0
    while (len(family_candidate) > 0) and (edge_took < n_e):
        node_set, cost = family_candidate.pop()
        node_marked = marked[node_set]
        
        if not any(node_marked):
            n_gain = len(node_set) - 1
            if n_gain > n_reduce +10:
                continue  # this helps avoid over-reducing 
            marked[node_set] = True
            coarsening_list.append(node_set)
            epsilon_step += cost * len(node_set)
            edge_took += len(node_set) -1
            n_reduce -= n_gain
            if n_reduce <= 0:
                break
        else: 
            node_set = node_set[~node_marked]
            if len(node_set) > 1:
                new_set = cost_function(node_set)
                family_candidate.add(new_set)
            count_already_marked += 1
    return coarsening_list, epsilon_step



def subgraph_cost_spectral(nodes, W, A_cost, degree_matrix):
    """Compute the cost of a subgraph using the spectral method .
    Args:
        nodes (list of int): list of nodes of the subgraph
        W (csr matrix): normalized adjacency matrix
        A_cost (csr matrix): cost matrix
        degree_matrix (np array): degree matrix
    Returns:
        float: cost of the subgraph 
    """
    
    nc = len(nodes)
    if nc == 1:
        return 1e7 #to avoid singleton 
    ones = np.ones(nc)
    W_restrict = W[nodes, :][:, nodes] 
    L = sp.diags_array(np.array(2 * degree_matrix[nodes] - (W_restrict @ ones)).squeeze()) - W_restrict 
    B = sp.csr_matrix(np.eye(nc) - np.outer(ones, ones) / nc) @ A_cost[nodes,:]

    cost = sp.linalg.norm(B.T @ L @ B, ord='fro') / (nc - 1)
    return cost

@jit(nopython=True)
def frobenius_norm_manual(matrix):
    norm = 0
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            norm += matrix[i][j]**2

    return np.sqrt(norm)

@jit(nopython=True)
def subgraph_cost_spectral_np(nodes, W_restrict, A_cost, degree_matrix):
    nc = len(nodes)
    if nc == 1:
        return 1e7 #to avoid singleton 
    ones = np.ones(nc)
    A_cost_selected = A_cost[nodes,:]
    A_cost_selected = np.ascontiguousarray(A_cost_selected)
    
    B = (np.eye(nc) - np.outer(ones, ones) / nc) @ A_cost_selected
    L = np.diag((2 * (degree_matrix[nodes]) - (W_restrict @ ones)))
    cost = np.linalg.norm(np.dot(B.T, L @ B)) / (nc - 1) #frobenius norm by default
    return cost


def get_lifting_matrix(adjacency_intermediate, coarsening_list):
    """
    Get the coarsening matrix from the list of nodes merged together.
    It is consistent with other things than edges
    Args:
        adjacency_intermediate (sparse csr matrix): adjacency matrix
        partitioning (list of list of int): list of list of clusters of nodes

    Returns:
       sparse csr matrix: coarsening matrix intermediate
    """
    num_nodes = adjacency_intermediate.shape[0]
    Q = sp.eye(num_nodes, format='csr')
    mask_preserve = np.ones(num_nodes, dtype=bool)
    for subgraph in coarsening_list:
        Q[subgraph[0], subgraph] = 1.0 
        mask_preserve[subgraph[1:]] = False #erase the nodes that have been merged
    Q = Q[mask_preserve, :] #the restriction  ensure the good dimension of the matrix to be n*N
    Q = Q.T
    Q = Q.tocsr()
    return Q

def coarsen_matrix_Laplacian(A_intermediate, Q_intermediate):
    """Given an well partitionned matrix Q, compute the coarsened Laplacian matrix.

    Args:
        A_intermediate (csr matrix): intermediate Laplacian matrix
        Q_intermediate (csr matrix): intermediate lifting matrix
    Returns:
        csr matrix: coarsened Laplacian matrix
    """
    A_intermediate_coarsened = Q_intermediate.T @ A_intermediate @ Q_intermediate
    return A_intermediate_coarsened