import math
from typing import Optional

from matplotlib import pyplot as plt
import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Data
from tqdm import tqdm

from constants import TQDM_OPTIONS
from data_generation import Dataset
from karger import karger_stein_repeated
from util import sum_of_edge_weights
from ._util import prepare_model


# NOTE: a lot of this code is duplicated from karger_average_cut_value.py


@torch.no_grad()
def count_karger_runs_until_minimum_found(
    dataset_name: str = "simple-graph-plus--n-100--k-2--max-cut-value-30",
    checkpoint_path: str = "../logs/important/2023-10-10T16-33-01--simple-gnn--on--"
                           "simple-graph-plus--n-100--k-2--max-cut-value-30/5-checkpoint--new-format.pt",
):
    dataset = Dataset.load(dataset_name)
    val_graphs = dataset.val_graphs

    model = prepare_model(checkpoint_path)

    counts_vanilla = _count_runs_on_dataset(val_graphs)
    counts_with_gnn = _count_runs_on_dataset(val_graphs, model)

    _plot_data(counts_vanilla, counts_with_gnn, dataset.config.name)


def _plot_data(counts_vanilla: Tensor, counts_with_gnn: Tensor, dataset_name: str):
    labels = ("Vanilla Karger-Stein", "Karger-Stein + GNN")
    max_count = torch.stack((counts_vanilla, counts_with_gnn)).max().item()
    bins = list(range(1, max_count + 1))

    fig, ax = plt.subplots()
    fig.set_size_inches(9, 4)

    ax.hist((counts_vanilla, counts_with_gnn), bins=bins, align="left", label=labels)

    ax.set_xticks(bins)
    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")

    ax.set_title(f'Karger-Stein Algorithm on "{dataset_name}"', pad=25)
    ax.set_xlabel("number of runs until minimum cut was found")
    ax.set_ylabel("number of graphs", labelpad=10)
    ax.legend()

    plt.savefig(f"karger-num-runs-until-minimum-found--on-{dataset_name}.svg")


@torch.no_grad()
def _count_runs_on_dataset(graphs: list[Data], model: Optional[Module] = None) -> Tensor:
    counts = [_count_runs_on_graph(graph, model) for graph in tqdm(graphs, **TQDM_OPTIONS)]
    return torch.tensor(counts)


@torch.no_grad()
def _count_runs_on_graph(graph: Data, model: Optional[Module] = None) -> int:
    ground_truth_cut_size = sum_of_edge_weights(graph, graph.y)

    if model is not None:
        steering_weights = 1 - torch.sigmoid(model(graph))
    else:
        steering_weights = None

    num_runs = 0
    found_cut_size = -1
    while not math.isclose(found_cut_size, ground_truth_cut_size, rel_tol=1e-6):
        num_runs += 1
        cut = karger_stein_repeated(graph, 2, 1, steering_weights=steering_weights, mode="test")
        found_cut_size = sum_of_edge_weights(graph, cut)

        # hard cut-off if it's taking too long
        if num_runs > 30:
            break

    return num_runs


if __name__ == "__main__":
    count_karger_runs_until_minimum_found()
