from typing import Union, Optional, Any, Type

from torch_geometric.data.hetero_data import HeteroData

from hmpn.abstract.abstract_message_passing_base import AbstractMessagePassingBase
from hmpn.common.hmpn_util import unpack_heterogeneous_features
from hmpn.heterogeneous2.heterogeneous2_graph_assertions import (
    Heterogeneous2GraphAssertions,
)
from hmpn.heterogeneous2.heterogeneous2_input_embedding import (
    Heterogeneous2InputEmbedding,
)
from hmpn.heterogeneous2.heterogeneous2_stack import Heterogeneous2Stack


class Heterogeneous2MessagePassingBase(AbstractMessagePassingBase):
    """
    Graph Neural Network (GNN) Base module processes the graph observations of the environment.
    It uses a stack of GNN Steps. Each step defines a single GNN pass.
    """

    def __init__(
        self,
        *,
        in_node_features: dict[str, int],
        in_edge_features: dict[tuple[str, str, str], int],
        in_global_features: Optional[int],
        latent_dimension: int,
        scatter_reduce_strs: Union[list[str], str],
        stack_config: dict[str, Any],
        embedding_config: dict[str, Any],
        unpack_output: bool,
        edge_dropout: float = 0.0,
        create_graph_copy: bool = True,
        assert_graph_shapes: bool = True,
        flip_edges_for_nodes: bool = False
    ):
        """

        Args:
            in_node_features:
                Dictionary {node_type: #node_features} of node_types and their input sizes for a heterogeneous graph.
                Node features may have size 0, in which case an empty input graph of the appropriate shape/batch_size
                is expected and the initial embeddings are learned constants
            in_edge_features:
                Dictionary {edge_type: #edge_features} of edge_types and their input sizes for a heterogeneous graph.
                Edge features may have size 0, in which case an empty input graph of the appropriate shape/batch_size
                is expected and the initial embeddings are learned constants
            in_global_features:
                If None, no global features will be used (and no GlobalModules created)
                May have size 0, in which case the initial values are a learned constant. This expects (empty) global
                 input tensors and will use the GlobalModule
            latent_dimension:
                Latent dimension of the network. All modules internally operate with latent vectors of this dimension
            scatter_reduce_strs:
                Names of the scatter reduce to use to aggregate messages of the same type.
                List of strings "sum", "mean", "max", "min", "std"
            stack_config:
                Configuration of the stack of GNN steps. See the documentation of the stack for more information.
            embedding_config:
                Configuration of the embedding stack (can be empty by choosing None, resulting in linear embeddings).
            unpack_output: If true, will unpack the processed batch of graphs to a 4-tuple of
                ({node_name: node features}, {edge_name: edge features}, global features, {node_name: batch indices}).
                Else, will return the raw processed batch of graphs

            edge_dropout:
                Dropout probability for the edges. Removes edges from the graph with the given probability.
            create_graph_copy:
                If True, a copy of the input graph is created and modified in-place.
                If False, the input graph is modified in-place.
            assert_graph_shapes:
                If True, the input graph is checked for consistency with the input shapes.
                If False, the input graph is not checked for consistency with the input shapes.
        """
        super().__init__(
            in_node_features=in_node_features,
            in_edge_features=in_edge_features,
            in_global_features=in_global_features,
            latent_dimension=latent_dimension,
            embedding_config=embedding_config,
            scatter_reduce_strs=scatter_reduce_strs,
            edge_dropout=edge_dropout,
            unpack_output=unpack_output,
            create_graph_copy=create_graph_copy,
            assert_graph_shapes=assert_graph_shapes,
        )

        use_global_features = in_global_features is not None

        # create message passing stack
        self.message_passing_stack = Heterogeneous2Stack(
            latent_dimension=latent_dimension,
            scatter_reducers=self._scatter_reducers,
            stack_config=stack_config,
            use_global_features=use_global_features,
            flip_edges_for_nodes=flip_edges_for_nodes,
        )

    def _get_assertions(self) -> Type[Heterogeneous2GraphAssertions]:
        return Heterogeneous2GraphAssertions

    @staticmethod
    def _get_input_embedding_type() -> Type[Heterogeneous2InputEmbedding]:
        return Heterogeneous2InputEmbedding

    def _edge_dropout(self, graph: HeteroData):
        raise NotImplementedError(
            "'edge_dropout' not implemented for HeterogeneousMessagePassingBase"
        )

    def unpack_features(self, graph: HeteroData) -> HeteroData:
        """
        Unpacking important data from heterogeneous graphs.

        Params:
            graph – The input heterogeneous observation

        Returns:
            Tuple of (edge_features, edge_index, node_features, global_features, batch)

        """
        return unpack_heterogeneous_features(graph)
