


from numpy.linalg import norm

import torch
from torch_geometric.utils import degree, to_dense_adj

# norm of the projection onto the orthogonal of pi
def smoothing_metric(X, pi):
    pi_norm = pi/torch.sqrt((pi**2).sum())
    #X_norm = X/np.sqrt((X**2).sum(axis=0))[None, :]
    return norm(X - pi_norm[:,None] @ ((X.T@pi_norm)[None,:]), ord='fro')

def compute_randomwalk(data):
     edge_index = data.edge_index
     row, col = edge_index
     x = data.x
     deg = degree(col, x.size(0))
     pi = deg/deg.sum()
     adj = to_dense_adj(edge_index)[0]
     P = adj/deg[:,None]
     
     return P, pi