from typing import Optional

from matplotlib import pyplot as plt
import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Data
from tqdm import tqdm

from constants import TQDM_OPTIONS
from data_generation import Dataset
from karger import karger_stein_repeated
from util import sum_of_edge_weights
from ._util import prepare_model


# NOTE: a lot of this code is duplicated in karger_num_runs_until_minimum_found.py


@torch.no_grad()
def measure_karger_average_cut_size(
    num_runs_per_graph: int = 1,
    dataset_name: str = "simple-graph-plus--n-100--k-2--max-cut-value-30",
    checkpoint_path: str = "../logs/important/2023-10-10T16-33-01--simple-gnn--on--"
                           "simple-graph-plus--n-100--k-2--max-cut-value-30/5-checkpoint--new-format.pt",
):
    dataset = Dataset.load(dataset_name)
    val_graphs = dataset.val_graphs

    model = prepare_model(checkpoint_path)

    errors_vanilla = _evaluate_on_dataset(val_graphs, num_runs_per_graph)
    errors_with_gnn = _evaluate_on_dataset(val_graphs, num_runs_per_graph, model)

    _plot_data(errors_vanilla, errors_with_gnn, dataset.config.name)


def _plot_data(errors_vanilla: Tensor, errors_with_gnn: Tensor, dataset_name: str):
    labels = ("Vanilla Karger-Stein", "Karger-Stein + GNN")
    bins = range(0, 500, 50)

    fig, ax = plt.subplots()
    fig.set_size_inches(9, 4)

    _, bins, _ = ax.hist((errors_vanilla, errors_with_gnn), bins=bins, align="left", label=labels)

    plt.xticks(bins[:-1])

    ax.spines["top"].set_color("none")
    ax.spines["right"].set_color("none")

    ax.set_title(f'Karger-Stein Algorithm on "{dataset_name}"', pad=25)
    ax.set_xlabel("found cut value - optimal cut value")
    ax.set_ylabel("number of graphs", labelpad=10)
    ax.legend()

    plt.savefig(f"karger-average-cut-value--on-{dataset_name}.svg")


@torch.no_grad()
def _evaluate_on_dataset(
    graphs: list[Data],
    num_runs_per_graph: int,
    model: Optional[Module] = None,
) -> Tensor:
    errors = [_evaluate_on_graph(graph, num_runs_per_graph, model) for graph in tqdm(graphs, **TQDM_OPTIONS)]
    return torch.tensor(errors).flatten()


@torch.no_grad()
def _evaluate_on_graph(graph: Data, num_runs: int, model: Optional[Module] = None) -> list[float]:
    ground_truth_cut_size = sum_of_edge_weights(graph, graph.y)

    if model is not None:
        steering_weights = 1 - torch.sigmoid(model(graph))
    else:
        steering_weights = None

    errors = []
    for _ in range(num_runs):
        cut = karger_stein_repeated(graph, 2, 1, steering_weights=steering_weights, mode="test")
        error = sum_of_edge_weights(graph, cut) - ground_truth_cut_size
        errors.append(error.item())

    return errors


if __name__ == "__main__":
    measure_karger_average_cut_size()
