import torch
import torch.nn as nn
from models.s4.s4d import S4D
import torch.nn.functional as F

class MLP1d(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP1d, self).__init__()
        self.mlp1 = nn.Conv1d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv1d(mid_channels, out_channels, 1)

    def forward(self, x):
        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)
        return x

class S4DModel(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        lr=0.001,
        bidirectional=False,
        prenorm=False,
        **kwargs,
    ):
        super().__init__()

        self.prenorm = prenorm
        self.bidirectional = bidirectional

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB)
        self.encoder = nn.Linear(d_input + 1, d_model)

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        if self.bidirectional:
            self.bidirectional_layers = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, lr))
            )
            if self.bidirectional:
                self.bidirectional_layers.append(
                    S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, lr))
                )
            self.norms.append(nn.LayerNorm(d_model))
            # self.norms.append(nn.InstanceNorm1d(d_model))
            self.dropouts.append(nn.Dropout(dropout))

        # Linear decoder
        if self.bidirectional:
            self.decoder = nn.Linear(2*d_model, d_output)
        else:
            self.decoder = nn.Linear(d_model, d_output)
        # self.decoder = MLP1d(d_model, d_output, d_model*2)

    def forward(self, x, grid):
        """
        Input x is shape (B, L, d_input)
        """
        x = torch.cat((x, grid), dim=-1)
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)

        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        # # Pooling: average pooling over the sequence length
        # x = x.mean(dim=1)
        x = x.transpose(-1, -2)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)
        return x