import torch
import torch.nn as nn
from layers.Pyraformer_EncDec import Encoder


class Model(nn.Module):
    """ 
    Pyraformer: Pyramidal attention to reduce complexity
    Paper link: https://openreview.net/pdf?id=0EXmFzUn5I
    """

    def __init__(self, configs, window_size=[4,4], inner_size=5):
        """
        window_size: list, the downsample window size in pyramidal attention.
        inner_size: int, the size of neighbour attention
        """
        super().__init__()
        self.pred_len = configs.pred_len
        self.d_model = configs.d_model
        self.encoder = Encoder(configs, window_size, inner_size)

        self.projection = nn.Linear( (len(window_size)+1)*self.d_model, self.pred_len * configs.enc_in)
    def long_forecast(self, x_enc, x_enc_mark=None, x_dec=None, x_dec_mark=None):
        enc_out = self.encoder(x_enc, None)[:, -1, :]
        dec_out = self.projection(enc_out).view(
            enc_out.size(0), self.pred_len, -1)
        return dec_out

    def forward(self, x_enc, x_enc_mark=None, x_dec=None, x_dec_mark=None, target_x=None):
        dec_out = self.long_forecast(x_enc, x_enc_mark, x_dec, x_dec_mark)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]
