from typing import Optional
import torch
import torch.nn as nn
from pado.core import PadoModuleMixin

__all__ = ["Dropout", "SeqDropout"]


class Dropout(nn.Dropout, PadoModuleMixin):
    def __init__(self,
                 drop_prob: float = 0.5,
                 inplace: bool = False) -> None:
        nn.Dropout.__init__(self, p=drop_prob, inplace=inplace)
        PadoModuleMixin.__init__(self)


class SeqDropout(nn.Dropout, PadoModuleMixin):

    def __init__(self,
                 drop_prob: float = 0.5,
                 inplace: bool = False,
                 dim: int = 1) -> None:
        nn.Dropout.__init__(self, p=drop_prob, inplace=inplace)
        PadoModuleMixin.__init__(self)
        self.dim = dim

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Dropout that share same mask through time
        Assume time axis is 1
        :param x:       (batch_size, seq_length, ...)
        :param mask:    (batch_size, seq_length)    bool, T: valid, F: pad
        :return:
                y:      (batch_size, seq_length, ...)
        """
        if (not self.training) or (self.p <= 0):
            return x

        if mask is None:
            mask = self.generate_mask(x)

        if self.inplace:
            x *= mask
        else:
            x = x * mask
        return x

    @torch.no_grad()
    def generate_mask(self, x: torch.Tensor) -> torch.Tensor:
        x_shape = list(x.shape)  # (B, S, D)
        x_shape[self.dim] = 1  # (B, 1, D)

        keep_p = 1.0 - self.p
        mask = torch.bernoulli(torch.ones(x_shape, device=x.device, dtype=x.dtype, requires_grad=False), p=keep_p)
        mask = mask.div_(keep_p)
        return mask
