import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.utils import (
    get_laplacian,
    to_scipy_sparse_matrix,
    to_undirected,
)

SPARSE_THRESHOLD: int = 100


def compute_posenc_stats(data, pe_types, is_undirected, cfg):
    """Precompute positional encodings for the given graph.

    Supported PE statistics to precompute, selected by `pe_types`:
    'SignNet': Laplacian eigen-decomposition.
    Args:
        data: PyG graph
        pe_types: Positional encoding types to precompute statistics for.
            This can also be a combination, e.g. 'eigen+rw_landing'
        is_undirected: True if the graph is expected to be undirected
        cfg: Main configuration node

    Returns:
        Extended PyG Data object.
    """
    # Verify PE types.
    for t in pe_types:
        if t not in [
            "SignNet",
        ]:
            raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}")

    # Basic preprocessing of the input graph.
    if hasattr(data, "num_nodes"):
        N = data.num_nodes  # Explicitly given number of nodes, e.g. ogbg-ppa
    else:
        N = data.x.shape[0]  # Number of nodes, including disconnected nodes.
    laplacian_norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower()
    if laplacian_norm_type == "none":
        laplacian_norm_type = None
    if is_undirected:
        undir_edge_index = data.edge_index
        undirected_edge_attr = data.edge_attr
    else:
        undir_edge_index, undirected_edge_attr = to_undirected(
            data.edge_index, edge_attr=data.edge_attr
        )
        is_undirected = True

    # Check if edge_attr exists
    if hasattr(data, "edge_attr"):
        if data.edge_attr is not None:
            if len(data.edge_attr.shape) == 1:
                edge_weight = undirected_edge_attr
            elif data.edge_attr.shape[1] == 1:
                edge_weight = undirected_edge_attr
            else:
                edge_weight = undirected_edge_attr
                # Check if it's one-hot encoded
                if (
                    torch.all(
                        (edge_weight.sum(dim=1) == 1) | (edge_weight.sum(dim=1) == 0)
                    )
                    and torch.all((edge_weight == 0) | (edge_weight == 1))
                    and edge_weight.shape[1] > 0
                ):
                    # Convert one-hot to categorical
                    edge_weight = torch.argmax(edge_weight, dim=1).float()
                    # print("Converting one hot to catagorical")
                else:
                    # print("Multi dim features")
                    edge_weight = None
        else:
            edge_weight = None
    else:
        edge_weight = None

    # Eigen values and vectors.
    evals, evects = None, None

    # Eigen-decomposition with numpy for SignNet.
    norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower()
    if norm_type == "none":
        norm_type = None
    L = to_scipy_sparse_matrix(
        *get_laplacian(
            undir_edge_index,
            normalization=norm_type,
            num_nodes=N,
            edge_weight=edge_weight,
        )
    )

    if N < SPARSE_THRESHOLD:
        from numpy.linalg import eig, eigh

        eig_fn = eig if not is_undirected else eigh
        evals_sn, evects_sn = eig_fn(L.todense())
    else:
        from scipy.sparse.linalg import eigs, eigsh

        eig_fn = eigs if not is_undirected else eigsh

        evals_sn, evects_sn = eig_fn(
            L,
            k=cfg.posenc_SignNet.eigen.max_freqs + 1,
            which="SR" if not is_undirected else "SA",
            return_eigenvectors=True,
            # maxiter=10000000,
            # tol=5e-2,
        )

    data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats(
        evals=evals_sn,
        evects=evects_sn,
        max_freqs=cfg.posenc_SignNet.eigen.max_freqs,
        eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm,
    )

    data.pos_type = torch.full((len(data.eigvecs_sn),), 1, dtype=torch.long)

    return data


def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm="L2"):
    """Compute Laplacian eigen-decomposition-based PE stats of the given graph.

    Args:
        evals, evects: Precomputed eigen-decomposition
        max_freqs: Maximum number of top smallest frequencies / eigenvecs to use
        eigvec_norm: Normalization for the eigen vectors of the Laplacian
    Returns:
        Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node
        Tensor (num_nodes, max_freqs) of eigenvector values per node
    """
    N = len(evals)  # Number of nodes, including disconnected nodes.

    # Keep up to the maximum desired number of frequencies.
    idx = evals.argsort()[:max_freqs]
    evals, evects = evals[idx], np.real(evects[:, idx])
    evals = torch.from_numpy(np.real(evals)).clamp_min(0)

    # Normalize and pad eigen vectors.
    evects = torch.from_numpy(evects).float()
    evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm)
    if N < max_freqs:
        EigVecs = F.pad(evects, (0, max_freqs - N), value=float("nan"))
    else:
        EigVecs = evects

    # Pad and save eigenvalues.
    if N < max_freqs:
        EigVals = F.pad(evals, (0, max_freqs - N), value=float("nan")).unsqueeze(0)
    else:
        EigVals = evals.unsqueeze(0)
    EigVals = EigVals.repeat(N, 1).unsqueeze(2)
    return EigVals, EigVecs


def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12):
    """
    Implement different eigenvector normalizations.
    """

    EigVals = EigVals.unsqueeze(0)

    if normalization == "L1":
        # L1 normalization: eigvec / sum(abs(eigvec))
        denom = EigVecs.norm(p=1, dim=0, keepdim=True)

    elif normalization == "L2":
        # L2 normalization: eigvec / sqrt(sum(eigvec^2))
        denom = EigVecs.norm(p=2, dim=0, keepdim=True)

    elif normalization == "abs-max":
        # AbsMax normalization: eigvec / max|eigvec|
        denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values

    elif normalization == "wavelength":
        # AbsMax normalization, followed by wavelength multiplication:
        # eigvec * pi / (2 * max|eigvec| * sqrt(eigval))
        denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = denom * eigval_denom * 2 / np.pi

    elif normalization == "wavelength-asin":
        # AbsMax normalization, followed by arcsin and wavelength multiplication:
        # arcsin(eigvec / max|eigvec|)  /  sqrt(eigval)
        denom_temp = (
            torch.max(EigVecs.abs(), dim=0, keepdim=True)
            .values.clamp_min(eps)
            .expand_as(EigVecs)
        )
        EigVecs = torch.asin(EigVecs / denom_temp)
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = eigval_denom

    elif normalization == "wavelength-soft":
        # AbsSoftmax normalization, followed by wavelength multiplication:
        # eigvec / (softmax|eigvec| * sqrt(eigval))
        denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(
            dim=0, keepdim=True
        )
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = denom * eigval_denom

    else:
        raise ValueError(f"Unsupported normalization `{normalization}`")

    denom = denom.clamp_min(eps).expand_as(EigVecs)
    EigVecs = EigVecs / denom

    return EigVecs
