import torch
import torch.nn as nn


class MASKModel(nn.Module):
    # FITS: Frequency Interpolation Time Series Forecasting

    def __init__(self, configs):
        super(MASKModel, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.individual = configs.individual
        self.channels = configs.enc_in
        self.revin = configs.revin


        if self.individual:
            self.linears = nn.ModuleList()
            for i in range(self.channels):
                self.linears.append(
                    nn.Linear(
                        self.seq_len,
                        self.pred_len,
                    )
                )

        else:
            self.linear = nn.Linear(
                self.seq_len,
                self.pred_len,
            )

    def forward(self, x):
        if self.revin:
            x_mean = torch.mean(x, dim=1, keepdim=True)
            x = x - x_mean
            x_var = torch.var(x, dim=1, keepdim=True) + 1e-5
            # print(x_var)
            x = x / torch.sqrt(x_var)
        
        x = x.permute(0, 2, 1)  # Change shape to (batch_size, channels, seq_len)
        if self.individual:
            y = torch.zeros(
                [
                    x.size(0),
                    x.size(1),
                    self.pred_len,
                ],
                dtype=x.dtype,
            ).to(x.device)
            for i in range(self.channels):
                # print(x.shape)
                y[:, i, :] = self.linears[i](x[:, i, :])
        else:
            y = self.linear(x)
        y = y.permute(0, 2, 1)
        if self.revin:
            y = y * torch.sqrt(x_var) + x_mean
        return y
