from typing import Literal, Optional

import torch
from torch import Tensor
from torch_geometric.data import Data

from ._contract_karger_et_al import contract_until_num_nodes_reduced_to_fixed_fraction
from ._cut import Cut
from ._meta_graph import MetaEdgeWeight, MetaGraph, MetaNode
from ._update_best_known_cut import update_best_known_cut


def karger_stein_repeated(
    graph: Data,
    k: int,
    num_runs: int,
    steering_weights: Optional[Tensor] = None,
    mode: Optional[Literal["train", "test"]] = None,
) -> Tensor:
    """
    TODO

    `steering_weights` should be a tensor of size `[num_edges]` that can be used to influence the probability
    that an edge is contracted.
    A high value makes it more likely that an edge is contracted, and a low value makes it less likely.
    For example, passing `steering_weights=1 - graph.y` would lead the algorithm to only contract edges that aren't
    in the minimum cut, thus always finding the correct solution.
    All entries of `steering_weights` should be non-negative.

    `mode` should be `None` if and only if `steering_weights` is `None`.

    # Sources

    The Karger-Stein algorithm was first published in
    Karger and Stein, "An Õ(n^2) algorithm for minimum cuts",
    in Proceedings of the Twenty-Fifth Annual ACM Symposium on Theory of Computing, 1993.

    The present implementation is loosely based on the code for the long version of
    Chekuri et al., "Experimental Study of Minimum Cut Algorithms",
    in Proceedings of the Eighth Annual ACM-SIAM Symposium on Discrete Algorithms, 1997.
    The paper's code can be found here: http://www.columbia.edu/~cs2035/codes/cut-src.tar.gz
    """
    meta_graph: MetaGraph = graph.meta_graph
    meta_graph.update_steering_weights(steering_weights)

    best_known_cut = Cut(connected_components=torch.tensor([]), value=MetaEdgeWeight(float("inf")))

    for _ in range(num_runs):
        best_known_cut = contract_recursive(meta_graph, k, best_known_cut, meta_graph.nodes, mode)

    # reconstruct cut edges from _minimum_cut_nodes
    # mark all edges that are part of the minimum cut with True
    # (i.e. the edges with exactly one node in _minimum_cut_nodes)
    minimum_cut_edges: list[bool] = [
        best_known_cut.connected_components[graph.edge_index[0, i]]
        != best_known_cut.connected_components[graph.edge_index[1, i]]
        for i in range(graph.edge_index.size(1))
    ]

    device = graph.edge_index.device
    # TODO change dtype to bool?
    minimum_cut_edges = torch.tensor(minimum_cut_edges, dtype=torch.float32, requires_grad=False, device=device)

    return minimum_cut_edges


def contract_recursive(
    meta_graph: MetaGraph,
    k: int,
    best_known_cut: Cut,
    input_graph_nodes: list[MetaNode],
    mode: Optional[Literal["train", "test"]],
) -> Cut:
    """
    TODO
    Returns the best cut that was found
    """
    assert len(meta_graph.nodes) >= k
    assert len(meta_graph.edges) > 0  # if <= 0, then the graph is disconnected

    if len(meta_graph.nodes) == k:
        best_known_cut = update_best_known_cut(meta_graph, best_known_cut, input_graph_nodes, mode)
        return best_known_cut
    else:
        new_meta_graph = contract_until_num_nodes_reduced_to_fixed_fraction(meta_graph, k)
        best_known_cut = contract_recursive(new_meta_graph, k, best_known_cut, input_graph_nodes, mode)
        best_known_cut = contract_recursive(new_meta_graph, k, best_known_cut, input_graph_nodes, mode)
        return best_known_cut
