from collections import namedtuple
from typing import Union

import numpy as np
import scipy.sparse as sp

ego_graph_nodes_edges = namedtuple('ego_graph', ['nodes', 'edges'])

__all__ = ['ego_graph']


def ego_graph(adj_matrix: sp.csr_matrix, targets: Union[int, list],
              hops: int = 1) -> ego_graph_nodes_edges:
    """Returns induced subgraph of neighbors centered at node n within
    a given radius.

    Parameters
    ----------
    adj_matrix : sp.csr_matrix,
        a Scipy CSR sparse adjacency matrix representing a graph
    targets : Union[int, list]
        center nodes, a single node or a list of nodes
    hops : int number, optional
        Include all neighbors of distance<=hops from nodes.

    Returns
    -------
    NamedTuple(nodes, edges):
        nodes: shape [N], the nodes of the subgraph
        edges: shape [2, M], the edges of the subgraph

    Note
    ----
    This is a faster implementation of
    :class:`networkx.ego_graph` based on scipy sparse matrix and numba


    See Also
    --------
    :class:`networkx.ego_graph`
    :class:`torch_geometric.utils.k_hop_subgraph`

    """
    fn = get_numbafn()
    assert sp.issparse(adj_matrix)
    adj_matrix = adj_matrix.tocsr(copy=False)

    if np.ndim(targets) == 0:
        targets = [targets]
    elif isinstance(targets, np.ndarray):
        targets = targets.tolist()
    else:
        targets = list(targets)

    indices = adj_matrix.indices
    indptr = adj_matrix.indptr

    edges = {}
    start = 0
    N = adj_matrix.shape[0]
    seen = np.zeros(N) - 1
    seen[targets] = 0
    for level in range(hops):
        end = len(targets)
        while start < end:
            head = targets[start]
            nbrs = indices[indptr[head]:indptr[head + 1]]
            for u in nbrs:
                if seen[u] < 0:
                    targets.append(u)
                    seen[u] = level + 1
                if (u, head) not in edges:
                    edges[(head, u)] = level + 1

            start += 1

    if len(targets[start:]):
        e = fn(indices, indptr, np.array(targets[start:]), seen, hops)
    else:
        e = []

    return ego_graph_nodes_edges(nodes=np.asarray(targets),
                                 edges=np.asarray(list(edges.keys()) + e).T)


def get_numbafn():
    from numba import njit, types
    from numba.typed import Dict

    @njit
    def _get_remaining_edges(indices: np.ndarray, indptr: np.ndarray,
                             last_level: np.ndarray, seen: np.ndarray,
                             hops: int) -> list:
        edges = []
        mapping = Dict.empty(
            key_type=types.int64,
            value_type=types.int64,
        )
        for u in last_level:
            nbrs = indices[indptr[u]:indptr[u + 1]]
            nbrs = nbrs[seen[nbrs] == hops]
            mapping[u] = 1
            for v in nbrs:
                if v not in mapping:
                    edges.append((u, v))
        return edges

    return _get_remaining_edges
