import torch
from torch import nn
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import PatchEmbedding

class Model(nn.Module):
    """
    Paper link: https://arxiv.org/abs/2211.14730
    The implementation of Moment is basically consistent with PatchTST.
    """
    def __init__(self, configs):
        super().__init__()
        self.pred_len = configs.test_pred_len
        self.seq_len = configs.seq_len
        self.output_token_len = configs.output_token_len  # 新增的output_token_len参数
        self.input_token_len = configs.input_token_len

        self.Linear = nn.Linear(self.seq_len, self.output_token_len)  # 线性层根据output_token_len调整
        self.use_norm = configs.use_norm

    def forecast(self, x, x_mark, y_mark):
        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x.mean(1, keepdim=True).detach()
            x = x - means
            stdev = torch.sqrt(
                torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x /= stdev

        # do patching and embedding
        x = x.permute(0, 2, 1)

        dec_out = self.Linear(x)
        dec_out = torch.cat([x[:, :, self.input_token_len:], dec_out], dim=2)
        
        dec_out = dec_out.permute(0, 2, 1)
        # print(dec_out.shape)
        
        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * stdev + means
        
        return dec_out

    def forward(self, x, x_mark, y_mark):
        dec_out = self.forecast(x, x_mark, y_mark)
        return dec_out  # [B, L, D]
