import torch
import torch.nn as nn
import torch.nn.functional as F
from model.layers.Embed import DataEmbedding
from model.layers.AutoCorrelation import AutoCorrelationLayer
from model.layers.FourierCorrelation import FourierBlock
from model.layers.MultiWaveletCorrelation import MultiWaveletTransform
from model.layers.Autoformer_EncDec import Encoder, EncoderLayer, my_Layernorm
from model.base import LOBAutoEncoder


class FEDformer_AE(LOBAutoEncoder):
    """
    FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity
    Paper link: https://proceedings.mlr.press/v162/zhou22g.html
    """

    def __init__(self,
                 seq_len,
                 pred_len,
                 moving_avg,
                 enc_in,
                 d_model,
                 embed,
                 freq,
                 dropout,
                 dim_ff,
                 c_out,
                 activation,
                 e_layers,
                 n_heads,
                 version='fourier', mode_select='random', modes=32, **kwargs):
        """
        version: str, for FEDformer, there are two versions to choose, options: [Fourier, Wavelets].
        mode_select: str, for FEDformer, there are two mode selection method, options: [random, low].
        modes: int, modes to be selected.
        """
        super().__init__(**kwargs)
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.d_model = d_model
        self.version = version
        self.mode_select = mode_select
        self.modes = modes

        # Decomp
        self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)

        if self.version == 'Wavelets':
            encoder_self_att = MultiWaveletTransform(ich=d_model, L=1, base='legendre')
        else:
            encoder_self_att = FourierBlock(in_channels=d_model,
                                            out_channels=d_model,
                                            seq_len=self.seq_len,
                                            modes=self.modes,
                                            mode_select_method=self.mode_select)
        # Encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AutoCorrelationLayer(
                        encoder_self_att,  # instead of multi-head attention in transformer
                        d_model, n_heads),
                    d_model,
                    dim_ff,
                    moving_avg=moving_avg,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            norm_layer=my_Layernorm(d_model)
        )
        self.linear_encoding = nn.Linear(in_features=d_model * seq_len, out_features=d_model)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead=8, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.projection = nn.Linear(d_model, c_out, bias=True)
        
    def encode(self, x):
        enc_out = self.enc_embedding(x, None)
        enc_out, attns = self.encoder(enc_out, attn_mask=None)
        enc_out = self.linear_encoding(enc_out.view((-1, self.seq_len * self.d_model)))
        return enc_out
        
    def forward(self, x):
        enc_out = self.enc_embedding(x, None)
        enc_out, attns = self.encoder(enc_out, attn_mask=None)
        # final
        memory = torch.zeros(enc_out.shape[0], enc_out.shape[1], enc_out.shape[2], device=enc_out.device)
        enc_out = self.decoder(enc_out, memory)
        out = self.projection(enc_out)
        return out    