from typing import Dict, Optional, Tuple

from torch_geometric.data.data import Data

from hmpn.abstract.abstract_graph_assertions import AbstractGraphAssertions


class HomogeneousGraphAssertions(AbstractGraphAssertions):

    def __init__(
        self,
        *,
        in_node_features: Dict[str, int],
        in_edge_features: Dict[Tuple[str, str, str], int],
        in_global_features: Optional[int],
    ):
        """

        Args:
            in_node_features:
                number of input features for nodes
            in_edge_features:
                number of input features for edges
            in_global_features:
                Number of input global features, None if no global features are used
        """
        super().__init__(
            in_node_features=in_node_features,
            in_edge_features=in_edge_features,
            in_global_features=in_global_features,
        )

    def __call__(self, tensor: Data):
        """
        Does various shape assertions to make sure that the (batch of) graph(s) is built correctly
        Args:
            tensor: (batch of) heterogeneous graph(s)

        Returns:

        """
        super().__call__(tensor)
        assert tensor.edge_index.shape[0] == 2, (
            f"Edge index must have shape (2, num_edges), "
            f"given '{tensor.edge_index.shape}' instead."
        )
        assert tensor.edge_index.shape[1] == tensor.edge_attr.shape[0], (
            f"Must provide one edge index per edge "
            f"feature vector, given "
            f"'{tensor.edge_index.shape}' and "
            f"'{tensor.edge_attr.shape}' instead."
        )

        in_edge_features = tensor.edge_attr.shape[1]
        expected_edge_features = self._assertion_dict.get("in_edge_features")
        assert in_edge_features == expected_edge_features, (
            f"Feature dimensions of edges do not match. "
            f"Given '{in_edge_features}', "
            f"expected '{expected_edge_features}"
        )

        in_node_features = tensor.x.shape[1]
        expected_node_features = self._assertion_dict.get("in_node_features")
        assert in_node_features == expected_node_features, (
            f"Feature dimensions of nodes do not match. "
            f"Given '{in_node_features}', "
            f"expected '{expected_node_features}"
        )
        if self._assert_global:
            assert hasattr(
                tensor, "batch"
            ), "Need batch pointer for graph ids when using batch and global features"
            assert (
                tensor.batch is not None
            ), "Need batch pointer for graph ids when using batch and global features"
