# Adapted from OpenFold
# Copyright 2021 AlQuraishi Laboratory
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn as nn
from functools import partialmethod
from typing import Union, List


class Dropout(nn.Module):
    """
        Implementation of dropout with the ability to share the dropout mask
        along a particular dimension.

        If not in training mode, this module computes the identity function.
    """
    def __init__(self, r: float, batch_dim: Union[int, List[int]]):
        """
            Args:
                r:
                    Dropout rate
                batch_dim:
                    Dimension(s) along which the dropout mask is shared
        """ 
        super(Dropout, self).__init__()
        
        self.r = r
        if(type(batch_dim) == int):
            batch_dim = [batch_dim]
        self.batch_dim = batch_dim
        self.dropout = nn.Dropout(self.r)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
            Args:
                x:
                    Tensor to which dropout is applied. Can have any shape
                    compatible with self.batch_dim
        """
        shape = list(x.shape)
        if(self.batch_dim is not None):
            for bd in self.batch_dim:
                shape[bd] = 1
        mask = x.new_ones(shape, requires_grad=False)
        mask = self.dropout(mask)
        x = x * mask
        return x


class DropoutRowwise(Dropout):
    """ 
        Convenience class for rowwise dropout as described in subsection 
        1.11.6.
    """
    __init__ = partialmethod(Dropout.__init__, batch_dim=-3)


class DropoutColumnwise(Dropout):
    """ 
        Convenience class for columnwise dropout as described in subsection 
        1.11.6.  
    """
    __init__ = partialmethod(Dropout.__init__, batch_dim=-2)
