import torch
from torch_geometric import utils
import numpy as np


class LaplacianEncoding(object):
    SPARSE_THRESHOLD: int = 200

    def __init__(
        self,
        dim,
        normalization='sym',
        is_undirected=True,
        sign_flip=True,
        attr_name="edge_index",
    ):
        self.dim = dim
        self.normalization = normalization
        self.is_undirected = is_undirected
        self.sign_flip = sign_flip
        self.attr_name = attr_name

    def __call__(self, data):
        if self.dim is None or self.dim <= 0:
            return data
        num_nodes = data.num_nodes
        k = min(num_nodes - 1, self.dim)
        edge_index, edge_weight = utils.get_laplacian(
            getattr(data, self.attr_name, "edge_index"),
            data.edge_weight,
            normalization=self.normalization,
            num_nodes=num_nodes
        )
        L = utils.to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)

        if num_nodes < self.SPARSE_THRESHOLD:
            from numpy.linalg import eig, eigh
            eig_fn = eig if not self.is_undirected else eigh

            eig_vals, eig_vecs = eig_fn(L.toarray())
        else:
            from scipy.sparse.linalg import eigs, eigsh
            eig_fn = eigs if not self.is_undirected else eigsh

            eig_vals, eig_vecs = eig_fn(  # type: ignore
                L,
                k=self.dim,
                which='LM' if not self.is_undirected else 'LA',
                return_eigenvectors=True,
            )

        idx = eig_vals.argsort()
        eig_vals = np.real(eig_vals[idx])[1:]
        eig_vecs = np.real(eig_vecs[:, idx])[:, 1:]
        pe = torch.zeros((num_nodes, self.dim))
        pe[:, :k] = torch.from_numpy(eig_vecs)[:, :k]
        if self.sign_flip:
            sign = -1 + 2 * torch.randint(0, 2, (self.dim, ))
            pe *= sign
        # data.pos_encoding = pe.float()
        data.EigVecs = pe.float()
        eig_vals_tmp = torch.zeros((1, self.dim))
        eig_vals_tmp[0, :k] = torch.from_numpy(eig_vals[:k]).float()
        data.EigVals = eig_vals_tmp.repeat(num_nodes, 1)
        return data
