from datetime import datetime
from typing import Any

from ruamel.yaml import YAML
from ruamel.yaml.compat import StringIO
import torch
from torch import Tensor
from torch_geometric.data import Data


def sum_of_edge_weights(graph: Data, edges: Tensor) -> Tensor:
    """
    Calculates the sum of edge weights for the given set of edges (e.g. the weight of a k-cut or length of a TSP tour).

    `edges` should be of size `[graph.num_edges]`.
    If an entry is 1, it indicates that the corresponding edge is in the set, 0 indicates that it isn't.
    Returns the sum as a Tensor of size `[]`.

    This is differentiable.
    """
    # divide by two because torch_geometric stores each edge twice, once in each direction
    return torch.dot(edges, graph.edge_attr) / 2


def convert_tsp_tour_from_node_list_to_edge_index(graph: Data, tsp_tour: list[int]) -> Tensor:
    """
    Converts a TSP tour from a list of nodes in the order in which they are visited to a Tensor of size
    `[graph.num_edges]`.
    Each entry of the Tensor is 1, if the corresponding edge is in the TSP tour, 0 otherwise.
    """
    tsp_tour_edges = torch.zeros(graph.num_edges)

    for i, (node_1, node_2) in enumerate(graph.edge_index.t()):
        # check if node_1 and node_2 appear next to each other in the TSP tour
        # since the tour is a cycle, the first and last element of the list are also next to each other (distance n - 1)
        if abs(tsp_tour.index(node_1) - tsp_tour.index(node_2)) in [1, len(tsp_tour) - 1]:
            tsp_tour_edges[i] = 1

    return tsp_tour_edges


def get_timestamp() -> str:
    """
    Returns a string representing the current time, not including milliseconds.
    """
    return datetime.now().isoformat(timespec="seconds").replace(":", "-")


def yaml_object_to_string(yaml_data: Any) -> str:
    """
    Converts an instance of a class annotated with `@yaml_object(YAML())` into its YAML string representation.
    """
    stream = StringIO()
    YAML().dump(yaml_data, stream)
    return stream.getvalue()
