from typing import Dict, Union, Optional, Tuple, Any, List, Type

from torch_geometric.data.hetero_data import HeteroData

from hmpn.abstract.abstract_message_passing_base import AbstractMessagePassingBase
from hmpn.common.hmpn_util import unpack_heterogeneous_features
from hmpn.heterogeneous.heterogeneous_graph_assertions import HeterogeneousGraphAssertions
from hmpn.heterogeneous.heterogeneous_input_embedding import HeterogeneousInputEmbedding
from hmpn.heterogeneous.heterogeneous_stack import HeterogeneousStack


class HeterogeneousMessagePassingBase(AbstractMessagePassingBase):
    """
        Graph Neural Network (GNN) Base module processes the graph observations of the environment.
        It uses a stack of GNN Steps. Each step defines a single GNN pass.
    """

    def __init__(self, *, in_node_features: Dict[str, int],
                 in_edge_features: Dict[Tuple[str, str, str], int],
                 in_global_features: Optional[int],
                 latent_dimension: int,
                 scatter_reduce_strs: Union[List[str], str],
                 stack_config: Dict[str, Any],
                 embedding_config: Dict[str, Any],
                 unpack_output: bool,
                 edge_dropout: float = 0.0,
                 create_graph_copy: bool = True,
                 assert_graph_shapes: bool = True,
                 flip_edges_for_nodes: bool = False):
        """

        Args:
            in_node_features:
                Dictionary {node_type: #node_features} of node_types and their input sizes for a heterogeneous graph.
                Node features may have size 0, in which case an empty input graph of the appropriate shape/batch_size
                is expected and the initial embeddings are learned constants
            in_edge_features:
                Dictionary {edge_type: #edge_features} of edge_types and their input sizes for a heterogeneous graph.
                Edge features may have size 0, in which case an empty input graph of the appropriate shape/batch_size
                is expected and the initial embeddings are learned constants
            in_global_features:
                If None, no global features will be used (and no GlobalModules created)
                May have size 0, in which case the initial values are a learned constant. This expects (empty) global
                 input tensors and will use the GlobalModule
            latent_dimension:
                Latent dimension of the network. All modules internally operate with latent vectors of this dimension
            scatter_reduce_strs:
                Names of the scatter reduce to use to aggregate messages of the same type.
                List of strings "sum", "mean", "max", "min", "std"
            stack_config:
                Configuration of the stack of GNN steps. See the documentation of the stack for more information.
            embedding_config:
                Configuration of the embedding stack (can be empty by choosing None, resulting in linear embeddings).
            unpack_output: If true, will unpack the processed batch of graphs to a 4-tuple of
                ({node_name: node features}, {edge_name: edge features}, global features, {node_name: batch indices}).
                Else, will return the raw processed batch of graphs

            edge_dropout:
                Dropout probability for the edges. Removes edges from the graph with the given probability.
            create_graph_copy:
                If True, a copy of the input graph is created and modified in-place.
                If False, the input graph is modified in-place.
            assert_graph_shapes:
                If True, the input graph is checked for consistency with the input shapes.
                If False, the input graph is not checked for consistency with the input shapes.
        """
        super().__init__(in_node_features=in_node_features,
                         in_edge_features=in_edge_features,
                         in_global_features=in_global_features,
                         latent_dimension=latent_dimension,
                         embedding_config=embedding_config,
                         scatter_reduce_strs=scatter_reduce_strs,
                         edge_dropout=edge_dropout,
                         unpack_output=unpack_output,
                         create_graph_copy=create_graph_copy,
                         assert_graph_shapes=assert_graph_shapes)

        use_global_features = in_global_features is not None

        # create message passing stack
        self.message_passing_stack = HeterogeneousStack(in_node_features=in_node_features,
                                                        in_edge_features=in_edge_features,
                                                        latent_dimension=latent_dimension,
                                                        scatter_reducers=self._scatter_reducers,
                                                        stack_config=stack_config,
                                                        use_global_features=use_global_features,
                                                        flip_edges_for_nodes=flip_edges_for_nodes)

    def _get_assertions(self) -> Type[HeterogeneousGraphAssertions]:
        return HeterogeneousGraphAssertions

    @staticmethod
    def _get_input_embedding_type() -> Type[HeterogeneousInputEmbedding]:
        return HeterogeneousInputEmbedding

    def _edge_dropout(self, graph: HeteroData):
        raise NotImplementedError("'edge_dropout' not implemented for HeterogeneousMessagePassingBase")

    def unpack_features(self, graph: HeteroData) -> HeteroData:
        """
        Unpacking important data from heterogeneous graphs.

        Params:
            graph – The input heterogeneous observation

        Returns:
            Tuple of (edge_features, edge_index, node_features, global_features, batch)

        """
        return unpack_heterogeneous_features(graph)
