from typing import Dict, Union, Tuple, Any, List, Callable

from torch import nn
from torch_geometric.data.batch import Batch

from hmpn.abstract.abstract_step import AbstractStep
from hmpn.common.hmpn_util import noop
from hmpn.common.latent_mlp import LatentMLP
from hmpn.heterogeneous.heterogeneous_modules import HeterogeneousEdgeModule, \
    HeterogeneousNodeModule, HeterogeneousGlobalModule


class HeterogeneousStep(AbstractStep):
    """
         Defines a single Message Passing Step that takes a heterogeneous observation graph and updates its node and edge
         features using different modules (Edge, Node, Global).
         It first updates the edge-features. The node-features are updated next using the new edge-features. Finally,
         it updates the global features using the new edge- & node-features. The updates are done through MLPs.
    """

    def __init__(self,
                 in_node_features: Dict[str, int],
                 in_edge_features: Dict[Tuple[str, str, str], int],
                 node_mlps: nn.ModuleDict,
                 num_edge_types: Dict[str, int],
                 edge_mlps: nn.ModuleDict,
                 latent_dimension: int,
                 scatter_reducers: Union[Callable, List[Callable]],
                 stack_config: Dict[str, Any],
                 use_global_features: bool,
                 flip_edges_for_nodes: bool,
                 global_mlp: LatentMLP):
        """
        Initializes the HeterogeneousStep.

        Args:
            node_mlps: Dictionary of node MLPs. The keys are the node types and the values are the MLPs.
            num_edge_types: Dictionary of edge types and the number of edge types.
            edge_mlps: Dictionary of edge MLPs. The keys are the edge types and the values are the MLPs.
            latent_dimension: Dimension of the latent space.
            scatter_reducers: list of functions from torch_scatter to use for scatter operations.
            stack_config: Dictionary of stack configuration.
            use_global_features: Whether to use global features or not.
            global_mlp: MLP for global features.
        """
        super().__init__(stack_config=stack_config,
                         latent_dimension=latent_dimension,
                         use_global_features=use_global_features)

        # edge module
        self.edge_module = HeterogeneousEdgeModule(mlps=edge_mlps,
                                                   scatter_reducers=scatter_reducers,
                                                   use_global_features=use_global_features,
                                                   latent_dimension=latent_dimension)

        # node module

        self.node_module = HeterogeneousNodeModule(mlps=node_mlps,
                                                   num_edge_types=num_edge_types,
                                                   latent_dimension=latent_dimension,
                                                   scatter_reducers=scatter_reducers,
                                                   use_global_features=use_global_features,
                                                   flip_edges_for_nodes=flip_edges_for_nodes)

        if use_global_features:
            self.global_module = HeterogeneousGlobalModule(
                mlps=nn.ModuleDict({"u": global_mlp}),
                scatter_reducers=scatter_reducers,
                use_global_features=use_global_features,
                latent_dimension=latent_dimension
            )
            self.maybe_global = self.global_module
        else:
            self.global_module = None
            self.maybe_global = noop

        if self.use_layer_norm:
            self._node_layer_norms = nn.ModuleDict({node_name: nn.LayerNorm(normalized_shape=latent_dimension)
                                                    for node_name in in_node_features.keys()})
            self._edge_layer_norms = nn.ModuleDict({"".join(edge_name): nn.LayerNorm(normalized_shape=latent_dimension)
                                                    for edge_name in in_edge_features.keys()})

        else:
            self._node_layer_norms = None
            self._edge_layer_norms = None

        self.reset_parameters()

    def _store_nodes(self, graph: Batch):
        self._old_graph["node_stores"] = [{"x": node_store_dict.get("x")}
                                          for node_store_dict in graph.node_stores]

    def _store_edges(self, graph: Batch):
        self._old_graph["edge_stores"] = [{"edge_attr": edge_store_dict.get("edge_attr")}
                                          for edge_store_dict in graph.edge_stores]

    def _add_node_residual(self, graph: Batch):
        for position, node_type in enumerate(graph.node_types):
            graph.node_stores[position]["x"] = graph.node_stores[position]["x"] + \
                                               self._old_graph.get("node_stores")[position]["x"]

    def _add_edge_residual(self, graph: Batch):
        for position, edge_type in enumerate(graph.edge_types):
            graph.edge_stores[position]["edge_attr"] = graph.edge_stores[position]["edge_attr"] + \
                                                       self._old_graph.get("edge_stores")[position]["edge_attr"]

    def _node_layer_norm(self, graph: Batch) -> None:
        for position, node_type in enumerate(graph.node_types):
            graph.node_stores[position]["x"] = self._node_layer_norms[node_type](graph.node_stores[position]["x"])

    def _edge_layer_norm(self, graph: Batch) -> None:
        for position, edge_type in enumerate(graph.edge_types):
            graph.edge_stores[position]["edge_attr"] = \
                self._edge_layer_norms["".join(edge_type)](graph.edge_stores[position]["edge_attr"])
