""" Functions for processing edges and edge features in a graph """

from collections import defaultdict
from typing import List
import numpy as np
import numpy.typing as npt

FORWARD_TOKEN = "<f>"
BACKWARD_TOKEN = "<b>"


def reindex_edge_index(
    node_idxs: npt.NDArray[np.int_],
    edge_index: npt.NDArray[np.int_]
) -> npt.NDArray[np.int_]:
    """ Returns a reindexed adjacency matrix based on the node ordering described in the list
        node_idxs
    """
    idx2idx = {node_idx: idx for idx, node_idx in enumerate(node_idxs)}
    idx2idx_vect = np.vectorize(idx2idx.get)
    return idx2idx_vect(edge_index)


def get_disconnected_edge_indexes(edge_index: npt.NDArray[np.int_]) -> List[npt.NDArray[np.int_]]:
    """ Returns a list of groups of edges, each group is a disconnected
        subgraph in the graph described by the original edge index
    """
    max_idx =  edge_index.max()
    node_colors = -1 * np.ones((max_idx + 1)).astype(int)
    np.maximum.at(node_colors, edge_index[:, 1], np.arange(max_idx + 1)[edge_index[:, 0]])
    while not all(node_colors[edge_index[:, 0]] == node_colors[edge_index[:, 1]]):
        np.maximum.at(node_colors, edge_index[:, 1], node_colors[edge_index[:, 0]])
        np.maximum.at(node_colors, edge_index[:, 0], node_colors[edge_index[:, 1]])
    subgraphs = defaultdict(list)
    for edge in edge_index:
        subgraphs[node_colors[edge[0]]].append(edge.reshape(1, -1))
    return [np.concatenate(edges, axis=0) for edges in subgraphs.values()]
