from typing import Dict, Tuple, Any, List, Callable

from torch import nn

from hmpn.abstract.abstract_stack import AbstractStack
from hmpn.common.hmpn_util import count_in_node_features
from hmpn.common.latent_mlp import LatentMLP
from hmpn.heterogeneous.heterogeneous_step import HeterogeneousStep


class HeterogeneousStack(AbstractStack):
    """
    Message Passing module that acts on both node and edge features.
    Internally stacks multiple instances of MessagePassingSteps.
    This implementation is used for heterogeneous observation graphs.
    """

    def __init__(
        self,
        in_node_features: Dict[str, int],
        in_edge_features: Dict[Tuple[str, str, str], int],
        latent_dimension: int,
        scatter_reducers: List[Callable],
        stack_config: Dict[str, Any],
        use_global_features: bool = False,
        flip_edges_for_nodes: bool = False,
    ):
        """
        Builds a heterogeneous message passing stack
        Args:
            in_node_features:
                Dictionary {node_type: #node_features} of node_types and their input sizes for a heterogeneous graph.
                Node features may have size 0, in which case an empty input graph of the appropriate shape/batch_size
                is expected and the initial embeddings are learned constants
            in_edge_features:
                Dictionary {edge_type: #edge_features} of edge_types and their input sizes for a heterogeneous graph.
                Edge features may have size 0, in which case an empty input graph of the appropriate shape/batch_size
                is expected and the initial embeddings are learned constants
            stack_config: Dictionary specifying the way that the message passing network base should look like.
                num_steps: how many steps this stack should have
                residual_connections: which kind of residual connections to use. null/None for no connections,
                "outer" for residuals around each full message passing step, "inner" for residuals after each message
            latent_dimension: Dimensionality of the latent space of the MLPs in the network
            scatter_reducers: List of Functions to use to aggregate message for nodes (and potentially global information).
                Must be permutation invariant. Examples include sum, mean, min, std, max. Uses the torch.scatter
                implementation of these functions
            use_global_features: Wether to use global features.
        """
        super().__init__(latent_dimension=latent_dimension, stack_config=stack_config)

        # compute number of incoming edge types and thus input features per node type
        n_scatter_ops = len(scatter_reducers)
        in_node_features, num_edge_types = count_in_node_features(
            in_edge_features=in_edge_features,
            in_node_features=in_node_features,
            latent_dimension=latent_dimension,
            n_scatter_ops=n_scatter_ops,
        )

        mlp_config = stack_config.get("mlp")
        self._message_passing_steps = nn.ModuleList([])
        for _ in range(self._num_steps):
            if use_global_features:
                global_features = latent_dimension

                edge_types_count = len(in_edge_features.keys())
                node_types_count = len(in_node_features.keys())

                global_in_features = latent_dimension * (
                    ((edge_types_count + node_types_count) * n_scatter_ops) + 1
                )
                global_mlp = LatentMLP(
                    in_features=global_in_features,
                    latent_dimension=latent_dimension,
                    config=mlp_config,
                )
            else:
                global_mlp = None
                global_features = 0

            node_mlps = nn.ModuleDict(
                {
                    node_name: LatentMLP(
                        in_features=in_features + global_features,
                        latent_dimension=latent_dimension,
                        config=mlp_config,
                    )
                    for node_name, in_features in in_node_features.items()
                }
            )

            edge_mlps = nn.ModuleDict(
                {
                    "".join(edge_name): LatentMLP(
                        in_features=3 * latent_dimension + global_features,
                        latent_dimension=latent_dimension,
                        config=mlp_config,
                    )
                    for edge_name in in_edge_features.keys()
                }
            )

            self._message_passing_steps.append(
                HeterogeneousStep(
                    in_node_features=in_node_features,
                    in_edge_features=in_edge_features,
                    node_mlps=node_mlps,
                    edge_mlps=edge_mlps,
                    num_edge_types=num_edge_types,
                    stack_config=stack_config,
                    latent_dimension=latent_dimension,
                    scatter_reducers=scatter_reducers,
                    use_global_features=use_global_features,
                    flip_edges_for_nodes=flip_edges_for_nodes,
                    global_mlp=global_mlp,
                )
            )
