import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# pylint: disable=no-member


class VanillaTransformerEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.reset_parameters()

    def forward(self, input_ids=None, position_ids=None):
        input_shape = input_ids.size()
        seq_length = input_shape[1]
        device = input_ids.device

        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)

        embeddings = self.word_embeddings(input_ids)

        position_embeddings = self.position_embeddings(position_ids)

        embeddings = embeddings + position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

    def get_word_embedding(self) -> torch.Tensor:
        return self.word_embeddings.weight

    def reset_parameters(self):
        nn.init.normal_(self.word_embeddings.weight, mean=0, std=self.config.hidden_size**-0.5)
        nn.init.normal_(self.position_embeddings.weight, mean=0, std=self.config.hidden_size**-0.5)

        if self.config.pad_token_id is not None:
            nn.init.constant_(self.word_embeddings.weight[self.config.pad_token_id], 0)


class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super().__init__()
        self.demb = demb

        inv_freq = 1 / (10000**(torch.arange(0.0, demb, 2.0) / demb))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

        if bsz is not None:
            return pos_emb[:, None, :].expand(-1, bsz, -1)
        else:
            return pos_emb[:, None, :]
