import torch
from torch import nn
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import PatchEmbedding



class Model(nn.Module):
    """
    Paper link: https://arxiv.org/abs/2211.14730
    The implementation of Moment is basically consistent with patchtst.
    """
    def __init__(self, configs):
        super().__init__()
        self.pred_len = configs.test_pred_len
        self.seq_len = configs.seq_len

        self.Linear = nn.Linear(self.seq_len, self.pred_len)
        self.use_norm = True

    def forecast(self, x, x_mark, y_mark):
        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x.mean(1, keepdim=True).detach()
            x = x - means
            stdev = torch.sqrt(
                torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x /= stdev
        # do patching and embedding
        x = x.permute(0, 2, 1)

        # print(x.shape, self.seq_len)
        dec_out = self.Linear(x)
        dec_out = dec_out.permute(0, 2, 1)

        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * \
                    (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + \
                    (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        return dec_out

    def forward(self, x, x_mark, y_mark):
        dec_out = self.forecast(x, x_mark, y_mark)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]
