from dataclasses import dataclass
import multiprocessing
from typing import Optional

from ruamel.yaml import YAML, yaml_object
import torch
from torch import Tensor
from torch_geometric.data import Data
from typing_extensions import override

from util import sum_of_edge_weights
from ._decoder import Decoder
from ._minimum_k_cut import MinimumKCutDecoder
from ._tsp import TSPDecoder


@yaml_object(YAML())
@dataclass()
class RepeatedParallel(MinimumKCutDecoder, TSPDecoder):
    """
    Runs the given decoder `num_runs` times in parallel, then returns the best result.
    Should be used in conjunction with a probabilistic decoder.

    Note: Result quality is evaluated using `sum_of_edge_weights()`, where smaller is better.

    When done using this decoder, `close_pool()` should be used to avoid errors when the program exits.
    """

    decoder: Decoder
    num_runs: int

    def close_pool(self):
        """
        Closes and deletes the internal multiprocessing pool.
        If the decoder is used again after calling this method, a new pool is automatically created.
        """
        if hasattr(self, "_pool"):
            self._pool.close()
            self._pool.join()
            del self._pool

    @override
    def _run(self, graph: Data, steering_weights: Optional[Tensor]) -> Tensor:
        pool = self._get_multiprocessing_pool()
        inputs = [(self.decoder, graph, steering_weights)] * self.num_runs
        individual_run_outputs = pool.starmap(_single_run, inputs)

        results, penalties = zip(*individual_run_outputs)
        best_index = torch.stack(penalties).argmin()
        return results[best_index]

    def _get_multiprocessing_pool(self) -> multiprocessing.Pool:
        """
        Returns the internal multiprocessing pool.
        If it doesn't yet, a new one is automatically created.
        """
        if not hasattr(self, "_pool"):
            self._pool = multiprocessing.get_context("spawn").Pool()
        return self._pool


def _single_run(decoder: Decoder, graph: Data, steering_weights: Optional[Tensor]) -> tuple[Tensor, Tensor]:
    result = decoder(graph, steering_weights)
    penalty = sum_of_edge_weights(graph, result)
    return result, penalty
