from typing import Dict, Optional, Tuple, Any

from torch import nn
from torch_geometric.data.batch import Batch

from hmpn.abstract.abstract_input_embedding import AbstractInputEmbedding
from hmpn.common.embedding import Embedding
from hmpn.common.hmpn_util import tuple_to_string


class HeterogeneousInputEmbedding(AbstractInputEmbedding):
    """
    Input feature embedding for Heterogeneous Graphs.
    """

    def __init__(
        self,
        *,
        in_node_features: Dict[str, int],
        in_edge_features: Dict[Tuple[str, str, str], int],
        in_global_features: Optional[int],
        embedding_config: Optional[Dict[str, Any]],
        latent_dimension: int
    ):
        """
        Initializes the heterogeneous input embedding

        Args:
            in_node_features: Keys are node types and values are the number of input features for that node type
            in_edge_features: Keys are edge types and values are the number of input features for that edge type
            in_global_features: Number of input features for the global features
            embedding_config: Embedding configuration dictionary
            latent_dimension: Dimension of the latent space
        """
        super().__init__(
            in_global_features=in_global_features,
            latent_dimension=latent_dimension,
            embedding_config=embedding_config,
        )
        self.node_input_embeddings = nn.ModuleDict(
            {
                node_name: Embedding(
                    in_features=num_node_features,
                    latent_dimension=latent_dimension,
                    embedding_config=embedding_config,
                )
                for node_name, num_node_features in in_node_features.items()
            }
        )

        self.edge_input_embeddings = nn.ModuleDict(
            {
                tuple_to_string(edge_name): Embedding(
                    in_features=num_edge_features,
                    latent_dimension=latent_dimension,
                    embedding_config=embedding_config,
                )
                for edge_name, num_edge_features in in_edge_features.items()
            }
        )

    def forward(self, graph: Batch):
        """
        Computes the forward pass for this heterogeneous input embedding
        Args:
            graph: Batch object of pytorch geometric. Represents a batch of heterogeneous graphs

        Returns: None. In-place modification of the graph object.
        """
        for position, node_type in enumerate(graph.node_types):
            graph.node_stores[position]["x"] = self.node_input_embeddings[node_type](
                graph.node_stores[position]["x"]
            )

        for position, edge_type in enumerate(graph.edge_types):
            graph.edge_stores[position]["edge_attr"] = self.edge_input_embeddings[
                tuple_to_string(edge_type)
            ](graph.edge_stores[position]["edge_attr"])

        super().forward(graph=graph)
