###############
#   Package   #
###############
import os

import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from copy import deepcopy
from typing import Tuple, List

############################
#   Positional Embedding   #
############################
# REF: https://blog.csdn.net/qq_41897800/article/details/114777064'''
class SinusoidalEmbedding3d(nn.Module):
    def __init__(self,
                d_model: int,
                dropout: float = 0.1,
                max_seq_len: int = 5000,
                ):
        super(SinusoidalEmbedding3d, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(dim=1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).unsqueeze(2)
        self.register_buffer('pe', pe)

    def forward(self, x):
        ''' to be check!'''
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


###########################
#   Transformer Encoder   #
###########################
class EncoderLayer(nn.Module):
    def __init__(self,
                d_model: int,
                num_heads: int,
                ff_dim: int,
                ff_dropout: float = 0.1,
                attn_dropout: float = 0.1,
                norm_type: str = "LayerNorm",
                seq_len: int = None,
                num_value: int = None,
                ):
        super(EncoderLayer, self).__init__()
        assert seq_len is not None, ValueError("seq_len assigned uncorrectly.")
        assert num_value is not None, ValueError("num_value assigned uncorrectly.")
        self.seq_len = seq_len
        self.num_value = num_value
        self.num_heads = num_heads

        self.enc_self_attn = nn.MultiheadAttention(
            embed_dim = d_model,
            num_heads = num_heads,
            dropout = attn_dropout,
            batch_first = True,
            )

        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.GELU(),
            nn.Dropout(ff_dropout),
            nn.Linear(ff_dim, d_model),
            )

        assert norm_type in ["LayerNorm", "BatchNorm1d"], "Wrong Normalization Type."
        if norm_type == "LayerNorm":
            self.attn_norm = nn.LayerNorm(d_model)
            self.ff_norm = nn.LayerNorm(d_model)
        else:
            self.attn_norm = nn.BatchNorm1d(self.seq_len*(self.num_value))
            self.ff_norm = nn.BatchNorm1d(self.seq_len*(self.num_value))
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.ff_dropout = nn.Dropout(ff_dropout)

    def forward(self,
                src: Tensor,
                src_mask: Tensor,
                src_key_padding_mask: Tensor,
                ):
        output, attn_weight = self.enc_self_attn(
                                    query = src,
                                    key = src,
                                    value = src,
                                    key_padding_mask = src_key_padding_mask,
                                    attn_mask = src_mask
                                    )
        output = src + self.attn_dropout(output) # residual connection
        output = self.attn_norm(output) #normalization

        output_ = self.ff(output) # feed forward
        output = output + self.ff_dropout(output_) # residual connection
        output = self.ff_norm(output) # normalization

        return output, attn_weight

class Encoder(nn.Module):
    def __init__(self,
                encoder_layer: EncoderLayer,
                num_layers: int,
                only_mask_first: bool = False,
                ):
        super(Encoder, self).__init__()
        self.seq_len = encoder_layer.seq_len
        self.num_value = encoder_layer.num_value
        self.num_heads = encoder_layer.num_heads
        self.only_mask_first = only_mask_first

        self.layers = nn.ModuleList(
            [deepcopy(encoder_layer) for _ in range(num_layers)]
            )

    def get_key_padding_mask(self, mask: Tensor) -> Tensor:
        """
        param:
            mask: (Batch Size, seq_len, num_value)
                The mask indicates the positions of missing values.

        output:
            key_padding_mask: (Batch Size, seq_len * num_value)
        """
        flatten_mask = mask.view(mask.size(0), -1)
        return ~flatten_mask.bool()

    def forward(self,
                src: Tensor,
                mask: Tensor
                ) -> Tuple[Tensor, Tensor]:
        src_key_padding_mask = self.get_key_padding_mask(mask) # create key_padding_mask
        attns_all = []

        output = src
        counter = 0 if self.only_mask_first else -1

        for layer in self.layers:
            if counter == -1:
                output, attn = layer(src=output,
                                    src_mask=None,
                                    src_key_padding_mask=src_key_padding_mask,
                                    )
                attns_all.append(attn)

            elif counter == 0:
                output, attn = layer(src=output,
                                    src_mask=None,
                                    src_key_padding_mask=src_key_padding_mask,
                                    )
                attns_all.append(attn)
                counter += 1

            else:
                output, attn = layer(src=output,
                                    src_mask=None,
                                    src_key_padding_mask=None,
                                    )
                attns_all.append(attn)

        attns_all = torch.stack(attns_all)
        return output, attns_all

if __name__ == "__main__":
    pass
