from abc import ABC, abstractmethod
from typing import Optional

from torch import Tensor
from torch_geometric.data import Data


class Decoder(ABC):
    """
    Runs a decoder/heuristic on a given graph.

    If they are provided, the decoder/heuristic is initialised with steering weights of size `[graph.num_edges]`.

    Returns a TSP tour as a tensor of size `[num_edges]`, where a 1 indicates that the corresponding edge is in the
    tour.

    (Non-abstract) subclasses of this class should be annotated with `@yaml_object(YAML())` and `@dataclass()`.
    """

    @abstractmethod
    def _run(self, graph: Data, steering_weights: Optional[Tensor]) -> Tensor:
        pass

    def __call__(self, graph: Data, steering_weights: Optional[Tensor] = None) -> Tensor:
        return self._run(graph, steering_weights).to(graph.edge_index.device)
