from typing import Dict, Union, Tuple, Any, List, Callable

from torch import nn
from torch_geometric.data.batch import Batch

from hmpn.abstract.abstract_step import AbstractStep
from hmpn.common.hmpn_util import noop
from hmpn.common.latent_mlp import LatentMLP
from hmpn.heterogeneous.heterogeneous_modules import (
    HeterogeneousEdgeModule,
    HeterogeneousNodeModule,
    HeterogeneousGlobalModule,
)


class HeterogeneousStep(AbstractStep):
    """
    Defines a single Message Passing Step that takes a heterogeneous observation graph and updates its node and edge
    features using different modules (Edge, Node, Global).
    It first updates the edge-features. The node-features are updated next using the new edge-features. Finally,
    it updates the global features using the new edge- & node-features. The updates are done through MLPs.
    """

    def __init__(
        self,
        in_node_features: Dict[str, int],
        in_edge_features: Dict[Tuple[str, str, str], int],
        node_mlps: nn.ModuleDict,
        num_edge_types: Dict[str, int],
        edge_mlps: nn.ModuleDict,
        latent_dimension: int,
        scatter_reducers: Union[Callable, List[Callable]],
        stack_config: Dict[str, Any],
        use_global_features: bool,
        flip_edges_for_nodes: bool,
        global_mlp: LatentMLP,
    ):
        """
        Initializes the HeterogeneousStep.

        Args:
            node_mlps: Dictionary of node MLPs. The keys are the node types and the values are the MLPs.
            num_edge_types: Dictionary of edge types and the number of edge types.
            edge_mlps: Dictionary of edge MLPs. The keys are the edge types and the values are the MLPs.
            latent_dimension: Dimension of the latent space.
            scatter_reducers: list of functions from torch_scatter to use for scatter operations.
            stack_config: Dictionary of stack configuration.
            use_global_features: Whether to use global features or not.
            global_mlp: MLP for global features.
        """
        super().__init__(
            stack_config=stack_config,
            latent_dimension=latent_dimension,
            use_global_features=use_global_features,
        )

        # edge module
        self.edge_module = HeterogeneousEdgeModule(
            mlps=edge_mlps,
            scatter_reducers=scatter_reducers,
            use_global_features=use_global_features,
            latent_dimension=latent_dimension,
        )

        # node module

        self.node_module = HeterogeneousNodeModule(
            mlps=node_mlps,
            num_edge_types=num_edge_types,
            latent_dimension=latent_dimension,
            scatter_reducers=scatter_reducers,
            use_global_features=use_global_features,
            flip_edges_for_nodes=flip_edges_for_nodes,
        )

        if use_global_features:
            self.global_module = HeterogeneousGlobalModule(
                mlps=nn.ModuleDict({"u": global_mlp}),
                scatter_reducers=scatter_reducers,
                use_global_features=use_global_features,
                latent_dimension=latent_dimension,
            )
            self.maybe_global = self.global_module
        else:
            self.global_module = None
            self.maybe_global = noop

        if self.use_layer_norm:
            self._node_layer_norms = nn.ModuleDict(
                {
                    node_name: nn.LayerNorm(normalized_shape=latent_dimension)
                    for node_name in in_node_features.keys()
                }
            )
            self._edge_layer_norms = nn.ModuleDict(
                {
                    "".join(edge_name): nn.LayerNorm(normalized_shape=latent_dimension)
                    for edge_name in in_edge_features.keys()
                }
            )

        else:
            self._node_layer_norms = None
            self._edge_layer_norms = None

        self.reset_parameters()

    def _store_nodes(self, graph: Batch):
        self._old_graph["node_stores"] = [
            {"x": node_store_dict.get("x")} for node_store_dict in graph.node_stores
        ]

    def _store_edges(self, graph: Batch):
        self._old_graph["edge_stores"] = [
            {"edge_attr": edge_store_dict.get("edge_attr")}
            for edge_store_dict in graph.edge_stores
        ]

    def _add_node_residual(self, graph: Batch):
        for position, node_type in enumerate(graph.node_types):
            graph.node_stores[position]["x"] = (
                graph.node_stores[position]["x"]
                + self._old_graph.get("node_stores")[position]["x"]
            )

    def _add_edge_residual(self, graph: Batch):
        for position, edge_type in enumerate(graph.edge_types):
            graph.edge_stores[position]["edge_attr"] = (
                graph.edge_stores[position]["edge_attr"]
                + self._old_graph.get("edge_stores")[position]["edge_attr"]
            )

    def _node_layer_norm(self, graph: Batch) -> None:
        for position, node_type in enumerate(graph.node_types):
            graph.node_stores[position]["x"] = self._node_layer_norms[node_type](
                graph.node_stores[position]["x"]
            )

    def _edge_layer_norm(self, graph: Batch) -> None:
        for position, edge_type in enumerate(graph.edge_types):
            graph.edge_stores[position]["edge_attr"] = self._edge_layer_norms[
                "".join(edge_type)
            ](graph.edge_stores[position]["edge_attr"])
