# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
from typing import TYPE_CHECKING

import python_tsp.distances
import python_tsp.exact
from torch import Tensor
from torch_geometric.data import Data

from util import convert_tsp_tour_from_node_list_to_edge_index

if TYPE_CHECKING:  # avoid circular import
    from ._config import TSPSolverConfig


def solve_tsp_exact(config: TSPSolverConfig, graph: Data) -> Tensor:
    """
    Runs an exact TSP solver on the given graph and returns the optimal TSP tour.
    This is a Tensor of size `[graph.num_edges]` that indicates whether an edge is in the optimal tour.
    """
    distance_matrix = python_tsp.distances.euclidean_distance_matrix(graph.x)
    # this library has another exact solver that uses branch & bound, but that takes extremely long on some graphs
    optimal_tsp_tour, _tour_length = python_tsp.exact.solve_tsp_dynamic_programming(distance_matrix)
    return convert_tsp_tour_from_node_list_to_edge_index(graph, optimal_tsp_tour)
