import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.StandardNorm import myNormalize
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import DataEmbedding_inverted
from layers.Autoformer_EncDec import series_decomp


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.channels = configs.enc_in
        self.decomposition = series_decomp(configs.moving_avg)
        # Embedding
        self.enc_embedding_seasonal = DataEmbedding_inverted(configs.enc_in * configs.seq_len, configs.d_model, configs.embed, configs.freq,
                                                             configs.dropout)
        self.enc_embedding_trend = DataEmbedding_inverted(configs.enc_in * configs.seq_len, configs.d_model, configs.embed,
                                                          configs.freq, configs.dropout)
        # Encoder
        self.encoder_seasonal = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=False), configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )
        self.encoder_trend = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=False), configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )
        # Decoder
        self.projection_seasonal = nn.Linear(configs.d_model, configs.pred_len, bias=True)
        self.projection_trend = nn.Linear(configs.d_model, configs.pred_len, bias=True)
        self.instance_norm = myNormalize(self.channels, affine=True, subtract_last=True)

    def forward(self, x_enc, x_enc_mark=None, x_dec=None, x_dec_mark=None, target_x=None):
        x_enc = self.instance_norm(x_enc, target_x, 'norm')
        seasonal_init, trend_init = self.decomposition(x_enc)
        seasonal_init = seasonal_init.reshape(seasonal_init.shape[0], seasonal_init.shape[1]*seasonal_init.shape[2], 1)
        trend_init = trend_init.reshape(trend_init.shape[0], trend_init.shape[1] * trend_init.shape[2], 1)
        x_enc_mark = x_enc_mark.repeat(1, x_enc.shape[2], 1)

        # Embedding
        seasonal_out = self.enc_embedding_seasonal(seasonal_init, x_mark=x_enc_mark)
        seasonal_out, attns = self.encoder_seasonal(seasonal_out, attn_mask=x_enc_mark)
        seasonal_out = self.projection_seasonal(seasonal_out).permute(0, 2, 1)[:, :, :1]

        trend_out = self.enc_embedding_trend(trend_init, x_mark=None)
        trend_out, attns = self.encoder_trend(trend_out, attn_mask=None)
        trend_out = self.projection_trend(trend_out).permute(0, 2, 1)[:, :, :1]

        dec_out = seasonal_out + trend_out
        dec_out = self.instance_norm(dec_out, mode='denorm')
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]
