from typing import Optional, Tuple
from hmpn.abstract.abstract_graph_assertions import AbstractGraphAssertions
from torch_geometric.data import Batch


class Heterogeneous2GraphAssertions(AbstractGraphAssertions):

    def __init__(
        self,
        *,
        in_node_features: dict[str, int],
        in_edge_features: dict[Tuple[str, str, str], int],
        in_global_features: Optional[int],
    ):
        super().__init__(
            in_node_features=in_node_features,
            in_edge_features=in_edge_features,
            in_global_features=in_global_features,
        )

    def __call__(self, graph: Batch):
        """
        Does various shape assertions to make sure that the (batch of) graph(s) is built correctly
        Args:
            graph: batch of heterogeneous graphs

        Returns: None

        """
        super().__call__(graph)

        # general assertions
        for edge_type, edge_store in zip(graph.edge_types, graph.edge_stores):
            assert "edge_attr" in edge_store.keys(), (
                f"Edge store corresponding to type '{edge_type}' must provide "
                f"key 'edge_attr'. Given key(s) '{edge_store.keys()}' instead."
            )
            assert "edge_index" in edge_store.keys(), (
                f"Edge store corresponding to type '{edge_type}' must provide "
                f"key 'edge_index'. Given key(s) '{edge_store.keys()}' instead."
            )

            index_shape = edge_store.get("edge_index").shape
            feature_shape = edge_store.get("edge_attr").shape
            assert (
                index_shape[0] == 2
            ), f"Edge index must have shape (2, num_edges), given '{index_shape}' instead."
            assert index_shape[1] == feature_shape[0], (
                f"Must provide one edge index per edge feature vector, "
                f"given '{index_shape}' and '{feature_shape}' instead."
            )

        for node_type, node_store in zip(graph.node_types, graph.node_stores):
            assert "x" in node_store.keys(), (
                f"Node store corresponding to type '{node_type}' must provide key 'x'. "
                f"Given key(s) '{node_store.keys()}' instead."
            )

        # node assertions
        in_node_features = self._assertion_dict.get("in_node_features")
        assert len(graph.node_types) == len(in_node_features.keys()), (
            f"Number of node types does not match. Given '{len(graph.node_types)}',"
            f"expected '{len(in_node_features.keys())}'"
        )
        for node_type, node_store in zip(graph.node_types, graph.node_stores):
            assert node_store["x"].shape[-1] == in_node_features[node_type], (
                f"Feature dimensions of node type "
                f"{node_type} do not match. Given "
                f"'{node_store['x'].shape[-1]}', "
                f"expected '{in_node_features[node_type]}"
            )

        # edge assertions
        in_edge_features = self._assertion_dict.get("in_edge_features")
        assert len(graph.edge_types) == len(in_edge_features.keys()), (
            f"Number of edge types does not match. "
            f"Given '{len(graph.edge_types)}', "
            f"expected '{len(in_edge_features.keys())}'"
        )
        for edge_type, edge_store in zip(graph.edge_types, graph.edge_stores):
            in_features = edge_store["edge_attr"].shape[-1]
            expected_features = in_edge_features[edge_type]
            assert in_features == expected_features, (
                f"Feature dimensions of edge type {edge_type} do not match. "
                f"Given '{in_features}', expected '{expected_features}"
            )

        if self._assert_global:
            for node_type in graph.node_types:
                assert hasattr(
                    graph[node_type], "batch"
                ), f"Need batch pointer for graph ids when using batch and global features for node type {node_type}"
                assert (
                    graph[node_type].batch is not None
                ), f"Batch pointer for node type {node_type} is None"
