import torch
import torch.nn as nn
from layers.Tem_Agg_Block import Tem_Agg_Block,series_decomp

class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.viper=configs.viper
        
        # Series decomposition block from Autoformer
        self.decompsition = series_decomp(configs.moving_avg)

        self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
        
        self.tem_agg_layer=Tem_Agg_Block(configs.top_k,configs.seq_len)

        self.Linear_Trend.weight = nn.Parameter(
                (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
        self.backbone=LSTM(configs)
        
    def encoder(self, x):
        # Normalization from Non-stationary Transformer
        means = x.mean(1, keepdim=True).detach()
        x = x - means
        stdev = torch.sqrt(
                    torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x /= stdev
        seasonal_init, trend_init = self.decompsition(x)
        
        seasonal_output = self.tem_agg_layer(seasonal_init)
        seasonal_output=self.backbone(seasonal_output)
        trend_output = self.Linear_Trend(trend_init.permute(0, 2, 1)).permute(0,2,1)
        
        pred = seasonal_output + trend_output
        pred=pred
        pred = pred * \
                (stdev[:, 0, :].unsqueeze(1).repeat(
                            1, self.pred_len, 1))
        pred = pred + \
                (means[:, 0, :].unsqueeze(1).repeat(
                                1, self.pred_len, 1))
        return pred[:, -self.pred_len:, :]  # [B, L, D]

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.viper:
            dec_out = self.encoder(x_enc)
        else: dec_out = self.backbone(x_enc)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]
    
class LSTM(nn.Module):
    def __init__(self, configs):
        super(LSTM, self).__init__()
        
        self.lstm = nn.LSTM(input_size=configs.enc_in, hidden_size=configs.d_model, num_layers=configs.e_layers, batch_first=True)
        self.fc = nn.Linear(configs.d_model, configs.pred_len * configs.enc_in)
        self.dropout = nn.Dropout(configs.dropout)
        self.layernorm = nn.LayerNorm(normalized_shape=[configs.seq_len, configs.enc_in])
        self.channels=configs.enc_in
        
    def forward(self,x_enc):
        B,T,N=x_enc.shape
        x_enc = self.layernorm(x_enc)
        lstm_out, _ = self.lstm(x_enc)
        lstm_out = self.dropout(lstm_out)
        
        predictions = self.fc(lstm_out[:,-1,:])
        
        pre=predictions.reshape(B, -1, self.channels)

        return pre

