import math
import random
from typing import Final

from ._compact_meta_graph import compact_meta_graph
from ._meta_graph import MetaGraph


_ONE_OVER_SQRT_2: Final[float] = 1 / math.sqrt(2)


def contract_until_num_nodes_reduced_to_fixed_fraction(meta_graph: MetaGraph, k: int) -> MetaGraph:
    """
    Contraction algorithm as described by
    Karger and Stein, "A New Approach to the Minimum Cut Problem", Journal of the ACM, 1996.

    The algorithm samples a meta edge based on probabilities proportional to the meta edge weights, then contracts that
    meta edge.
    This is repeated until the number of meta nodes in the graph is reduced from `n` to `1 + ceil(n / sqrt(2))`.
    If `n <= 6`, this would lead to zero contractions, so it instead contracts meta edges until there are `k` meta
    nodes left.

    If the meta_graph's edges contain steering weights, they are used to influence the probability that an edge is
    selected for contraction.
    Specifically, the probability that an edge is selected is proportional to
    `meta_edge.weight * meta_edge.steering_weight`.

    Returns the meta graph that results from the contractions.
    """
    for meta_node in meta_graph.nodes:
        meta_node.reset_root()

    num_meta_nodes = len(meta_graph.nodes)
    target_num_meta_nodes = _calculate_target_num_meta_nodes(num_meta_nodes, k)
    meta_edge_indices = list(range(len(meta_graph.edges)))

    sample_weights = [meta_edge.steering_weight for meta_edge in meta_graph.edges]

    # contract random edges until there are target_num_meta_nodes meta nodes left
    while num_meta_nodes > target_num_meta_nodes:
        # if the sample weights are all 0, sample uniformly at random
        sample_weights_nonzero = sample_weights if max(sample_weights) > 0 else None

        meta_edge_index = random.choices(meta_edge_indices, sample_weights_nonzero, k=1)[0]

        # contract the meta edge
        meta_edge = meta_graph.edges[meta_edge_index]
        meta_nodes_were_not_already_merged = meta_edge.nodes[0].union(meta_edge.nodes[1])

        if meta_nodes_were_not_already_merged:
            num_meta_nodes -= 1

        # don't sample that edge again
        sample_weights[meta_edge_index] = 0

    return compact_meta_graph(meta_graph, num_meta_nodes)


def _calculate_target_num_meta_nodes(num_meta_nodes: int, k: int) -> int:
    """
    Calculates the number of meta nodes that should be left after the contractions.
    """
    if num_meta_nodes > 6:
        target_num_meta_nodes = 1 + math.ceil(num_meta_nodes * _ONE_OVER_SQRT_2)
        # only contract until there are k nodes left
        return max(target_num_meta_nodes, k)
    else:
        # this special case is important because 1 + ceil(6 / sqrt(2)) == 6
        # this means that the algorithm would recurse endlessly without merging nodes
        return k
