import numpy as np
import torch
from scipy import sparse as sp



def edge_index_to_csr(edge_index: np.ndarray, num_nodes: int) -> sp.csr_matrix:
    # edge_index: shape (2, E) numpy array (u, v)
    u = edge_index[0].astype(np.int64)
    v = edge_index[1].astype(np.int64)
    data = np.ones_like(u, dtype=np.uint8)
    A = sp.csr_matrix((data, (u, v)), shape=(num_nodes, num_nodes))
    return A


def k_hop_pairs_batch(
    edge_index: torch.Tensor,
    num_nodes: int,
    k: int,
    undirected: bool = False,
) -> torch.Tensor:
    """
    Compute for each node the set of nodes within <= k hops (excluding the node itself),
    returning (sources_array, targets_array) pair lists in batches to avoid huge memory spikes.

    Returns a list of tuples; each tuple contains two 1D numpy arrays of equal length
    representing (src_nodes, reachable_nodes) for a block of source nodes.
    """
    device = edge_index.device

    # Convert to numpy
    edge_index = edge_index.detach().to(torch.long).cpu()
    edge_index = np.ascontiguousarray(edge_index.numpy())

    A = edge_index_to_csr(edge_index, num_nodes)
    if undirected:
        A = ((A + A.T) > 0).astype(np.uint8)

    # frontier: shape (N, b) boolean; initially one-hot per seed
    # store as uint8 to be memory-efficient
    frontier = np.zeros((num_nodes, num_nodes), dtype=np.uint8)
    rows = np.arange(num_nodes, dtype=np.int64)
    cols = np.arange(num_nodes, dtype=np.int64)
    frontier[rows, cols] = 1

    visited = frontier.copy()

    for _ in range(k):
        # multiply: next_frontier = (A @ frontier) > 0
        # Note: A is csr, frontier is dense small (N x b). This is fast in C.
        nxt = (A.dot(frontier) > 0).astype(np.uint8)
        # remove already visited
        nxt = nxt & (~visited)
        if not nxt.any():
            break
        visited |= nxt
        frontier = nxt

    # visited includes the seed node itself; remove diag (exclude seed)
    visited[rows, cols] = 0

    # extract pairs: for each column j, find indices where visited[:, j] == 1
    src_list = []
    dst_list = []
    # vectorized extraction:
    for i, src_node in enumerate(range(num_nodes)):
        targets = np.nonzero(visited[i, :])[0]
        if targets.size:
            src_list.append(np.full(targets.shape, src_node, dtype=np.int64))
            dst_list.append(targets.astype(np.int64))

    if src_list:
        src_all = np.concatenate(src_list)
        dst_all = np.concatenate(dst_list)
    else:
        src_all = np.array([], dtype=np.int64)
        dst_all = np.array([], dtype=np.int64)

    edge_index_output = np.vstack((src_all.astype(np.int64), dst_all.astype(np.int64)))
    edge_index_output = torch.from_numpy(np.ascontiguousarray(edge_index_output)).long().to(device)

    return edge_index_output


def k_hop_pairs_batch_sparse(
    edge_index: torch.Tensor,
    num_nodes: int,
    k: int,
    batch_size: int = 4096,
    undirected: bool = False,
) -> torch.Tensor:
    """
    Memory-efficient computation of all (src, dst) pairs within <= k hops.
    Works for graphs with cross-batch edges. Never allocates dense NxN arrays.

    Args:
        edge_index: (2, E) torch.LongTensor
        num_nodes: int, total number of nodes
        k: int, max hop distance
        batch_size: int, how many source nodes to process in one batch
        undirected: bool, whether to symmetrize the graph

    Returns:
        edge_index_output: torch.LongTensor of shape (2, M)
    """
    device = edge_index.device
    edge_index = edge_index.detach().to(torch.long).cpu().numpy()

    # Sparse adjacency
    A = edge_index_to_csr(edge_index, num_nodes)
    if undirected:
        A = ((A + A.T) > 0).astype(np.uint8)
    indptr, indices = A.indptr, A.indices

    src_all = []
    dst_all = []

    # Process nodes in manageable batches to avoid Python overhead
    for start in range(0, num_nodes, batch_size):
        end = min(start + batch_size, num_nodes)
        batch_nodes = range(start, end)

        for src in batch_nodes:
            visited = np.zeros(num_nodes, dtype=bool)
            visited[src] = True
            frontier = [src]

            for _ in range(k):
                next_frontier = []
                for node in frontier:
                    # Iterate over all neighbors
                    for nbr in indices[indptr[node] : indptr[node + 1]]:
                        if not visited[nbr]:
                            visited[nbr] = True
                            next_frontier.append(nbr)
                            src_all.append(src)
                            dst_all.append(nbr)
                if not next_frontier:
                    break
                frontier = next_frontier

    if len(src_all) == 0:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    edge_index_output = torch.tensor([src_all, dst_all], dtype=torch.long, device=device)
    return edge_index_output
