from dataclasses import dataclass
from typing import Optional

from ruamel.yaml import YAML, yaml_object
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj
from typing_extensions import override

from util import convert_tsp_tour_from_node_list_to_edge_index
from .._tsp_decoder import TSPDecoder
from ._beam_search_function import beamsearch_tour_nodes_shortest


@yaml_object(YAML())
@dataclass()
class BeamSearch(TSPDecoder):
    """
    Performs beam search on the given TSP instance, steered by scores (logits) predicted by the model.
    """

    beam_size: int

    @override
    def _run(self, graph: Data, steering_weights: Optional[Tensor]) -> Tensor:
        if steering_weights is None:
            steering_weights = - graph.edge_attr

        # note if i want to implement batching: to_dense_adj has a batch parameter in case there are multiple graphs
        adjacency_matrix = to_dense_adj(graph.edge_index, edge_attr=graph.edge_attr)
        model_output_matrix = to_dense_adj(graph.edge_index, edge_attr=steering_weights)

        path_node_list = beamsearch_tour_nodes_shortest(
            model_output_matrix,
            adjacency_matrix,
            self.beam_size,
            batch_size=1,
            num_nodes=graph.num_nodes,
        )

        return convert_tsp_tour_from_node_list_to_edge_index(graph, path_node_list.squeeze().tolist())
