import torch
import torch.nn as nn
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import DataEmbedding_inverted
from layers.Autoformer_EncDec import series_decomp


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.configs = configs
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.num_down_layers = configs.multiscale_levels
        self.down_sampling_window = 2
        self.down_sampling_method = 'conv'

        if self.down_sampling_method == 'max':
            self.down_pool = torch.nn.MaxPool1d(self.down_sampling_window, return_indices=False)
        elif self.down_sampling_method == 'avg':
            self.down_pool = torch.nn.AvgPool1d(self.down_sampling_window)
        elif self.down_sampling_method == 'conv':
            padding = 1 if torch.__version__ >= '1.5.0' else 2
            self.down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in,
                                  kernel_size=3, padding=padding,
                                  stride=self.down_sampling_window,
                                  padding_mode='circular',
                                  bias=False)

        # Embedding
        self.season_enc_embedding = nn.ModuleList([
            DataEmbedding_inverted(configs.seq_len // (2**i), configs.d_model // (2**i), configs.embed, configs.freq,
                                   configs.dropout) for i in range(self.num_down_layers + 1)
            ])
        self.trend_enc_embedding = nn.ModuleList([
            DataEmbedding_inverted(configs.seq_len // (2**i), configs.d_model // (2**i), configs.embed, configs.freq,
                                   configs.dropout) for i in range(self.num_down_layers + 1)
            ])
        # Encoder
        self.season_encoder = nn.ModuleList([
            Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                          output_attention=False), configs.d_model // (2**i), configs.n_heads),
                        configs.d_model // (2**i),
                        configs.d_ff // (2**i),
                        dropout=configs.dropout,
                        activation=configs.activation
                    ) for l in range(configs.e_layers)
                ],
                norm_layer=nn.LayerNorm(configs.d_model // (2**i))
        ) for i in range(self.num_down_layers + 1)
        ])
        self.trend_encoder = nn.ModuleList([
            Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                          output_attention=False), configs.d_model // (2**i), configs.n_heads),
                        configs.d_model // (2**i),
                        configs.d_ff // (2**i),
                        dropout=configs.dropout,
                        activation=configs.activation
                    ) for l in range(configs.e_layers)
                ],
                norm_layer=nn.LayerNorm(configs.d_model // (2**i))
        ) for i in range(self.num_down_layers + 1)
        ])
        self.decomposition = series_decomp(configs.moving_avg)
        # Decoder
        proj_len = 0
        for i in range(self.num_down_layers + 1):
            proj_len += configs.d_model // (2**i)
        self.projection = nn.Linear(proj_len, configs.pred_len, bias=True)

    def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
        if self.down_sampling_method not in ['conv', 'max', 'avg']:
            return x_enc, x_mark_enc
        # B,T,C -> B,C,T
        x_enc = x_enc.permute(0, 2, 1)

        x_enc_ori = x_enc
        x_mark_enc_mark_ori = x_mark_enc

        x_enc_sampling_list = []
        x_mark_sampling_list = []
        x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
        x_mark_sampling_list.append(x_mark_enc)

        for i in range(self.num_down_layers):
            x_enc_sampling = self.down_pool(x_enc_ori)

            x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
            x_enc_ori = x_enc_sampling

            if x_mark_enc is not None:
                x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.down_sampling_window, :])
                x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.down_sampling_window, :]

        x_enc = x_enc_sampling_list
        x_mark_enc = x_mark_sampling_list if x_mark_enc is not None else None

        return x_enc, x_mark_enc

    def forecast(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, target_x=None):
        _, _, N = x_enc.shape

        seasonal_init, trend_init = self.decomposition(x_enc)
        seasonal_init, season_x_mark_enc = self.__multi_scale_process_inputs(seasonal_init, x_mark_enc)
        trend_init, trend_x_mark_enc = self.__multi_scale_process_inputs(trend_init, x_mark_enc)
        # Embedding
        out_list = []
        season_x_mark_enc = [None for i in range(len(seasonal_init))] \
            if season_x_mark_enc is None else season_x_mark_enc
        trend_x_mark_enc = [None for i in range(len(trend_init))] \
            if trend_x_mark_enc is None else trend_x_mark_enc
        for i, season_x, season_x_mark, trend_x, trend_x_mark in zip(
            range(len(seasonal_init)), seasonal_init, season_x_mark_enc, trend_init, trend_x_mark_enc
        ):
            season_out = self.season_enc_embedding[i](season_x, x_mark=season_x_mark)
            season_out, attns = self.season_encoder[i](season_out, attn_mask=None)
            trend_out = self.trend_enc_embedding[i](trend_x, x_mark=trend_x_mark)
            trend_out, attns = self.trend_encoder[i](trend_out, attn_mask=None)
            out = season_out + trend_out
            out_list.append(out)

        out_cat = torch.cat(out_list, -1)
        dec_out = self.projection(out_cat).permute(0, 2, 1)[:, :, :N]
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, target_x=None):
        # Normalization from Non-stationary Transformer
        means = target_x.mean(1, keepdim=True).detach() \
            if target_x is not None else x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc - means
        stdev = torch.sqrt(torch.var(target_x, dim=1, keepdim=True, unbiased=False) + 1e-5) \
            if target_x is not None else torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc /= stdev

        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)

        # De-Normalization from Non-stationary Transformer
        dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        return dec_out
