from typing import Optional

from networkx.algorithms.approximation import christofides as christofides_nx
import torch
from torch import Tensor
from torch_geometric.data import Data

from util import convert_tsp_tour_from_node_list_to_edge_index


def christofides_steered(steering_weights: Optional[Tensor], graph: Data) -> Tensor:
    """
    Modifies the graph's edge weights based on the steering weights, then runs the Christofides algorithm on the
    resulting graph.

    The graph's edge weights are multiplied with `1 - torch.sigmoid(steering_weights)`.
    Steering weights should have size `[graph.num_edges]`.

    Returns a Tensor indicating for each edge whether it's in the TSP tour. Size `[graph.num_edges]`
    """
    if steering_weights is not None:
        steering_weights = 1 - torch.sigmoid(steering_weights)
        modified_weights = graph.edge_attr * steering_weights
    else:
        modified_weights = graph.edge_attr

    for edge, modified_weight in zip (graph.edge_index.t().tolist(), modified_weights.tolist()):
        graph.networkx.edges[edge]["modified_weight"] = modified_weight

    tsp_tour_node_list = christofides_nx(graph.networkx, weight="modified_weight")
    # networkx adds the first node again at the end of the list but the conversion expects each node to appear only once
    tsp_tour_node_list = tsp_tour_node_list[:-1]

    return convert_tsp_tour_from_node_list_to_edge_index(graph, tsp_tour_node_list)
