import torch
import torch.nn as nn
from layers.Tem_Agg_Block import Tem_Agg_Block,series_decomp

class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        
        self.revin=configs.revin
        
        # Series decomposition block from Autoformer
        self.decompsition = series_decomp(configs.moving_avg)
        self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
        self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
        
        self.tem_agg_layer=Tem_Agg_Block(configs.top_k,configs.seq_len)

        self.Linear_Seasonal.weight = nn.Parameter(
                (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
        self.Linear_Trend.weight = nn.Parameter(
                (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
    def encoder(self, x):
        if self.revin:
            # Normalization from Non-stationary Transformer
            means = x.mean(1, keepdim=True).detach()
            x = x - means
            stdev = torch.sqrt(
                    torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x /= stdev
        seasonal_init, trend_init = self.decompsition(x)
        
        seasonal_output = self.tem_agg_layer(seasonal_init)
        seasonal_output = self.Linear_Seasonal(seasonal_output.permute(0, 2, 1))
        trend_output = self.Linear_Trend(trend_init.permute(0, 2, 1))
        
        pred = seasonal_output + trend_output
        pred=pred.permute(0,2,1)
        if self.revin:
            pred = pred * \
                        (stdev[:, 0, :].unsqueeze(1).repeat(
                            1, self.pred_len, 1))
            pred = pred + \
                            (means[:, 0, :].unsqueeze(1).repeat(
                                1, self.pred_len, 1))
        return pred[:, -self.pred_len:, :]  # [B, L, D]

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        dec_out = self.encoder(x_enc)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]
    