import omegaconf
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torch_geometric.utils import get_laplacian, to_undirected, to_scipy_sparse_matrix
import logging

logger = logging.getLogger(__name__)

SPARSE_THRESHOLD = 100


def compute_pos_enc_statistics(data, cfg, is_undirected):
    """
    Compute the Laplacian matrix and the eigenvectors of the graph.
    """
    use_undirected = True
    if cfg.dataset.eigen.make_undirected and not is_undirected:
        data.edge_index = to_undirected(data.edge_index)

    if not cfg.dataset.eigen.make_undirected and not is_undirected:
        use_undirected = False

    # Convert the list of tuples to a PyTorch tensor of type long, then transpose
    num_nodes = data.edge_index.max().item() + 1

    # Compute the Laplacian matrix
    laplacian, laplacian_edge_weight = get_laplacian(data.edge_index, normalization=cfg.dataset.eigen.laplacian_norm)

    # Convert edge_index to a scipy sparse matrix, assuming unweighted graph
    L = to_scipy_sparse_matrix(laplacian, laplacian_edge_weight, num_nodes=num_nodes)

    if isinstance(cfg.dataset.eigen.max_freqs, int):
        if cfg.dataset.eigen.max_freqs == -1:
            cfg.dataset.eigen.max_freqs = num_nodes

        if num_nodes < SPARSE_THRESHOLD:
            # Small graphs should use numpy
            from numpy.linalg import eig, eigh

            eig_fn = eig if not use_undirected else eigh
            eigenvals, eigenvecs = eig_fn(L.todense())

        elif cfg.dataset.eigen.max_freqs < num_nodes:
            # from scipy.sparse.linalg import eigs, eigsh
            # eig_fn = eigs if not use_undirected else eigsh
            # eigenvals, eigenvecs = eig_fn(
            #    L,
            #    k=cfg.dataset.eigen.max_freqs,
            #    which="SR" if not use_undirected else "SA",
            #    return_eigenvectors=True,
            # )
            raise NotImplementedError()
        else:
            if use_undirected:
                from src.slepc_utils import SLEPc_eigsolve

                eigenvals, eigenvecs = SLEPc_eigsolve(
                    L,
                    k=cfg.dataset.eigen.max_freqs,
                    which="smallest",
                )
                if eigenvals is None:
                    exit()
                # eigenvals, eigenvecs = scipy.linalg.eigh(L.toarray(), driver="evd")
            else:
                raise NotImplementedError
    elif isinstance(cfg.dataset.eigen.max_freqs, omegaconf.listconfig.ListConfig) or isinstance(
        cfg.dataset.eigen.max_freqs, list
    ):
        assert len(cfg.dataset.eigen.max_freqs) == 2
        # 2 list elements: [left-side-count, right-side-coutn]
        from scipy.sparse.linalg import eigs, eigsh
        from src.slepc_utils import SLEPc_eigsolve

        if not use_undirected:
            raise NotImplementedError
        low_k, high_k = cfg.dataset.eigen.max_freqs
        logging.info(f"Computing lower {low_k} eigenvectors")
        # eigenvals_low, eigenvecs_low = eigsh(L, k=low_k, which="SA", tol=1e-4)
        eigenvals_low, eigenvecs_low = SLEPc_eigsolve(L, k=low_k, which="smallest")
        logging.info(f"Computing higher {high_k} eigenvectors")
        # eigenvals_high, eigenvecs_high = eigsh(L, k=high_k, which="LA", tol=1e-4)
        eigenvals_high, eigenvecs_high = SLEPc_eigsolve(L, k=high_k, which="largest")
        logging.info("Done")

        if eigenvals_low is None:
            exit()

        eigenvals = np.concatenate((eigenvals_low, eigenvals_high), axis=-1)
        eigenvecs = np.concatenate((eigenvecs_low, eigenvecs_high), axis=-1)
    else:
        raise ValueError(f"Unknown config {cfg.dataset.eigen.max_freqs=}")

    # assert np.isclose(L @ eigenvecs, eigenvals[None, :] * eigenvecs).all()

    data.eigenvals, data.eigenvecs = get_lap_decomp_stats(
        eigenvals, eigenvecs, cfg.dataset.eigen.max_freqs, cfg.dataset.eigen.eigvec_norm
    )

    assert np.isfinite(data.eigenvals).all()
    assert np.isfinite(data.eigenvecs).all()

    return data


def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm="L2"):
    """Compute Laplacian eigen-decomposition-based PE stats of the given graph.

    Args:
        evals, evects: Precomputed eigen-decomposition
        max_freqs: Maximum number of top smallest frequencies / eigenvecs to use
        eigvec_norm: Normalization for the eigen vectors of the Laplacian
    Returns:
        Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node
        Tensor (num_nodes, max_freqs) of eigenvector values per node
    """
    N = len(evects)  # Number of nodes, including disconnected nodes.

    if isinstance(max_freqs, omegaconf.listconfig.ListConfig) or isinstance(max_freqs, list):
        assert len(max_freqs) == 2
        max_freqs = max_freqs[0] + max_freqs[1]

    # Keep up to the maximum desired number of frequencies.
    idx = evals.argsort()[:max_freqs]
    evals, evects = evals[idx], np.real(evects[:, idx])
    evals = torch.from_numpy(np.real(evals)).clamp_min(0)

    # Normalize and pad eigen vectors.
    evects = torch.from_numpy(evects).float()
    evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm)
    if N < max_freqs:
        evects = F.pad(evects, (0, max_freqs - N), value=float("nan"))

    # Pad and save eigenvalues.
    if N < max_freqs:
        evals = F.pad(evals, (0, max_freqs - N), value=float("nan"))
    return evals, evects


def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12):
    """
    Implement different eigenvector normalizations.
    """

    EigVals = EigVals.unsqueeze(0)

    if normalization == "L1":
        # L1 normalization: eigvec / sum(abs(eigvec))
        denom = EigVecs.norm(p=1, dim=0, keepdim=True)

    elif normalization == "L2":
        # L2 normalization: eigvec / sqrt(sum(eigvec^2))
        denom = EigVecs.norm(p=2, dim=0, keepdim=True)

    elif normalization == "abs-max":
        # AbsMax normalization: eigvec / max|eigvec|
        denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values

    elif normalization == "wavelength":
        # AbsMax normalization, followed by wavelength multiplication:
        # eigvec * pi / (2 * max|eigvec| * sqrt(eigval))
        denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = denom * eigval_denom * 2 / np.pi

    elif normalization == "wavelength-asin":
        # AbsMax normalization, followed by arcsin and wavelength multiplication:
        # arcsin(eigvec / max|eigvec|)  /  sqrt(eigval)
        denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs)
        EigVecs = torch.asin(EigVecs / denom_temp)
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = eigval_denom

    elif normalization == "wavelength-soft":
        # AbsSoftmax normalization, followed by wavelength multiplication:
        # eigvec / (softmax|eigvec| * sqrt(eigval))
        denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True)
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = denom * eigval_denom

    else:
        raise ValueError(f"Unsupported normalization `{normalization}`")

    denom = denom.clamp_min(eps).expand_as(EigVecs)
    EigVecs = EigVecs / denom

    return EigVecs


def pre_transform_in_memory(dataset, transform_func, show_progress=False):
    """Pre-transform already loaded PyG dataset object.

    Apply transform function to a loaded PyG dataset object so that
    the transformed result is persistent for the lifespan of the object.
    This means the result is not saved to disk, as what PyG's `pre_transform`
    would do, but also the transform is applied only once and not at each
    data access as what PyG's `transform` hook does.

    Implementation is based on torch_geometric.data.in_memory_dataset.copy

    Args:
        dataset: PyG dataset object to modify
        transform_func: transformation function to apply to each data example
        show_progress: show tqdm progress bar
    """
    if transform_func is None:
        return dataset

    data_list = [
        transform_func(dataset.get(i))
        for i in tqdm(
            range(len(dataset)),
            disable=not show_progress,
            mininterval=10,
            miniters=len(dataset) // 20,
        )
    ]
    data_list = list(filter(None, data_list))

    dataset._indices = None
    dataset._data_list = data_list
    dataset.data, dataset.slices = dataset.collate(data_list)


def typecast_x(data, type_str):
    if type_str == "float":
        data.x = data.x.float()
    elif type_str == "long":
        data.x = data.x.long()
    else:
        raise ValueError(f"Unexpected type '{type_str}'.")
    return data
