""" Implementation of FFN block in the style of Transformers """

from functools import partial
from torch import nn
import torch
from src.models.sequence.base import SequenceModule
from src.models.nn import LinearActivation, DropoutNd

class FF_lookback(SequenceModule):

    def __init__(self, d_input,T=10, expand=2, d_output=None, transposed=False, activation='gelu', initializer=None, dropout=0.0, tie_dropout=False,**kwargs):
        super().__init__()
        self.T = T
        self.d_output = d_input if d_output is None else d_output
        d_input = d_input*(T+1)
        
        self.transposed = transposed
        d_inner = expand * d_input
        linear1 = LinearActivation(
            d_input, d_inner,
            transposed=transposed,
            activation=activation,
            initializer=initializer,
            activate=True,
        )
        dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout
        # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout
        drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity()

        linear2 = LinearActivation(
            d_inner, self.d_output,
            transposed=transposed,
            activation=None,
            initializer=initializer,
            activate=False,
        )

        self.ff = nn.Sequential(
            linear1,
            drop,
            linear2,
        )

    def forward(self, x, *args, **kwargs):
        #x shape (batch, seq_len, hidden_dim)
        assert self.T <= x.shape[1]
        x_cat = x

        for i in range(self.T):
            x_cat = torch.cat((x_cat, torch.cat((torch.zeros(x.shape[0], i+1, x.shape[2]).cuda(),x[:, :-(1+i), :]),dim=1)), dim=2) 
        x_cat = self.ff(x_cat)

        return x_cat, None