from typing import Sequence
from torch import Tensor, nn
from meta_diffusion.model.embedder.base import MetaDiffEmbedder
from meta_diffusion.model.embedder.time import timestep_embedding
from meta_diffusion.model.embedder.utils import (
    ScalarEmbeddingSine1D, ScalarEmbeddingSine2D, ScalarEmbeddingSine3D
)


class MISEmbedder(MetaDiffEmbedder):
    def __init__(self, hidden_dim: int, sparse: bool, time_flag: bool):
        super(MISEmbedder, self).__init__(hidden_dim, sparse, time_flag)
        
        if self.sparse:
            # node embedder
            self.node_embed = nn.Sequential(
                ScalarEmbeddingSine1D(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )
        
            # edge embedder
            self.edge_embed = nn.Sequential(
                ScalarEmbeddingSine1D(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )
            
        else:
            # node embedder
            self.node_embed = nn.Sequential(
                ScalarEmbeddingSine2D(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )
        
            # edge embedder
            self.edge_embed = nn.Sequential(
                ScalarEmbeddingSine3D(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )

    def sparse_forward(
        self, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, t: Tensor
    ) -> Sequence[Tensor]:
        """
        Args:
            nodes_feature: (V,) [not use]
            x: (V,)
            edges_feature: (E,)
            e: (E,) [not use]
            t: (1,)
        Return:
            x: (V, H)
            e: (E, H)
            t: (H)
        """   
        # embedding
        x = self.node_embed(x) # (V, H)
        e = self.edge_embed(edges_feature) # (E, H)
        if t is not None:
            t = self.time_embedder(
                timestep_embedding(t, self.hidden_dim)
            ) # (H,)
        return x, e, t
    
    def dense_forward(
        self, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, t: Tensor
    ) -> Sequence[Tensor]:
        """
        Args:
            nodes_feature: (B, V) [not use]
            x: (B, V)
            edges_feature: (B, V, V)
            e: (B, V, V) [not use]
            t: (1,)
        Return:
            x: (B, V, H)
            e: (B, V, V, H)
            t: (H)
        """
        # embedding
        x = self.node_embed(x) # (B, V, H)
        e = self.edge_embed(edges_feature) # (B, V, V, H)
        if t is not None:
            t = self.time_embedder(
                timestep_embedding(t, self.hidden_dim)
            ) # (H,)
        return x, e, t