import torch
import torch.nn.functional as F
from typing import Sequence
from torch import Tensor, nn
from meta_diffusion.model.embedder.base import MetaDiffEmbedder
from meta_diffusion.model.embedder.time import timestep_embedding
from meta_diffusion.model.embedder.utils import (
    ScalarEmbeddingSine1D, ScalarEmbeddingSine3D, PositionEmbeddingSine
)


class ATSPEmbedder(MetaDiffEmbedder):
    def __init__(self, hidden_dim: int, sparse: bool, time_flag: bool):
        super(ATSPEmbedder, self).__init__(hidden_dim, sparse, time_flag)
        
        if self.sparse:
            # node embedder
            self.node_embed = nn.Sequential(
                PositionEmbeddingSine(hidden_dim // 2),
                nn.Linear(hidden_dim, hidden_dim)
            )
        
            # edge embedder
            self.edge_embed_e = nn.Sequential(
                ScalarEmbeddingSine1D(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )

            self.edge_embed_d = nn.Sequential(
                ScalarEmbeddingSine1D(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )
            
        else:
            # node embedder
            self.node_embed = nn.Sequential(
                PositionEmbeddingSine(hidden_dim // 2),
                nn.Linear(hidden_dim, hidden_dim)
            )
        
            # edge embedder
            self.edge_embed_e = nn.Sequential(
                ScalarEmbeddingSine3D(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.edge_embed_d = nn.Sequential(
                ScalarEmbeddingSine3D(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim)
            )

        mix1_init = (1 / 2) ** (1 / 2)
        mix2_init = (1 / 16) ** (1 / 2)
        mix1_weight = torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample((2, hidden_dim))
        mix1_bias = torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample((hidden_dim,))
        self.mix1_weight = nn.Parameter(mix1_weight)
        self.mix1_bias = nn.Parameter(mix1_bias)
        
        mix2_weight = torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample((hidden_dim, 1))
        mix2_bias = torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample((1,))
        self.mix2_weight = nn.Parameter(mix2_weight)
        self.mix2_bias = nn.Parameter(mix2_bias)

    def sparse_forward(
        self, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, t: Tensor
    ) -> Sequence[Tensor]:
        """
        Args:
            nodes_feature: (V, 2) 
            x: (V,) [not use]
            edges_feature: (E,) [not use]
            e: (E,) 
            t: (1,)
        Return:
            x: (V, H)
            e: (E, H)
            t: (H)
        """   
        # embedding
        x = self.node_embed(nodes_feature)
        e = e.unsqueeze(1)

        two_scores = torch.stack([e, edges_feature], dim=2) # (E X 2)
        ms1 = torch.matmul(two_scores, self.mix1_weight) # (E X H)
        ms1 = ms1 + self.mix1_bias[None, None, :] # (E X H)
        ms1_activated = F.relu(ms1) # (E X H)
        ms2 = torch.matmul(ms1_activated, self.mix2_weight) # (E X 1)
        ms2 = ms2 + self.mix2_bias[None, None, :] # (E X 1)
        e = ms2.squeeze(-1) # (E,)

        d = self.edge_embed_d(edges_feature)
        e = self.edge_embed_e(e) # (E, H)

        if t is not None:
            t = self.time_embedder(
                timestep_embedding(t, self.hidden_dim)
            ) # (H,)
        return x, d, e, t
    
    def dense_forward(
        self, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, t: Tensor
    ) -> Sequence[Tensor]:
        """
        Args:
            nodes_feature: (B, V, 2) [not use]
            x: (B, V) [not use]
            edges_feature: (B, V, V)
            e: (B, V, V)
            t: (1,)
        Return:
            e: (B, V, V, H)
            t: (H)
        """
        # edges embedding
        x = self.node_embed(nodes_feature)

        two_scores = torch.stack([e, edges_feature], dim=3) # (B x V x V X 2)
        ms1 = torch.matmul(two_scores, self.mix1_weight) # (B x V x V X H)
        ms1 = ms1 + self.mix1_bias[None, None, None, :] # (B x V x V X H)
        ms1_activated = F.relu(ms1) # (B x V x V X H)
        ms2 = torch.matmul(ms1_activated, self.mix2_weight) # (B x V x V X 1)
        ms2 = ms2 + self.mix2_bias[None, None, None, :] # (B x V x V X 1)
        e = ms2.squeeze(-1) # (B x V x V)

        d = self.edge_embed_d(edges_feature)
        e = self.edge_embed_e(e) # (B, V, V, H)

        if t is not None:
            t = self.time_embedder(
                timestep_embedding(t, self.hidden_dim)
            ) # (H,)
        return x, d, e, t