from typing import Dict, Optional, Any

from torch_geometric.data.data import Data

from hmpn.abstract.abstract_input_embedding import AbstractInputEmbedding
from hmpn.common.embedding import Embedding


class HomogeneousInputEmbedding(AbstractInputEmbedding):
    def __init__(self,
                 *,
                 in_node_features: int,
                 in_edge_features: int,
                 in_global_features: Optional[int],
                 embedding_config: Optional[Dict[str, Any]],
                 latent_dimension: int):
        """
        Builds and returns an input embedding for a homogeneous graph.
        Args:
            in_node_features:
                number of input node features
            in_edge_features:
                number of input edge features
            in_global_features:
                number of input global features. None if no global features are used.
            latent_dimension:
                dimension of the latent space.
        """
        super().__init__(in_global_features=in_global_features, latent_dimension=latent_dimension,
                         embedding_config=embedding_config)

        self.node_input_embedding = Embedding(in_features=in_node_features,
                                              latent_dimension=latent_dimension,
                                              embedding_config=embedding_config)

        self.edge_input_embedding = Embedding(in_features=in_edge_features,
                                              latent_dimension=latent_dimension,
                                              embedding_config=embedding_config)

        if in_global_features is not None:
            self.global_input_embedding = Embedding(in_features=in_global_features,
                                                    latent_dimension=latent_dimension,
                                                    embedding_config=embedding_config)

    def forward(self, graph: Data):
        """
        Computes the forward pass for this homogeneous input embedding inplace
        Args:
            graph: torch_geometric.data.Batch, represents a batch of homogeneous graphs
        Returns:
            None
        """
        graph.__setattr__("x", self.node_input_embedding(graph.x))
        graph.__setattr__("edge_attr", self.edge_input_embedding(graph.edge_attr))

        super().forward(graph=graph)
