###############
#   Package   #
###############
import os

import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from copy import deepcopy
from typing import Tuple, List

#######################
#   Embedding Layer   #
#######################
class SNE(nn.Module):
    def __init__(self,
                d_model: int = 128,
                value_LEN: int = None,
                independent_padding: bool = True,
                max_norm: float = None,
                ):
        super(SNE, self).__init__()
        # this version of "value embedding module" uses the production to connect the embedding vectors and values.
        # define variables and check it
        assert (value_LEN is not None) and (value_LEN > 0), ValueError("value_LEN should be positive integer.")
        self.value_LEN = value_LEN
        self.independent_padding = independent_padding

        # value embedding layer
        if independent_padding:
            self.value_embedding = nn.Embedding(num_embeddings=value_LEN+1, # add one dimension for missing value
                                                embedding_dim=d_model,
                                                padding_idx=0,
                                                max_norm=max_norm,
                                                )
        else:
            self.value_embedding = nn.Embedding(num_embeddings=value_LEN, # add one dimension for missing value
                                                embedding_dim=d_model,
                                                max_norm=max_norm,
                                                )

    def forward(self, x_idx, x, x_mask):
        """
            param: 
                x_idx: the vector to indicate each value's index.
                            dim = (batch size, summarization times, value_LEN)
                x: the value of data.
                            dim = (batch size, summarization times, value_LEN)
                x_mask: the mask to indicate which value is missing.
                            0 = missing value, 1 = non-missing value.
                            dim = (batch size, summarization times, value_LEN)
            output:
                x_emb: (batch size, summarization times, value_LEN, d_model)
        """
        embedding_idx = x_idx * x_mask if self.independent_padding else x_idx
        embedding_value = x * x_mask if self.independent_padding else x
        x_vec = self.value_embedding(embedding_idx)
        x_emb = embedding_value.unsqueeze(-1) * x_vec
        return x_emb

if __name__ == '__main__':
    pass
