from typing import Optional
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential

from pado.core import PadoModule
from pado.nn.modules import Linear, ReLU, Dropout, Add, LayerNorm

__all__ = ["FeedForwardResidualNorm", "NormFeedForwardResidual", "FeedForwardResidual",
           "_FeedForwardResidualBase"]


class _FeedForwardResidualBase(PadoModule):

    def __init__(self,
                 input_dim: int,
                 feedforward_dim: int,
                 drop_prob: float = 0.0,
                 feedforward_drop_prob: Optional[float] = None,
                 eps: float = 1e-5, *,
                 add_weight: float = 1.0,
                 act_layer=None,
                 memory_efficient: bool = False) -> None:
        super(_FeedForwardResidualBase, self).__init__()

        if act_layer is None:
            act_layer = ReLU
        if feedforward_drop_prob is None:
            feedforward_drop_prob = drop_prob

        self.memory_efficient = memory_efficient
        self.layers = nn.Sequential(
            OrderedDict({
                "linear1": Linear(input_dim, feedforward_dim),
                "drop1": Dropout(feedforward_drop_prob),
                "act": act_layer(),
                "linear2": Linear(feedforward_dim, input_dim),
                "drop2": Dropout(drop_prob),
            }))

        self.norm = LayerNorm(input_dim, eps=eps)
        self.add = Add()
        self.add_weight = add_weight

    def forward(self, hidden: torch.Tensor, identity: Optional[torch.Tensor] = None) -> torch.Tensor:
        raise NotImplementedError

    def extra_repr(self) -> str:
        s = f"{self.layers[0].in_features}, feedforward_dim={self.layers[0].out_features}"
        if self.drop2.p > 0:
            s += f", drop_prob={self.drop2.p}"
        if self.drop1.p > 0:
            s += f", feedforward_drop_prob={self.drop1.p}"
        if self.add_weight != 1.0:
            s += f", add_weight={self.add_weight}"
        return s


class FeedForwardResidualNorm(_FeedForwardResidualBase):

    def forward(self,
                hidden: torch.Tensor,
                identity: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        FeedForward block + (weighted) Residual Add + LayerNorm
        :param hidden:      (batch_size, seq_length, input_dim)
        :param identity:    (batch_size, seq_length, input_dim)
        :return:            (batch_size, seq_length, input_dim)
        """
        if identity is None:
            identity = hidden

        if (not self.memory_efficient) or (not self.training):
            hidden = self.layers(hidden)
        else:
            hidden = checkpoint_sequential(self.layers, segments=2, input=hidden)

        if self.add_weight != 1.0:
            out = self.add(hidden * self.add_weight, identity)
        else:
            out = self.add(hidden, identity)
        out = self.norm(out)
        return out


class NormFeedForwardResidual(_FeedForwardResidualBase):

    def forward(self,
                hidden: torch.Tensor,
                identity: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        LayerNorm + FeedForward block + (weighted) Residual Add
        :param hidden:      (batch_size, seq_length, input_dim)
        :param identity:    (batch_size, seq_length, input_dim)
        :return:            (batch_size, seq_length, input_dim)
        """
        if identity is None:
            identity = hidden
        hidden = self.norm(hidden)

        if (not self.memory_efficient) or (not self.training):
            hidden = self.layers(hidden)
        else:
            hidden = checkpoint_sequential(self.layers, segments=2, input=hidden)

        if self.add_weight != 0:
            out = self.add(hidden * self.add_weight, identity)
        else:
            out = self.add(hidden, identity)
        return out


class FeedForwardResidual(_FeedForwardResidualBase):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        del self.norm  # no LN inside

    def forward(self, hidden: torch.Tensor, identity: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        FeedForward block + (weighted) Residual Add
        :param hidden:      (batch_size, seq_length, input_dim)
        :param identity:    (batch_size, seq_length, input_dim)
        :return:            (batch_size, seq_length, input_dim)
        """
        if identity is None:
            identity = hidden

        if (not self.memory_efficient) or (not self.training):
            hidden = self.layers(hidden)
        else:
            hidden = checkpoint_sequential(self.layers, segments=2, input=hidden)

        if self.add_weight != 1.0:
            out = self.add(hidden * self.add_weight, identity)
        else:
            out = self.add(hidden, identity)
        return out
