from typing import Literal, Optional

import torch

from ._cut import Cut
from ._meta_graph import MetaGraph, MetaNode


def update_best_known_cut(
    meta_graph: MetaGraph,
    best_known_cut: Cut,
    input_graph_nodes: list[MetaNode],
    mode: Optional[Literal["train", "test"]],
) -> Cut:
    """
    Given the currently best known cut and a new cut (represented as a meta graph), check whether the new cut is better
    than the current best.
    Returns the better of the two cuts.
    If the new cut is better, it's converted to a `Cut` object.

    Parameters:

    - `meta_graph`: The meta graph that represents the new cut.
                    Each meta node represents a connected component resulting from the cut.
                    The nodes that were merged into a meta node are the nodes in that connected component.
                    The meta graph should therefore have exactly `k` nodes (`k` is an input to Karger-Stein).
    - `best_known_cut`: The currently best known cut.
    - `input_graph_nodes`: The nodes in the original graph that was given to the Karger-Stein algorithm.

    Returns: The better of the two cuts, as a `Cut` object.
    """
    # there are two options for what to use here: (1) meta_edge.steering_weight and (2) meta_edge.weight.
    # using the steering weight (1) would mean that vanilla karger-stein is running on a modified version of the
    # graph, where the edge weights are completely replaced with the steering weights.
    # however, this has led to issues when testing models trained with direct gradient, because for some GNN outputs
    # this makes it impossible or very unlikely that the minimum cut is found.
    # for this reason, i am modifying karger-stein to instead use the original edge weight (2) when testing.
    # (i'm calling this a modification because the vanilla karger-stein algorithm uses the same weights for this
    # and for choosing which edge to contract)
    if mode == "train":
        cut_value = sum([meta_edge.steering_weight for meta_edge in meta_graph.edges])
    else:
        # mode is "test" or None (None means that the steering weights are the same as the weights)
        cut_value = sum([meta_edge.weight for meta_edge in meta_graph.edges])

    if cut_value >= best_known_cut.value:
        # not a better cut => return the same cut that we previously had
        return best_known_cut

    # each remaining MetaNode in the MetaGraph represents a connected component resulting from the cut
    components = [_find_nodes_in_connected_component(component, input_graph_nodes) for component in meta_graph.nodes]

    components_combined = torch.zeros(len(input_graph_nodes))
    for i, component in enumerate(components):
        components_combined += (i + 1) * component

    assert components_combined.min() == 1

    return Cut(components_combined, cut_value)


def _find_nodes_in_connected_component(
    component_meta_node: MetaNode,
    input_graph_nodes: list[MetaNode]
) -> list[MetaNode]:
    """
    Finds all nodes in a connected component.
    The connected component is represented by a single `MetaNode`, into which all of the nodes of that connected
    component have been merged.

    Returns a Tensor of size `[num_nodes]`, whose entries are 1 if the corresponding node is in the connected component.
    """
    component_meta_node_root = component_meta_node.find_root_safe()

    nodes_in_component = torch.zeros(len(input_graph_nodes))

    for i, node in enumerate(input_graph_nodes):
        if node.find_root_safe() == component_meta_node_root:
            nodes_in_component[i] = 1

    return nodes_in_component
