from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F

from pado.core import PadoModule
from pado.nn.parameter import ParameterModule

__all__ = ["Embedding"]


class Embedding(PadoModule):

    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 padding_idx: Optional[int] = None,
                 max_norm: Optional[float] = None,
                 norm_type: float = 2.0,
                 *, word_drop_prob: float = 0.0,
                 scale_grad_by_freq: bool = False,
                 sparse: bool = False) -> None:
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        if padding_idx is not None:
            if padding_idx < 0:
                padding_idx += num_embeddings
            if not (0 <= padding_idx < self.num_embeddings):
                raise ValueError("Padding index is not in range.")
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.sparse = sparse
        self.word_drop_prob = word_drop_prob

        self.weight = ParameterModule(torch.empty(num_embeddings, embedding_dim))

        self._initialize_parameters()

    def _initialize_parameters(self):
        nn.init.normal_(self.weight.data)
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight.data[self.padding_idx].fill_(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight = self._embedding_word_dropout(self.weight(), self.training, self.word_drop_prob)
        return F.embedding(x, weight, self.padding_idx, self.max_norm,
                           self.norm_type, self.scale_grad_by_freq, self.sparse)

    @staticmethod
    def _embedding_word_dropout(weight: torch.Tensor,
                                training: bool,
                                word_drop_prob: float) -> torch.Tensor:
        if training and (word_drop_prob > 0):
            num_embeddings = weight.shape[0]
            keep_p = 1.0 - word_drop_prob
            word_mask = torch.bernoulli(torch.ones(num_embeddings, 1, dtype=weight.dtype,
                                                   device=weight.device), p=keep_p).bool()
            weight = (weight * word_mask)  # do we need to divide by keep_p?
        return weight
