from dataclasses import dataclass
from typing import Optional

from ruamel.yaml import YAML, yaml_object
from torch import Tensor
from torch_geometric.data import Data
from typing_extensions import override

from util import sum_of_edge_weights
from ._decoder import Decoder
from ._minimum_k_cut import MinimumKCutDecoder
from ._tsp import TSPDecoder


@yaml_object(YAML())
@dataclass()
class Repeated(MinimumKCutDecoder, TSPDecoder):
    """
    Runs the given decoder `num_runs` times, then returns the best result.
    Should be used in conjunction with a probabilistic decoder.

    Note: Result quality is evaluated using `sum_of_edge_weights()`, where smaller is better.
    """

    decoder: Decoder
    num_runs: int

    @override
    def _run(self, graph: Data, steering_weights: Optional[Tensor]) -> Tensor:
        best_result_so_far = None
        best_penalty_so_far = float("inf")

        for _ in range(self.num_runs):
            result = self.decoder(graph, steering_weights)
            penalty = sum_of_edge_weights(graph, result)

            if penalty < best_penalty_so_far:
                    best_result_so_far = result
                    best_penalty_so_far = penalty

        return best_result_so_far
