import math
import time

import torch
from tqdm import tqdm

from data_generation._graph_generators import SimpleGraphPlusConfig
from karger import karger_stein_repeated, MetaGraph
from util import sum_of_edge_weights


def measure_karger_execution_time(
    num_nodes_per_graph: int = 100,
    num_graphs: int = 100,
    max_karger_runs_per_graph: int = 100,
):
    """
    Generates a graph with known optimal cut and runs Karger-Stein on it until it finds a minimum cut.
    This is repeated `num_graphs` times.
    The Karger-Stein execution time and number of runs until the minimum cut was found are recorded for each graph.
    On top of this, the individual execution times of the Karger-Stein runs are recorded as well.
    """
    graph_generator_config = create_graph_generator_config(num_nodes_per_graph)

    # we store these values in unaggregated form so that we can do some statistics later
    num_runs_until_minimum_found: list[int] = []
    time_per_graph: list[float] = []
    time_per_karger_run: list[float] = []

    for _ in tqdm(range(num_graphs)):
        graph = graph_generator_config.generate_graph()
        graph.meta_graph = MetaGraph.from_pyg(graph)
        ground_truth_cut_size = sum_of_edge_weights(graph, graph.y)

        start_graph = time.time()

        for run in range(1, max_karger_runs_per_graph + 1):
            start_karger_run = time.time()
            cut = karger_stein_repeated(graph, 2, num_runs=1)
            time_per_karger_run.append(time.time() - start_karger_run)

            found_cut_size = sum_of_edge_weights(graph, cut)
            if math.isclose(found_cut_size, ground_truth_cut_size, rel_tol=1e-6):
                break

        time_per_graph.append(time.time() - start_graph)
        # run is still set from the previous for loop
        num_runs_until_minimum_found.append(run)

    results = {
        "config": {
            "num_nodes_per_graph": num_nodes_per_graph,
            "num_graphs": num_graphs,
            "max_karger_runs_per_graph": max_karger_runs_per_graph,
        },
        "num_runs_until_minimum_found": torch.tensor(num_runs_until_minimum_found),
        "time_per_graph": torch.tensor(time_per_graph),
        "time_per_karger_run": torch.tensor(time_per_karger_run),
    }
    results_out_file_name = f"karger-execution-times--simple-graph-plus--n-{num_nodes_per_graph}.pt"
    torch.save(results, results_out_file_name)
    print("Results stored at", results_out_file_name)


def create_graph_generator_config(num_nodes: int):
    return SimpleGraphPlusConfig(
        num_nodes=num_nodes,
        num_clusters=2,
        min_edges_between_clusters=num_nodes // 3,
        max_edges_between_clusters=num_nodes // 3,
    )


if __name__ == '__main__':
    measure_karger_execution_time()
