from ._meta_graph import MetaEdge, MetaGraph, MetaNode


def compact_meta_graph(meta_graph: MetaGraph, num_meta_nodes_in_new_graph: int) -> MetaGraph:
    meta_nodes_merged = _merge_nodes(meta_graph.nodes, num_meta_nodes_in_new_graph)

    # TODO the original implementation claims that now all meta nodes point directly to their root. Check this
    meta_edges_tmp: list[MetaEdge] = []
    # TODO probably don't need to reverse this
    # TODO can probably do this with list comprehension by extracting a function
    for meta_edge in reversed(meta_graph.edges):
        parent_a = meta_edge.nodes[0].disjoint_set_parent
        parent_b = meta_edge.nodes[1].disjoint_set_parent

        if parent_a != parent_b:
            nodes = (parent_a, parent_b) if parent_a.name < parent_b.name else (parent_b, parent_a)
            new_meta_edge = MetaEdge(nodes, weight=meta_edge.weight, steering_weight=meta_edge.steering_weight)
            meta_edges_tmp.append(new_meta_edge)

    # TODO find better names for meta_edges_tmp, meta_edges_tmp_2, meta_edges_tmp_3, meta_edges_out
    meta_edges_tmp_2 = _sort_meta_edges(meta_edges_tmp, num_meta_nodes_in_new_graph, meta_edge_side=1)
    meta_edges_tmp_3 = _sort_meta_edges(meta_edges_tmp_2, num_meta_nodes_in_new_graph, meta_edge_side=0)

    # merge edges with the same endpoints
    meta_edges_out = [meta_edges_tmp_3[-1]]
    for meta_edge in reversed(meta_edges_tmp_3[:-1]):
        if meta_edges_out[-1].nodes == meta_edge.nodes:
            meta_edges_out[-1].weight += meta_edge.weight
            meta_edges_out[-1].steering_weight += meta_edge.steering_weight
        else:
            meta_edges_out.append(meta_edge)

    return MetaGraph(edges=meta_edges_out, nodes=meta_nodes_merged)


def _merge_nodes(meta_nodes: list[MetaNode], num_meta_nodes_in_new_graph: int) -> list[MetaNode]:
    meta_nodes_merged: list[MetaNode] = []

    for meta_node in meta_nodes:
        if meta_node.is_root():
            num_meta_nodes_in_new_graph -= 1
            meta_node.counting_sort_index = num_meta_nodes_in_new_graph
            # TODO can probably use append here
            meta_nodes_merged.insert(0, meta_node)

    assert num_meta_nodes_in_new_graph == 0

    # TODO can include this as else block under the if above
    for meta_node in meta_nodes:
        meta_node.counting_sort_index = meta_node.find_root().counting_sort_index

    return meta_nodes_merged


# called "CountingSort" in the original implementation by Chekuri et al.
def _sort_meta_edges(meta_edges: list[MetaEdge], num_meta_nodes: int, meta_edge_side: int) -> list[MetaEdge]:
    """
    Sorts the meta edges based on the first (if meta_edge_side is 0) or second (if meta_edge_side is 1) node.

    Calling this function twice and sorting by both nodes completely sorts the meta edges by nodes.
    This means that if there are multiple meta edges with the same endpoints, they end up grouped together.

    # Parameters

    - `meta_edges`: List of meta edges to sort. The original list is not changed
    - `num_meta_nodes`: The total number of meta nodes in the meta graph
    - `meta_edge_side`: Whether to sort by the first (0) or second (1) node of the meta_edge. Must be 0 or 1
    """
    counts = [0] * num_meta_nodes

    for meta_edge in meta_edges:
        meta_node_index = meta_edge.nodes[meta_edge_side].counting_sort_index
        counts[meta_node_index] += 1

    # set counts[i] to the total number of edges with indices <= i
    cumulative_count = 0
    for i in range(len(counts)):
        cumulative_count += counts[i]
        counts[i] = cumulative_count - counts[i]
    # TODO use this instead, should give the same result:
    # import itertools
    # counts = list(itertools.accumulate([0] + counts[:-1]))

    meta_edges_sorted: list[MetaEdge] = [None] * len(meta_edges)
    for meta_edge in meta_edges:
        meta_node_index = meta_edge.nodes[meta_edge_side].counting_sort_index
        meta_edges_sorted[counts[meta_node_index]] = meta_edge
        counts[meta_node_index] += 1

    return meta_edges_sorted
