from typing import Dict, Union, Optional, Any, List, Type

from torch_geometric.data.batch import Batch

from hmpn.abstract.abstract_message_passing_base import AbstractMessagePassingBase
from hmpn.common.hmpn_util import unpack_homogeneous_features
from hmpn.homogeneous.homogeneous_graph_assertions import HomogeneousGraphAssertions
from hmpn.homogeneous.homogeneous_input_embedding import HomogeneousInputEmbedding
from hmpn.homogeneous.homogeneous_stack import HomogeneousStack


class HomogeneousMessagePassingBase(AbstractMessagePassingBase):
    """
        Graph Neural Network (GNN) Base module processes the graph observations of the environment.
        It uses a stack of GNN Steps. Each step defines a single GNN pass.
    """

    def __init__(self, *, in_node_features: int,
                 in_edge_features: int,
                 in_global_features: Optional[int],
                 latent_dimension: int,
                 scatter_reduce_strs: Union[List[str], str],
                 stack_config: Dict[str, Any],
                 embedding_config: Dict[str, Any],
                 unpack_output: bool,
                 edge_dropout: float = 0.0,
                 create_graph_copy: bool = True,
                 assert_graph_shapes: bool = True,
                 flip_edges_for_nodes: bool = False,
                 node_name: str = "node"):
        """
        Args:
            in_node_features:
                Node feature input size for a homogeneous graph.
                Node features may have size 0, in which case an empty input graph of the appropriate shape/batch_size
                is expected and the initial embeddings are learned constants
            in_edge_features:
                Edge feature input size for a homogeneous graph.
                Edge features may have size 0, in which case an empty input graph of the appropriate shape/batch_size
                is expected and the initial embeddings are learned constants
            in_global_features:
                If None, no global features will be used (and no GlobalModules created)
                May have size 0, in which case the initial values are a learned constant. This expects (empty) global
                 input tensors and will use the GlobalModule
            latent_dimension:
                Latent dimension of the network. All modules internally operate with latent vectors of this dimension
            scatter_reduce_strs:
                Names of the scatter reduce to use to aggregate messages of the same type.
                Can be multiple of "sum", "mean", "max", "min", "std"
                e.g. ["sum","max"]
            stack_config:
                Configuration of the stack of GNN steps. See the documentation of the stack for more information.
            embedding_config:
                Configuration of the embedding stack (can be empty by choosing None, resulting in linear embeddings).
            edge_dropout:
                Dropout probability for the edges. Removes edges from the graph with the given probability.
            unpack_output:
                If true, the output of the forward pass is unpacked into a tuple of (node_features, edge_features,
                global_features).
                If false, the output of the forward pass is the raw graph.
            create_graph_copy:
                If True, a copy of the input graph is created and modified in-place.
                If False, the input graph is modified in-place.
            assert_graph_shapes:
                If True, the input graph is checked for consistency with the input shapes.
                If False, the input graph is not checked for consistency with the input shapes.
            node_name:
                Name of the node. Used to convert the output of the forward pass to a dictionary
        """
        super().__init__(in_node_features=in_node_features,
                         in_edge_features=in_edge_features,
                         in_global_features=in_global_features,
                         latent_dimension=latent_dimension,
                         embedding_config=embedding_config,
                         scatter_reduce_strs=scatter_reduce_strs,
                         unpack_output=unpack_output,
                         edge_dropout=edge_dropout,
                         create_graph_copy=create_graph_copy,
                         assert_graph_shapes=assert_graph_shapes
                         )

        use_global_features = in_global_features is not None

        self._node_name = node_name

        # create message passing stack
        self.message_passing_stack = HomogeneousStack(stack_config=stack_config,
                                                      latent_dimension=latent_dimension,
                                                      scatter_reducers=self._scatter_reducers,
                                                      use_global_features=use_global_features,
                                                      flip_edges_for_nodes=flip_edges_for_nodes)

    def _get_assertions(self) -> Type[HomogeneousGraphAssertions]:
        return HomogeneousGraphAssertions

    @staticmethod
    def _get_input_embedding_type() -> Type[HomogeneousInputEmbedding]:
        return HomogeneousInputEmbedding

    def unpack_features(self, graph: Batch) -> Batch:
        return unpack_homogeneous_features(graph, node_name=self._node_name)
