from dataclasses import dataclass
from typing import Optional

from ruamel.yaml import YAML, yaml_object
import torch
from torch import Tensor
from torch.distributions import Gumbel
from torch_geometric.data import Data
from typing_extensions import override

from ._decoder import Decoder
from ._minimum_k_cut import MinimumKCutDecoder
from ._tsp import TSPDecoder


@yaml_object(YAML())
@dataclass()
class Noisy(MinimumKCutDecoder, TSPDecoder):
    """
    Adds noise to the steering weights, then runs the given decoder on the modified weights.
    This is useful for turning a deterministic decoder into a probabilistic one.

    The noise is drawn from a Gumbel distribution with location 0.
    The scale is calculated based on the variance of the steering weights and multiplied with `scale_factor`,
    so that the noise is in a similar order of magnitude as the steering weights.

    If no steering weights are given, the noise is instead drawn from a Gumbel distribution with location 1 and scale
    `scale_factor`.
    Then the noise is used directly as steering weights for the inner decoder.

    If this is used in conjunction with `Repeated`, e.g.
    `Repeated(Noisy(Christofides(), scale_factor=0.022), num_runs=50)`,
    then the scale factor should be adjusted based on `num_runs`.
    If the number of decoder runs increases, then the scale factor for the noise can also be increased for better
    results.
    """

    decoder: Decoder
    scale_factor: float

    @override
    def _run(self, graph: Data, steering_weights: Optional[Tensor]) -> Tensor:
        if steering_weights is not None:
            scale = self.scale_factor * steering_weights.var().sqrt() * 6 / (torch.pi ** 2)
            # no need to move the noise to the correct device, since torch uses the device of steering_weights
            noise = Gumbel(0, scale).sample(steering_weights.size())
            return self.decoder(graph, steering_weights + noise)
        else:
            # assume steering weights are all 1. let scale_factor control the scale directly
            noise = Gumbel(1, self.scale_factor).sample([graph.num_edges]).to(graph.edge_index.device)
            return self.decoder(graph, noise)
