from typing import Dict, Any, List, Callable

from torch import nn

from hmpn.abstract.abstract_stack import AbstractStack
from hmpn.homogeneous.homogeneous_step import HomogeneousStep


class HomogeneousStack(AbstractStack):
    """
    Message Passing module that acts on both node and edge features.
    Internally stacks multiple instances of MessagePassingSteps.
    This implementation is used for homogeneous observation graphs.
    """

    def __init__(self,
                 stack_config: Dict[str, Any],
                 latent_dimension: int,
                 scatter_reducers: List[Callable],
                 use_global_features: bool = False,
                 flip_edges_for_nodes: bool = False):
        """
        Args:
            stack_config: Dictionary specifying the way that the message passing network base should look like.
                num_steps: how many steps this stack should have
                residual_connections: which kind of residual connections to use. null/None for no connections,
                "outer" for residuals around each full message passing step, "inner" for residuals after each message
            latent_dimension: the latent dimension of all vectors used in this stack
            scatter_reducers: functions of torch_scatter: min,max,mean,std,etc, as a list of functions
            use_global_features: whether to use global features
        """
        super().__init__(latent_dimension=latent_dimension, stack_config=stack_config)
        self._message_passing_steps = nn.ModuleList([HomogeneousStep(stack_config=stack_config,
                                                                     latent_dimension=latent_dimension,
                                                                     scatter_reducers=scatter_reducers,
                                                                     use_global_features=use_global_features,
                                                                     flip_edges_for_nodes=flip_edges_for_nodes)
                                                     for _ in range(self._num_steps)])
