from typing import Any, Callable

from torch import nn

from hmpn.abstract.abstract_stack import AbstractStack
from hmpn.heterogeneous2.heterogeneous2_step import Heterogeneous2Step


class Heterogeneous2Stack(AbstractStack):
    """
    Message Passing module that acts on both node and edge features.
    Internally stacks multiple instances of MessagePassingSteps.
    This implementation is used for heterogeneous observation graphs.
    """

    def __init__(
        self,
        latent_dimension: int,
        scatter_reducers: list[Callable],
        stack_config: dict[str, Any],
        use_global_features: bool = False,
        flip_edges_for_nodes: bool = False,
    ):
        """
        Builds a heterogeneous message passing stack
        Args:
            stack_config: Dictionary specifying the way that the message passing network base should look like.
                num_steps: how many steps this stack should have
                residual_connections: which kind of residual connections to use. null/None for no connections,
                "outer" for residuals around each full message passing step, "inner" for residuals after each message
            latent_dimension: Dimensionality of the latent space of the MLPs in the network
            scatter_reducers: List of Functions to use to aggregate message for nodes (and potentially global information).
                Must be permutation invariant. Examples include sum, mean, min, std, max. Uses the torch.scatter
                implementation of these functions
            use_global_features: Wether to use global features.
        """
        super().__init__(latent_dimension=latent_dimension, stack_config=stack_config)

        self._message_passing_steps = nn.ModuleList([])
        for step in range(self._num_steps):
            active_edges = []
            for edge_name, active_steps in stack_config.get("active_edges", {}).items():
                if step in active_steps or -1 in active_steps:
                    active_edges.append(("x", edge_name, "x"))
            self._message_passing_steps.append(
                Heterogeneous2Step(
                    stack_config=stack_config,
                    latent_dimension=latent_dimension,
                    scatter_reducers=scatter_reducers,
                    use_global_features=use_global_features,
                    flip_edges_for_nodes=flip_edges_for_nodes,
                    active_edges=active_edges,
                )
            )
