import torch
import torch.nn as nn
import torch.nn.functional as F
from gluonts.torch.model.fedformer.layers.Embed import (
    DataEmbedding,
    DataEmbedding_wo_pos,
    DataEmbedding_wo_pos_temp,
    DataEmbedding_wo_temp,
)
from gluonts.torch.model.fedformer.layers.AutoCorrelation import (
    AutoCorrelation,
    AutoCorrelationLayer,
)
from gluonts.torch.model.fedformer.layers.FourierCorrelation import (
    FourierBlock,
    FourierCrossAttention,
)

# from gluonts.torch.model.fedformer.layers.MultiWaveletCorrelation import MultiWaveletCross, MultiWaveletTransform
from gluonts.torch.model.fedformer.layers.SelfAttention_Family import (
    FullAttention,
    ProbAttention,
)
from gluonts.torch.model.fedformer.layers.Autoformer_EncDec import (
    Encoder,
    Decoder,
    EncoderLayer,
    DecoderLayer,
    my_Layernorm,
    series_decomp,
    series_decomp_multi,
)
import math
import numpy as np


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class FEDformer(nn.Module):
    """
    FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity
    """

    def __init__(
        self,
        freq,
        seq_len,
        label_len,
        pred_len,
        enc_in,
        dec_in,
        c_out,
        version="Fourier",
        mode_select="random",
        modes=64,
        embed="timeF",
        moving_avg=[7],
        d_model=512,
        L=3,
        base="legendre",
        cross_activation="tanh",
        n_heads=8,
        d_ff=2048,
        dropout=0.05,
        activation="gelu",
        e_layers=2,
        d_layers=1,
    ):
        super(FEDformer, self).__init__()
        self.version = version
        self.mode_select = mode_select
        self.modes = modes
        self.seq_len = seq_len
        self.label_len = label_len
        self.pred_len = pred_len
        self.d_model = d_model

        self.enc_embedding = DataEmbedding_wo_pos(
            enc_in, d_model, embed, freq, dropout
        )
        self.dec_embedding = DataEmbedding_wo_pos(
            dec_in, d_model, embed, freq, dropout
        )
        # Decomp
        kernel_size = moving_avg
        if isinstance(kernel_size, list):
            self.decomp = series_decomp_multi(kernel_size)
        else:
            self.decomp = series_decomp(kernel_size)

        # if self.version == 'Wavelets':
        # encoder_self_att = MultiWaveletTransform(ich=d_model, L=L, base=base)
        # decoder_self_att = MultiWaveletTransform(ich=d_model, L=L, base=base)
        # decoder_cross_att = MultiWaveletCross(in_channels=d_model,
        # out_channels=d_model,
        # seq_len_q=self.seq_len // 2 + self.pred_len,
        # seq_len_kv=self.seq_len,
        # modes=modes,
        # ich=d_model,
        # base=base,
        # activation=cross_activation)
        # else:
        encoder_self_att = FourierBlock(
            in_channels=d_model,
            out_channels=d_model,
            seq_len=self.seq_len,
            modes=modes,
            mode_select_method=mode_select,
        )
        decoder_self_att = FourierBlock(
            in_channels=d_model,
            out_channels=d_model,
            seq_len=self.seq_len // 2 + self.pred_len,
            modes=modes,
            mode_select_method=mode_select,
        )
        decoder_cross_att = FourierCrossAttention(
            in_channels=d_model,
            out_channels=d_model,
            seq_len_q=self.seq_len // 2 + self.pred_len,
            seq_len_kv=self.seq_len,
            modes=modes,
            mode_select_method=mode_select,
        )
        # Encoder
        enc_modes = int(min(modes, seq_len // 2))
        dec_modes = int(min(modes, (seq_len // 2 + pred_len) // 2))
        print("enc_modes: {}, dec_modes: {}".format(enc_modes, dec_modes))

        self.encoder = Encoder(
            [
                EncoderLayer(
                    AutoCorrelationLayer(encoder_self_att, d_model, n_heads),
                    d_model,
                    d_ff,
                    moving_avg=moving_avg,
                    dropout=dropout,
                    activation=activation,
                )
                for l in range(e_layers)
            ],
            norm_layer=my_Layernorm(d_model),
        )
        # Decoder
        self.decoder = Decoder(
            [
                DecoderLayer(
                    AutoCorrelationLayer(decoder_self_att, d_model, n_heads),
                    AutoCorrelationLayer(decoder_cross_att, d_model, n_heads),
                    d_model,
                    c_out,
                    d_ff,
                    moving_avg=moving_avg,
                    dropout=dropout,
                    activation=activation,
                )
                for l in range(d_layers)
            ],
            norm_layer=my_Layernorm(d_model),
            projection=nn.Linear(d_model, c_out, bias=True),
        )

    def forward(
        self,
        x_enc,
        x_mark_enc,
        x_mark_dec,
        enc_self_mask=None,
        dec_self_mask=None,
        dec_enc_mask=None,
    ):
        # decomp init
        mean = (
            torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
        )
        zeros = torch.zeros(
            [x_enc.shape[0], self.pred_len, x_enc.shape[2]]
        ).to(
            device
        )  # cuda()
        seasonal_init, trend_init = self.decomp(x_enc)
        # decoder input
        trend_init = torch.cat(
            [trend_init[:, -self.label_len :, :], mean], dim=1
        )
        seasonal_init = F.pad(
            seasonal_init[:, -self.label_len :, :], (0, 0, 0, self.pred_len)
        )
        # enc
        enc_out = self.enc_embedding(x_enc, x_mark_enc)
        enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
        # dec
        dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
        seasonal_part, trend_part = self.decoder(
            dec_out,
            enc_out,
            x_mask=dec_self_mask,
            cross_mask=dec_enc_mask,
            trend=trend_init,
        )
        # final
        dec_out = trend_part + seasonal_part

        return dec_out[:, -self.pred_len :, :]  # [B, L, D]
