from typing import Optional

import torch
from torch import Tensor
from torch_geometric.data import Data
from tqdm import tqdm

from constants import TQDM_OPTIONS
from data_generation import Dataset
from decoders import *
from util import sum_of_edge_weights
from ._util import optimality_gap, prepare_model


@torch.no_grad()
def measure_average_optimality_gap(
    checkpoint_path: Optional[str] = None,
    decoder: TSPDecoder = RepeatedParallel(Insertion(mode="random"), num_runs=20),
    dataset_name: str = "tsp--n-20",
):
    """
    Runs the GNN -> decoder pipeline on the given dataset's validation set and measures the optimality gap.

    Parameters:
    - `checkpoint_path`: The path to the model to evaluate.
                         If `None` is given, the decoder (e.g. Christofides) is used without being steered by a model.
    - `decoder`: The non-learned decoding algorithm to use to translate the model's output into a valid solution.
    - `dataset_name`: The name of the dataset to evaluate the model on.
                      The dataset must be in the `data/` directory, in a `.pt` file of the same name.
                      Only the validation set is used.
    """
    print("Dataset:", dataset_name)
    print("Checkpoint:", checkpoint_path)
    print("Decoder:", decoder)

    dataset = Dataset.load(dataset_name)
    val_graphs = dataset.val_graphs

    if checkpoint_path is not None:
        model = prepare_model(checkpoint_path)
    else:
        def model(_: Data) -> Optional[Tensor]: return None

    optimality_gaps = []

    for graph in tqdm(val_graphs, **TQDM_OPTIONS):
        steering_weights = model(graph)

        predicted_tour = decoder(graph, steering_weights)

        optimal_tour_length = sum_of_edge_weights(graph, graph.y)
        predicted_tour_length = sum_of_edge_weights(graph, predicted_tour)

        optimality_gaps.append(optimality_gap(optimal_tour_length, predicted_tour_length))

    average_optimality_gap = sum(optimality_gaps) / len(optimality_gaps)
    print("Average optimality gap (%):", average_optimality_gap.item())

    if hasattr(decoder, "close_pool"):
        decoder.close_pool()


if __name__ == "__main__":
    measure_average_optimality_gap()
