# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
from typing import TYPE_CHECKING

from concorde.tsp import TSPSolver  # noqa
import numpy as np
from torch import Tensor
from torch_geometric.data import Data

from util import convert_tsp_tour_from_node_list_to_edge_index

if TYPE_CHECKING:  # avoid circular import
    from ._config import TSPSolverConfig


def concorde_solve_tsp_exact(config: TSPSolverConfig, graph: Data) -> Tensor:
    """
    Runs an exact TSP solver on the given graph and returns the optimal TSP tour.
    This is a Tensor of size `[graph.num_edges]` that indicates whether an edge is in the optimal tour.
    """
    # intuitively, one would think that the correct norm is "EUC_2D", which is supposed to calculate the 2D euclidean
    # distances between the cities. this gives me suboptimal TSP tours, though (compared to another exact solver).
    # Josi et al. use "GEO" in their repo (https://github.com/chaitjo/learning-tsp/blob/master/data/tsp/generate_tsp.py)
    # i found that using "GEO" allows Concorde to find the optimal TSP tour (even though Concorde's opinion of how long
    # that tour is is wildly off)
    solver = TSPSolver.from_data(graph.x[:, 0], graph.x[:, 1], norm="GEO")
    solution = solver.solve()

    assert solution.found_tour
    assert (np.sort(solution.tour) == np.arange(graph.num_nodes)).all()

    return convert_tsp_tour_from_node_list_to_edge_index(graph, solution.tour.tolist())
