import numpy as np
import torch
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import (
    get_laplacian,
    to_scipy_sparse_matrix,
)
from torch_geometric.utils.num_nodes import maybe_num_nodes
from scipy.sparse.linalg import eigs, eigsh


def get_edge_weight(data):
    if hasattr(data, 'edge_weight'):
        return data.edge_weight
    elif hasattr(data, 'edge_attr'):
        return data.edge_attr
    else:
        return None


def laplacian_eigenvector_pe(data, k: int, is_undirected: bool = False):
    eig_fn = eigs if not is_undirected else eigsh
    num_nodes = data.num_nodes

    edge_index, edge_weight = get_laplacian(
        data.edge_index,
        edge_weight=get_edge_weight(data),
        normalization='sym',
        num_nodes=num_nodes,
    )
    L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)
    eig_vals, eig_vecs = eig_fn(
        L,
        k=k + 1,
        which='SR' if not is_undirected else 'SA',
        return_eigenvectors=True
    )

    eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()])
    pe = torch.from_numpy(eig_vecs[:, 1: k + 1])
    sign = -1 + 2 * torch.randint(0, 2, (k, ))
    pe *= sign
    return pe



def random_walk_pe(data, walk_length):
    num_nodes = data.num_nodes
    edge_index, edge_weight = data.edge_index, get_edge_weight(data)

    adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=(num_nodes, num_nodes))
    # Compute D^{-1} A:
    deg_inv = 1.0 / adj.sum(dim=1)
    deg_inv[deg_inv == float('inf')] = 0
    adj = adj * deg_inv.view(-1, 1)

    out = adj
    row, col, value = out.coo()
    pe_list = [get_self_loop_attr((row, col), value, num_nodes)]
    for _ in range(walk_length - 1):
        out = out @ adj
        row, col, value = out.coo()
        pe_list.append(get_self_loop_attr((row, col), value, num_nodes))
    pe = torch.stack(pe_list, dim=-1)
    return pe



def get_self_loop_attr(edge_index, edge_attr= None, num_nodes=None):
    loop_mask = edge_index[0] == edge_index[1]
    loop_index = edge_index[0][loop_mask]

    if edge_attr is not None:
        loop_attr = edge_attr[loop_mask]
    else:  # A vector of ones:
        loop_attr = torch.ones_like(loop_index, dtype=torch.float)

    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:])
    full_loop_attr[loop_index] = loop_attr

    return full_loop_attr