
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_scipy_sparse_matrix, degree
from scipy.sparse.csgraph import shortest_path


def landmark_spd_features(data: Data, num_landmarks: int = 16) -> Data:

    deg = degree(data.edge_index[0], num_nodes=data.num_nodes)
    landmarks = torch.topk(deg, k=num_landmarks, largest=True).indices

    adj_matrix = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes).tocsr()
    dist_matrix = shortest_path(adj_matrix, directed=False, unweighted=True, indices=landmarks.cpu().numpy())
    dist_matrix = torch.from_numpy(dist_matrix).float().t().contiguous()

    is_inf = torch.isinf(dist_matrix)
    if is_inf.any():
        finite_max = torch.nan_to_num(dist_matrix.masked_fill(~is_inf, 0.).max(), nan=0.0)
        fill_val = max(data.num_nodes, int(finite_max.item()) + 1)
        dist_matrix[is_inf] = float(fill_val)

    position_encoding = 1.0 / (1.0 + dist_matrix)
    mean = position_encoding.mean(dim=0, keepdim=True)
    std = position_encoding.std(dim=0, unbiased=False, keepdim=True).clamp_min(1e-12)
    position_encoding = (position_encoding - mean) / std

    position_encoding = position_encoding.to(data.x.device, dtype=data.x.dtype)
    data.x = torch.cat([data.x, position_encoding], dim=-1)

    return data
