"""Implementation of FFN block in the style of Transformers."""

from functools import partial
from torch import nn
from src.models.sequence.base import SequenceModule
from src.models.nn import LinearActivation, DropoutNd

class FFN(SequenceModule):
    def __init__(
            self,
            d_input,
            expand=2,
            d_output=None,
            transposed=False,
            activation='gelu',
            initializer=None,
            dropout=0.0,
            tie_dropout=False,
            inner_bn=False,
        ):
        super().__init__()
        self.d_output = d_input if d_output is None else d_output
        self.transposed = transposed
        self.inner_bn = inner_bn
        d_inner = int(expand * d_input)
        if self.inner_bn:
            self.inner_bn_layer = nn.BatchNorm1d(d_inner)
        linear1 = LinearActivation(
            d_input, d_inner,
            transposed=transposed,
            activation=activation,
            initializer=initializer,
            activate=True,
        )
        dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout
        # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout
        drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity()

        linear2 = LinearActivation(
            d_inner, self.d_output,
            transposed=transposed,
            activation=None,
            initializer=initializer,
            activate=False,
        )
        '''
        self.ff = nn.Sequential(
            linear1,
            drop,
            linear2,
        )
        '''
        self.linear1 = linear1
        self.drop = drop
        self.linear2 = linear2

    def forward(self, x, *args, **kwargs):
        x = self.linear1(x)
        if self.inner_bn:
            x = self.inner_bn_layer(x.permute(0,2,1)).permute(0,2,1)
        x = self.drop(x)
        x = self.linear2(x)
        return x, None
        
        #return self.ff(x), None

    def step(self, x, state, **kwargs):
        # x: [batch, d_input]
        if self.transposed:
            # expects: [batch, d_input, seq_len]
            return self.ff(x.unsqueeze(-1)).squeeze(-1), state
        else:
            return self.ff(x), state

