from dataclasses import dataclass
from typing import Optional

from ruamel.yaml import YAML, yaml_object
from torch import Tensor
from torch_geometric.data import Data
from typing_extensions import override

from christofides_steered import christofides_steered
from ._tsp_decoder import TSPDecoder


@yaml_object(YAML())
@dataclass()
class Christofides(TSPDecoder):
    """
    Modifies the graph's edge weights based on the steering weights, then runs the Christofides algorithm on the
    resulting graph.

    The graph's edge weights are multiplied with `1 - torch.sigmoid(steering_weights)`.
    Steering weights should have size `[graph.num_edges]`.

    Returns a Tensor indicating for each edge whether it's in the TSP tour. Size `[graph.num_edges]`
    """

    @override
    def _run(self, graph: Data, steering_weights: Optional[Tensor]) -> Tensor:
        # TODO move christofides code here, maybe inline the function
        return christofides_steered(steering_weights, graph)
