import math
import torch
from torch import nn


class SinusoidalPositionEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.):
        """Positional encoding with sine and cosine.

        Args:
            d_model: Hidden dimensionality of the input.
            max_len: Maximum length of a sequence to expect.
        """
        
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, x):
        """
        Parameters
        ----------
        x: Tensor, shape (batch_size, num_nodes, num_samples, embedding_dim/2)

        Returns
        -------
        x: Tensor, shape (batch_size, num_nodes, num_samples, embedding_dim)
            Input vector with positional encodings concatenated
        """
        x = concat_embeddings(x, self.pe)
        return self.dropout(x)
    

class LinearPositionalEncoding(nn.Module):
    """Positional encoding with linear embeddings.

    Args:
        d_model: Hidden dimensionality of the input.
        max_len: Maximum length of a sequence to expect.
    """

    def __init__(self, d_model: int, max_len: int = 5000) -> None:
        super().__init__()
        
        self.max_len = max_len

        self.identity_embedding = nn.ModuleList([
            nn.Linear(1, d_model)
            for _ in range(max_len)
        ])

    def forward(self, x: torch.tensor, features_id: torch.tensor = None):
        """Compute the embeddings and concatenate them to the input.
        """
        # Compute the index embeddings. Output shape D x E
        if features_id is None:
            features_id = torch.tensor(range(self.max_len), dtype=x.dtype, device=x.device).unsqueeze(-1)
        embedded_indices = torch.stack([
            embed(features_id[i]) for i, embed in enumerate(self.identity_embedding)
        ], dim=0)

        return concat_embeddings(x, embedded_indices)
    


class NonlinearPositionalEncoding(LinearPositionalEncoding):
    """Positional encoding with nonlinear embeddings.

    Args:
        d_model: Hidden dimensionality of the input.
        max_len: Maximum length of a sequence to expect.
        input_dim: The length of the input element to embed (default 1, embedding a scalar).
    """
    def __init__(self, d_model: int, max_len: int = 5000, input_dim: int = 1) -> None:
        super().__init__(d_model, max_len)
        self.identity_embedding = nn.ModuleList([
            nn.Sequential(
            nn.Linear(input_dim, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
            nn.Linear(d_model, 4*d_model),
            nn.LayerNorm(4*d_model),
            nn.GELU(),
            nn.Linear(4*d_model, d_model),
            nn.LayerNorm(d_model)
            )
            for _ in range(max_len)
        ])



class NonlinearSinusoidalPositionalEncoding(NonlinearPositionalEncoding):
    """Sinusoidal positional encoding with nonlinear transformations embeddings.

    Args:
        d_model: Hidden dimensionality of the input.
        max_len: Maximum length of a sequence to expect.
    """
    def __init__(self, d_model: int, max_len: int = 5000) -> None:
        super().__init__(d_model, max_len, input_dim=d_model)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, x: torch.tensor):
        return super().forward(x, features_id=self.pe)


def concat_embeddings(x: torch.tensor, embeddings: torch.tensor) -> torch.tensor:
    """Concatenate x and embeddings.

    Args:
        x (torch.tensor): tensor of shape (batch_size, n_nodes, n_samples, half_embed_dim).
        embeddings (torch.tensor): (n_nodes, half_embed_dim).

    Returns:
        torch.tensor: x vector with node embeddings concatenated at each sample.
    """
    concat = []
    batch_size = x.size(0)
    pe = torch.stack([embeddings for _ in range(batch_size)], dim=0)

    # Add sample dimension for decoder input
    remove_sample_dim = False
    if x.dim() == 3:
        x = x.unsqueeze(2)
        remove_sample_dim = True

    for i in range(x.size(2)): # for each sample concat self.pe
        concat.append(torch.cat([x[:, :, i, :], pe], dim=2).unsqueeze(2))
    x = torch.cat(concat, dim=2)

    # Remove batch dimension for decoder output
    if remove_sample_dim:
        x = x.squeeze(2)
        
    return x