import torch
from torch.utils.checkpoint import checkpoint

from pado.core import PadoModule
from pado.nn.modules import Linear, Dropout, Add, LayerNorm

__all__ = ["ProjectionResidual", "ProjectionResidualNorm"]


class ProjectionResidual(PadoModule):

    def __init__(self,
                 input_dim: int,
                 output_dim: int,
                 drop_prob: float = 0.0,
                 bias: bool = True, *,
                 memory_efficient: bool = False) -> None:
        super(ProjectionResidual, self).__init__()

        self.linear = Linear(input_dim, output_dim, bias=bias)
        self.drop = Dropout(drop_prob, inplace=True)
        self.add = Add()

        self.memory_efficient = memory_efficient

    def forward(self, hidden: torch.Tensor, identity: torch.Tensor) -> torch.Tensor:
        """
        Linear + Residual Add
        :param hidden:      (batch_size, seq_length, input_dim)
        :param identity:    (batch_sie, seq_length, output_dim)
        :return:            (batch_size, seq_length, output_dim)
        """
        if (not self.memory_efficient) or (not self.training):
            hidden = self.linear(hidden)
        else:
            hidden = checkpoint(self.linear, hidden)

        hidden = self.drop(hidden)
        out = self.add(hidden, identity)
        return out

    def extra_repr(self) -> str:
        s = f"{self.linear.in_features}, {self.linear.out_features}"
        if self.drop.p > 0:
            s += f", drop_prob={self.drop.p}"
        return s


class ProjectionResidualNorm(ProjectionResidual):

    def __init__(self,
                 input_dim: int,
                 output_dim: int,
                 drop_prob: float = 0.0,
                 bias: bool = True,
                 eps: float = 1e-5, *,
                 memory_efficient: bool = False) -> None:
        super(ProjectionResidualNorm, self).__init__(input_dim, output_dim, drop_prob, bias=bias,
                                                     memory_efficient=memory_efficient)
        self.norm = LayerNorm(output_dim, eps=eps)

    def forward(self,
                hidden: torch.Tensor,
                identity: torch.Tensor) -> torch.Tensor:
        """
        Linear + Residual Add + LayerNorm
        :param hidden:      (batch_size, seq_length, input_dim)
        :param identity:    (batch_sie, seq_length, output_dim)
        :return:            (batch_size, seq_length, output_dim)
        """
        out = super().forward(hidden, identity)
        out = self.norm(out)
        return out
