from typing import Optional, Tuple, Any

from torch import nn
from torch_geometric.data.hetero_data import HeteroData

from hmpn.abstract.abstract_input_embedding import AbstractInputEmbedding
from hmpn.common.embedding import Embedding


class Heterogeneous2InputEmbedding(AbstractInputEmbedding):
    """
    Input feature embedding for Heterogeneous Graphs.
    """

    def __init__(
        self,
        *,
        in_node_features: dict[str, int],
        in_edge_features: dict[Tuple[str, str, str], int],
        in_global_features: Optional[int],
        embedding_config: Optional[dict[str, Any]],
        latent_dimension: int
    ):
        """
        Initializes the heterogeneous input embedding

        Args:
            in_node_features: Keys are node types and values are the number of input features for that node type
            in_edge_features: Keys are edge types and values are the number of input features for that edge type
            in_global_features: Number of input features for the global features
            embedding_config: Embedding configuration dictionary
            latent_dimension: Dimension of the latent space
        """
        super().__init__(
            in_global_features=in_global_features,
            latent_dimension=latent_dimension,
            embedding_config=embedding_config,
        )

        def get_embed(in_dim):
            return Embedding(in_dim, latent_dimension, embedding_config)

        self.node_embed = nn.ModuleDict(
            {name: get_embed(in_dim) for name, in_dim in in_node_features.items()}
        )

        if embedding_config.get("edges_shared_embedding", False):
            assert min(in_edge_features.values()) == max(in_edge_features.values())
            shared_edge_embed = get_embed(max(in_edge_features.values()))
            self.edge_embed = nn.ModuleDict(
                {str(name): shared_edge_embed for name in in_edge_features.keys()}
            )
        else:
            self.edge_embed = nn.ModuleDict(
                {
                    str(name): get_embed(in_dim)
                    for name, in_dim in in_edge_features.items()
                }
            )

    def forward(self, graph: HeteroData):
        """
        Computes the forward pass for this heterogeneous input embedding
        Args:
            graph: Batch object of pytorch geometric. Represents a batch of heterogeneous graphs

        Returns: None. In-place modification of the graph object.
        """
        for node_type, node_store in graph.node_items():
            node_store.x = self.node_embed[node_type](node_store.x)
        for edge_type, edge_store in graph.edge_items():
            edge_store.edge_attr = self.edge_embed[str(edge_type)](edge_store.edge_attr)

        super().forward(graph)
