import torch
import torch.nn as nn
from functools import partialmethod
from typing import Union, List


class Dropout(nn.Module):

    def __init__(self, r: float, batch_dim: Union[int, List[int]]):

        super(Dropout, self).__init__()

        self.r = r
        if type(batch_dim) == int:
            batch_dim = [batch_dim]
        self.batch_dim = batch_dim
        self.dropout = nn.Dropout(self.r)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        shape = list(x.shape)
        if self.batch_dim is not None:
            for bd in self.batch_dim:
                shape[bd] = 1
        mask = x.new_ones(shape)
        mask = self.dropout(mask)
        x *= mask
        return x


class DropoutRowwise(Dropout):

    __init__ = partialmethod(Dropout.__init__, batch_dim=-3)


class DropoutColumnwise(Dropout):

    __init__ = partialmethod(Dropout.__init__, batch_dim=-2)
