from dataclasses import dataclass
from typing import Optional, Literal

from ruamel.yaml import YAML, yaml_object
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj
from typing_extensions import override

from util import convert_tsp_tour_from_node_list_to_edge_index
from .._tsp_decoder import TSPDecoder
from ._insertion_heuristic import run_insertion


@yaml_object(YAML())
@dataclass()
class Insertion(TSPDecoder):
    """
    Modifies the graph's edge weights based on the steering weights, then runs an insertion heuristic on the
    resulting graph.
    Specifically, the graph's edge weights are multiplied with `1 - torch.sigmoid(steering_weights)`.

    `mode` determines which version of the insertion heuristic to use
    (random insertion/nearest insertion/farthest insertion).
    """

    mode: Literal["farthest", "nearest", "random"]

    @override
    def _run(self, graph: Data, steering_weights: Optional[Tensor]) -> Tensor:
        if steering_weights is not None:
            # this should probably happen much earlier, so that i can backpropagate through the sigmoid
            # (also adjust the documentation)
            steering_weights = 1 - torch.sigmoid(steering_weights)
            modified_weights = graph.edge_attr * steering_weights
        else:
            modified_weights = graph.edge_attr

        adjacency_matrix = to_dense_adj(graph.edge_index, edge_attr=modified_weights).squeeze()
        _, predicted_tour_node_list = run_insertion(adjacency_matrix, self.mode)

        return convert_tsp_tour_from_node_list_to_edge_index(graph, predicted_tour_node_list)
