import math
from multiprocessing import Pool
import random
from typing import Final
import unittest

import torch
from torch_geometric.data import Data
from tqdm import tqdm

from constants import TQDM_OPTIONS
from data_generation import SimpleGraphPlusConfig
from karger import karger_stein_repeated, MetaGraph
from util import sum_of_edge_weights


NUM_TESTS: Final = 100
NUM_KARGER_RUNS_PER_TEST: Final = 20

# maximum problem size
MAX_NUM_CLUSTERS: Final = 5
MAX_NUM_EDGES_BETWEEN_CLUSTERS: Final = 30
MAX_NUM_NODES: Final = 350


class TestKarger(unittest.TestCase):
    def test(self):
        """
        Tests Karger's algorithm `NUM_TESTS` times on randomly generated graphs.
        """
        with Pool() as pool:
            # run _test_on_random_graph() in parallel while showing a progress bar
            imap = pool.imap_unordered(_test_on_random_graph, range(NUM_TESTS))
            # Pool.imap_unordered() is lazy, so we need to consume the outputs in order for the calculation to happen
            for _ in tqdm(imap, total=NUM_TESTS, **TQDM_OPTIONS):
                pass


#  the function needs to take a parameter so that Pool.imap_unordered() works. we simply ignore it
def _test_on_random_graph(_):
    """
    Generates a graph with known minimum cut, then checks whether Karger's algorithm finds that minimum cut.
    """
    graph, k = _generate_test_graph()
    cut = karger_stein_repeated(graph, k, NUM_KARGER_RUNS_PER_TEST)
    assert math.isclose(sum_of_edge_weights(graph, cut), sum_of_edge_weights(graph, graph.y), rel_tol=1e-6)
    assert torch.equal(cut, graph.y)


def _generate_test_graph() -> tuple[Data, int]:
    """
    Returns:
    1. A generated graph with random parameters, for which the minimum k-cut solution is known by construction.
       This solution can be accessed with `graph, _ = _generate_test_graph()`, then `graph.y`.
       The meta graph is included as `graph.meta_graph`.
    2. The number of clusters k
    """
    num_clusters = random.randint(2, MAX_NUM_CLUSTERS)
    num_edges_between_clusters = random.randint(2, MAX_NUM_EDGES_BETWEEN_CLUSTERS)
    # generate_graph() requires num_nodes >= num_clusters * (num_edges_between_clusters + 2)
    minimum_number_of_nodes = num_clusters * (num_edges_between_clusters + 2)
    num_nodes = random.randint(minimum_number_of_nodes, MAX_NUM_NODES)

    config = SimpleGraphPlusConfig(
        num_nodes=num_nodes,
        num_clusters=num_clusters,
        min_edges_between_clusters=num_edges_between_clusters,
        max_edges_between_clusters=num_edges_between_clusters,
    )
    graph = config.generate_graph()

    graph.meta_graph = MetaGraph.from_pyg(graph)

    return graph, num_clusters


if __name__ == '__main__':
    unittest.main()
