# postponed evaluation of annotations (use a class name as type hint inside of the same class)
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional

from torch import Tensor
from torch_geometric.data import Data

from ._meta_edge import MetaEdge
from ._meta_edge_weight import MetaEdgeWeight
from ._meta_node import MetaNode


@dataclass()
class MetaGraph:
    edges: list[MetaEdge]
    nodes: list[MetaNode]
    _indices_of_edges_without_duplicates: Optional[list[int]] = None

    def from_pyg(graph: Data) -> MetaGraph:
        """
        Converts a pytorch geometric `Data` object to a `_MetaGraph`.
        """
        meta_nodes = [MetaNode(name) for name in range(graph.num_nodes)]

        meta_edges: list[MetaEdge] = []
        indices_of_edges_without_duplicates: list[int] = []
        for i in range(graph.edge_index.size(1)):
            meta_edge_node_names: list[int] = graph.edge_index[:, i].tolist()
            meta_edge_nodes = [meta_nodes[name] for name in meta_edge_node_names]
            meta_edge_weight = MetaEdgeWeight(graph.edge_attr[i].item())

            # check whether we already added the reverse of this edge, because we assume the graph to be undirected
            # if edge (a, b) has a different (steering) weight than (b, a), that of the second edge is discarded!
            if all([meta_edge.nodes != list(reversed(meta_edge_nodes)) for meta_edge in meta_edges]):
                meta_edges.append(MetaEdge(meta_edge_nodes, meta_edge_weight, meta_edge_weight))
                indices_of_edges_without_duplicates.append(i)

        meta_graph = MetaGraph(meta_edges, meta_nodes, indices_of_edges_without_duplicates)
        return meta_graph

    def update_steering_weights(self, steering_weights: Optional[Tensor]):
        assert self._indices_of_edges_without_duplicates is not None

        if steering_weights is None:
            for meta_edge in self.edges:
                meta_edge.steering_weight = meta_edge.weight
        else:
            for i, meta_edge in zip(self._indices_of_edges_without_duplicates, self.edges):
                meta_edge.steering_weight = meta_edge.weight * steering_weights[i].item()
