import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import DataEmbedding_inverted
import numpy as np
from model import VanillaVAE

class Model(nn.Module):
    """
    Paper link: https://arxiv.org/abs/2310.06625
    """

    def __init__(self, configs, mask_rate=0.5):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.output_attention = configs.output_attention
        self.use_norm = configs.use_norm
        # Embedding
        self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
                                                    configs.dropout)

        self.mask_enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
                                                    configs.dropout)

        self.class_strategy = configs.class_strategy
        # Encoder-only architecture
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=configs.output_attention), configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )
        
        self.cat_mode = 'E_cat'
        if self.cat_mode == 'E_cat':
            self.projector = nn.Linear(configs.d_model*2, configs.pred_len, bias=True)
        elif self.cat_mode == 'N_cat':
            self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)

        self.mask_rate = mask_rate
        
        self.prior_vae_encoder = VanillaVAE(in_channels=3, latent_dim=configs.d_model, encoder=Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=configs.output_attention), configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        ))

        self.post_vae_encoder = VanillaVAE(in_channels=3, latent_dim=configs.d_model, encoder=Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=configs.output_attention), configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        ))



    def random_mask(self, x, batch_size, seq_len, mask_rate=0.5):

        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        # 1 mask 0 keep
        # it is possible that two elements of the noise is the same, so do a while loop to avoid it
        while True:
            noise = torch.rand(batch_size, seq_len, device=x.device)  # noise in [0, 1]
            sorted_noise, _ = torch.sort(noise, dim=1)  # ascend: small is remove, large is keep
            cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
            token_all_mask = (noise <= cutoff_mask).float()
            if token_all_mask.sum() == batch_size*num_masked_tokens:
                break
            else:
                print("Rerandom the noise!")

        x[token_all_mask.nonzero(as_tuple=True)] = 0
        return x


    def kl_loss(self, mu, log_var, target_mu, target_log_var):
        """
        计算 KL 散度损失，衡量两个高斯分布之间的差异
        :param mu: (Tensor) 编码器输出的潜在空间中的均值向量 [batch_size, latent_dim]
        :param log_var: (Tensor) 编码器输出的潜在空间中的对数方差向量 [batch_size, latent_dim]
        :param target_mu: (Tensor) 目标分布的均值向量 [batch_size, latent_dim]
        :param target_log_var: (Tensor) 目标分布的对数方差向量 [batch_size, latent_dim]
        :return: (Tensor) KL 散度损失值 [scalar]
        """
        kl_loss = 0.5 * torch.sum(target_log_var - log_var + torch.exp(log_var - target_log_var) +
                                (mu - target_mu)**2 / torch.exp(target_log_var) - 1, dim=1)
        return torch.mean(kl_loss)


    def reconstrction(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x_enc /= stdev

        B, L, N = x_enc.shape # B L N 32 96 321
        # B: batch_size;    E: d_model; 
        # L: seq_len;       S: pred_len;
        # N: number of variate (tokens), can also includes covariates

        # mask
        x_enc_mask = self.random_mask(x_enc, batch_size=B, seq_len=L, mask_rate=self.mask_rate)
        
        # Embedding
        # B L N -> B N E  32 325 512              (B L N -> B L E in the vanilla Transformer)
        enc_embed = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
        enc_embed_mask = self.mask_enc_embedding(x_enc_mask, x_mark_enc)

        # mask
        prior_z, mu, log_var = self.prior_vae_encoder.encoder_forward(enc_embed_mask)
        # unmask
        post_z, target_mu, target_log_var = self.post_vae_encoder.encoder_forward(enc_embed)
        kl_loss = self.kl_loss(mu, log_var, target_mu, target_log_var)

        # B N E -> B N E                (B L E -> B L E in the vanilla Transformer)
        # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules
        enc_out, attns = self.encoder(enc_embed_mask, attn_mask=None)

        if self.cat_mode == 'E_cat': 
            concatenated_enc_out = torch.cat((enc_out, prior_z), dim=2)
        elif self.cat_mode == 'N_cat': 
            concatenated_enc_out = torch.cat((enc_out, prior_z), dim=1)

        # B N E -> B N S -> B S N 
        dec_out = self.projector(concatenated_enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates

        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

        return dec_out, kl_loss

    
    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x_enc /= stdev

        B, L, N = x_enc.shape # B L N
        # B: batch_size;    E: d_model; 
        # L: seq_len;       S: pred_len;
        # N: number of variate (tokens), can also includes covariates

        # mask
        x_enc_mask = self.random_mask(x_enc, batch_size=B, seq_len=L, mask_rate=self.mask_rate)
        
        # Embedding
        # B L N -> B N E                (B L N -> B L E in the vanilla Transformer)
        enc_out = self.enc_embedding(x_enc_mask, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
        
        # mask
        prior_z, mu, log_var = self.prior_vae_encoder.encoder_forward(x_enc_mask)
        # unmask
        post_z, target_mu, target_log_var = self.post_vae_encoder.encoder_forward(x_enc)
        kl_loss = self.kl_loss(mu, log_var, target_mu, target_log_var)


        # B N E -> B N E                (B L E -> B L E in the vanilla Transformer)
        # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules
        enc_out, attns = self.encoder(enc_out, attn_mask=None)

        # B N E -> B N S -> B S N 
        dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates

        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

        return dec_out, kl_loss


    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        dec_out, kl_loss = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]