import torch
import torch.nn as nn

class MLPBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout=0.0):
        super().__init__()
        assert num_layers >= 1, "num_layers must be >= 1"

        layers = []
        if num_layers == 1:
            layers.append(nn.Linear(input_dim, output_dim))
        else:
            # input layer
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))

            # hidden layers
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.ReLU())
                if dropout > 0:
                    layers.append(nn.Dropout(dropout))

            # output layer
            layers.append(nn.Linear(hidden_dim, output_dim))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class MLPModel(nn.Module):
    def __init__(
        self,
        dim: int,
        lag: int,
        hidden_dim: int,
        num_layers: int,
        componentwise: bool = False,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.dim = dim
        self.lag = lag
        self.componentwise = componentwise

        input_dim = dim * lag
        if componentwise:
            # dim separate MLPs, each seeing the full input and outputting 1 scalar
            self.blocks = nn.ModuleList([
                MLPBlock(input_dim, hidden_dim, 1, num_layers, dropout)
                for _ in range(dim)
            ])
        else:
            # one full MLP, outputting all dims at once
            self.block = MLPBlock(input_dim, hidden_dim, dim, num_layers, dropout)

    def forward(self, x):
        """
        x: [B, lag * dim]
        returns: [B, dim]
        """
        if self.componentwise:
            # run each scalar head separately on the same input
            outs = [block(x) for block in self.blocks]   # list of [B, 1]
            y = torch.cat(outs, dim=1)                   # [B, dim]
        else:
            y = self.block(x)                            # [B, dim]
        return y

class LSTMBlock(nn.Module):
    def __init__(self, dim, lag, hidden_dim, num_layers, output_dim, dropout=0.0):
        super().__init__()
        self.dim = dim
        self.lag = lag

        self.lstm = nn.LSTM(
            input_size=dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x: [B, lag * dim]
        B = x.size(0)
        x_seq = x.view(B, self.lag, self.dim)        # [B, lag, dim]
        _, (h_n, _) = self.lstm(x_seq)               # h_n: [num_layers, B, hidden_dim]
        h_last = h_n[-1]                             # [B, hidden_dim]
        y = self.fc(h_last)                          # [B, output_dim]
        return y

class LSTMModel(nn.Module):
    def __init__(
        self,
        dim: int,
        lag: int,
        hidden_dim: int,
        num_layers: int,
        componentwise: bool = False,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.dim = dim
        self.lag = lag
        self.componentwise = componentwise

        if componentwise:
            # dim separate LSTM blocks, each seeing full input and predicting one scalar
            self.blocks = nn.ModuleList([
                LSTMBlock(dim, lag, hidden_dim, num_layers, output_dim=1, dropout=dropout)
                for _ in range(dim)
            ])
        else:
            # one full LSTM block predicting all dims
            self.block = LSTMBlock(dim, lag, hidden_dim, num_layers, output_dim=dim, dropout=dropout)

    def forward(self, x):
        """
        x: [B, lag * dim]
        returns: [B, dim]
        """
        if self.componentwise:
            outs = [block(x) for block in self.blocks]   # list of [B, 1]
            y = torch.cat(outs, dim=1)                   # [B, dim]
        else:
            y = self.block(x)                            # [B, dim]
        return y

class ResidualBlock(nn.Module):
    def __init__(self, input, hidden, output, dropout):
        super(ResidualBlock, self).__init__()
        self.linear_1 = torch.nn.utils.parametrizations.weight_norm(nn.Linear(input, hidden))
        self.linear_2 = nn.Linear(hidden, output)
        self.linear_res = nn.Linear(input, output)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.layernorm = nn.LayerNorm(output)

    def forward(self, x):
        """
        x: [Batch, hidden]
        """
        h = self.linear_1(x)
        h = self.relu(h)
        h = self.linear_2(h)
        h = self.dropout(h)
        res = self.linear_res(x)
        out = h + res
        out = self.layernorm(out)
        return out
    
    def struct_loss(self):
        return torch.sum(self.linear_res.weight ** 2)

class ResidualMLP(nn.Module):
    def __init__(self, input_dim, output_dim, layers, hidden_dim, dropout):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.inputgate = nn.Linear(input_dim, hidden_dim)
        self.outputgate = nn.Linear(hidden_dim, output_dim)
        self.inputgate = torch.nn.utils.parametrizations.weight_norm(self.inputgate)

        modules = [ResidualBlock(hidden_dim, hidden_dim, hidden_dim, dropout) for _ in range(layers)]
        self.encoders = nn.ModuleList(modules)
        
    def forward(self, x):
        x = self.inputgate(x)
        for net in self.encoders:
            x = net(x)
        x = self.outputgate(x)
        return x