import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Informer_layer import DataEmbedding
from layers.Transformer_layer import Encoder, EncoderLayer, Decoder, DecoderLayer, ConvLayer
from layers.Attention import AttentionLayer, ProbAttention
from layers.RevIn import RevIN


class Informer(nn.Module):
    def __init__(self, args) -> None:
        super(Informer, self).__init__()
        self.d_model = args.d_model
        self.pred_len = args.pred_len
        self.output_attention = args.output_attention

        self.revin = args.revin
        if self.revin:
            self.revin_layer = RevIN(args.enc_in, affine=args.affine, subtract_last=False)

        # 初始输入嵌入生成
        self.enc_embedding = DataEmbedding(args.enc_in, args.d_model, args.embed, args.pos_embed_type, args.freq, args.dropout)
        self.dec_embedding = DataEmbedding(args.dec_in, args.d_model, args.embed, args.pos_embed_type, args.freq, args.dropout)

        # Encoder定义
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        ProbAttention(False, args.factor, attention_dropout=args.dropout,
                                      output_attention=args.output_attention),
                        args.d_model, args.n_heads),
                    args.d_model,
                    args.d_ff,
                    dropout=args.dropout,
                    activation=args.activation
                ) for _ in range(args.e_layers)
            ],
            [
                ConvLayer(
                    args.d_model
                ) for _ in range(args.e_layers-1)
            ] if args.distil else None,
            norm_layer=torch.nn.LayerNorm(args.d_model)
        )
        # Decoder
        self.decoder = Decoder(
            [
                DecoderLayer(
                    AttentionLayer(
                        ProbAttention(True, args.factor, attention_dropout=args.dropout, output_attention=False),
                        args.d_model, args.n_heads),
                    AttentionLayer(
                        ProbAttention(False, args.factor, attention_dropout=args.dropout, output_attention=False),
                        args.d_model, args.n_heads),
                    args.d_model,
                    args.d_ff,
                    dropout=args.dropout,
                    activation=args.activation,
                )
                for _ in range(args.d_layers)
            ],
            norm_layer=torch.nn.LayerNorm(args.d_model),
        )
        self.projection=nn.Linear(args.d_model, args.c_out, bias=True)
        
    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 
                enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
        if self.revin:
            x_enc = self.revin_layer(x_enc, 'norm')
        
        # 保证gpu位置
        x_mark_enc, x_dec, x_mark_dec = x_mark_enc.to(x_enc.device), x_dec.to(x_enc.device), x_mark_dec.to(x_enc.device)
        # 整合输入内容的嵌入
        # x: [batch_size x seq_len x channels]
        enc_out = self.enc_embedding(x_enc, x_mark_enc)
        # Informer是基于每个点的，主要使用1维卷积将每个数据点的特征映射到不同通道上去
        # enc_out: [batch_size x seq_len x d_model] -> [batch_size x (seq_len / 2^e_layers) x d_model]
        enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
        # 以相同的过程对标整理出预测内容的嵌入
        # x_dec: [batch_size x (label_len + pred_len) x channels]
        dec_out = self.dec_embedding(x_dec, x_mark_dec)
        # dec_out: [batch_size x (label_len + pred_len) x channels] -> [batch_size x (label_len + pred_len) x c_out]
        dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
        dec_out = self.projection(dec_out)

        if self.revin:
            dec_out = self.revin_layer(dec_out, 'denorm')
        
        if self.output_attention:
            return dec_out[:,-self.pred_len:,:], attns
        else:
            return dec_out[:,-self.pred_len:,:] # [B, L, D]    
