import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from layers.Autoformer_EncDec import series_decomp


class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return x * F.gelu(gate)


class QueryAdaptiveMasking(nn.Module):
    def __init__(self, dim=1, start_prob=0.1, end_prob=0.5):
        super().__init__()
        self.dim = dim
        self.start_prob = start_prob
        self.end_prob = end_prob

    def forward(self, x):
        if not self.training:
            return x
        else:
            size = x.shape[self.dim]
            dropout_prob = torch.linspace(self.start_prob, self.end_prob, steps=size, device=x.device).view(
                [-1 if i == self.dim else 1 for i in range(x.dim())])
            mask = torch.bernoulli(1 - dropout_prob).expand_as(x)
            return x * mask


class Model_backbone(nn.Module):
    def __init__(self, c_in: int, seq_len: int, pred_len: int, patch_len: int = 24, stride: int = 24, n_layers: int = 3,
                 d_model=128, n_heads=16, d_ff: int = 256,
                 attn_dropout: float = 0., dropout: float = 0., res_attention: bool = True, independence: bool = False,
                 store_attn: bool = False, QAM_start: float = 0.1,
                 QAM_end: float = 0.5, padding_patch=None):

        super().__init__()

        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch = padding_patch

        pred_patch_num = (pred_len + patch_len - 1) // patch_len
        seq_patch_num = int((seq_len - patch_len) / stride + 1)

        if padding_patch == 'end':
            self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
            seq_patch_num += 1

        # Backbone
        self.backbone = Dummy_Embedding(c_in, seq_patch_num=seq_patch_num, patch_len=patch_len,
                                        pred_patch_num=pred_patch_num,
                                        n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_ff=d_ff,
                                        attn_dropout=attn_dropout, dropout=dropout, QAM_start=QAM_start,
                                        QAM_end=QAM_end,
                                        res_attention=res_attention, independence=independence, store_attn=store_attn)

        self.n_vars = c_in
        self.pred_len = pred_len
        self.proj = Projection(d_model, patch_len)
        self.proj2 = nn.Linear(pred_patch_num * patch_len, pred_len)

    def forward(self, z):  # z: [bs x nvars x seq_len]
        # do patching
        if self.padding_patch == 'end':
            z = self.padding_patch_layer(z)
        z = z.unfold(dimension=-1, size=self.patch_len,
                     step=self.stride)  # z: [bs x nvars x seq_patch_num x patch_len]
        # model
        z = self.backbone(z)  # z: [bs x nvars x pred_patch_num x d_model]
        z = self.proj(z)  # z: [bs x nvars x pred_len]
        z = self.proj2(z)
        return z


class Projection(nn.Module):
    def __init__(self, d_model, patch_len):
        super().__init__()
        self.linear = nn.Linear(d_model, patch_len)
        self.flatten = nn.Flatten(start_dim=-2)

    def forward(self, x):
        x = self.linear(x)
        x = self.flatten(x)
        return x


class Dummy_Embedding(nn.Module):
    def __init__(self, c_in, seq_patch_num, patch_len, pred_patch_num, n_layers=3, d_model=128, n_heads=16,
                 QAM_start=0.1, QAM_end=0.5,
                 d_ff=256, attn_dropout=0., dropout=0., store_attn=False, res_attention=True, independence=False):

        super().__init__()

        # Input encoding
        self.W_P = nn.Linear(patch_len, d_model)
        self.dropout = nn.Dropout(dropout)
        # Dummy Input
        self.independence = independence
        if self.independence:
            self.dummies = nn.Parameter(0.5 * torch.randn(pred_patch_num, patch_len))
        else:
            self.dummies = nn.Parameter(0.5 * torch.randn(c_in, pred_patch_num, patch_len))
        self.independence = independence
        # Positional encoding
        self.PE = nn.Parameter(0.04 * torch.rand(seq_patch_num, d_model) - 0.02)
        # Encoder
        self.decoder = Decoder(seq_patch_num, d_model, n_heads, pred_patch_num, d_ff=d_ff, attn_dropout=attn_dropout,
                               dropout=dropout,
                               QAM_start=QAM_start, QAM_end=QAM_end, res_attention=res_attention, n_layers=n_layers,
                               store_attn=store_attn)

    def forward(self, x) -> Tensor:  # x: [bs x nvars x seq_patch_num x patch_len]
        bs = x.shape[0]
        n_vars = x.shape[1]
        # Input encoding
        x = self.W_P(x) + self.PE  # x: [bs x nvars x seq_patch_num x d_model]
        dummies = self.W_P(self.dummies)  # dummies: [bs x nvars x pred_patch_num x d_model]
        x = torch.reshape(x, (bs * n_vars, x.shape[2], x.shape[3]))  # x: [bs * nvars x seq_patch_num x d_model]

        seq_patch = self.dropout(x)  # seq_patch: [bs * nvars x patch_num x d_model]

        if self.independence:
            pred_patch = dummies.unsqueeze(0).repeat(bs * n_vars, 1, 1)
        else:
            pred_patch = dummies.unsqueeze(0).repeat(bs, 1, 1, 1)
            pred_patch = torch.reshape(pred_patch, (bs * n_vars, pred_patch.shape[2], pred_patch.shape[
                3]))  # pred_patch: [bs * nvars x pred_patch_num x d_model]

        # decoder
        z = self.decoder(seq_patch, pred_patch)  # z: [bs * nvars x pred_patch_num x d_model]
        z = torch.reshape(z, (-1, n_vars, z.shape[-2], z.shape[-1]))  # z: [bs x nvars x pred_patch_num x d_model]
        return z

    # Cell


class Decoder(nn.Module):
    def __init__(self, seq_patch_num, d_model, n_heads, pred_patch_num, d_ff=None, attn_dropout=0., dropout=0.,
                 QAM_start=0.1, QAM_end=0.5,
                 res_attention=False, n_layers=1, store_attn=False):
        super().__init__()

        self.layers = nn.ModuleList([DecoderLayer(seq_patch_num, d_model, pred_patch_num, n_heads=n_heads, d_ff=d_ff,
                                                  QAM_start=QAM_start, QAM_end=QAM_end,
                                                  attn_dropout=attn_dropout, dropout=dropout,
                                                  res_attention=res_attention,
                                                  store_attn=store_attn) for i in range(n_layers)])
        self.res_attention = res_attention

    def forward(self, seq: Tensor, pred: Tensor):
        scores = None
        if self.res_attention:
            for mod in self.layers: seq, pred, scores = mod(seq, pred, prev=scores)
            return pred
        else:
            for mod in self.layers: seq, pred = mod(seq, pred)
            return pred


class DecoderLayer(nn.Module):
    def __init__(self, seq_patch_num, d_model, pred_patch_num, n_heads, d_ff=256, store_attn=False, QAM_start=0.1,
                 QAM_end=0.5,
                 attn_dropout=0, dropout=0., bias=True, res_attention=False):
        super().__init__()
        assert not d_model % n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
        # Multi-Head attention
        self.res_attention = res_attention
        self.cross_attn = _MultiheadAttention(d_model, n_heads, attn_dropout=attn_dropout, proj_dropout=dropout,
                                              res_attention=res_attention)

        # Add & Norm
        self.dropout_attn = QueryAdaptiveMasking(dim=1, start_prob=QAM_start, end_prob=QAM_end)
        self.norm_attn = nn.LayerNorm(d_model)
        # Position-wise Feed-Forward
        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
                                 GEGLU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(d_ff // 2, d_model, bias=bias))

        # Add & Norm
        self.dropout_ffn = QueryAdaptiveMasking(dim=1, start_prob=QAM_start, end_prob=QAM_end)
        self.norm_ffn = nn.LayerNorm(d_model)
        self.store_attn = store_attn

    def forward(self, seq: Tensor, pred: Tensor, prev=None) -> Tensor:
        # pred_patch: [bs * nvars x pred_patch_num x d_model]
        ## Multi-Head attention
        if self.res_attention:
            pred2, attn, scores = self.cross_attn(pred, seq, seq, prev)
        else:
            pred2, attn = self.cross_attn(pred, seq, seq)
        if self.store_attn:
            self.attn = attn
        pred = pred + self.dropout_attn(pred2)
        pred = self.norm_attn(pred)

        pred2 = self.ffn(pred)
        pred = pred + self.dropout_ffn(pred2)
        pred = self.norm_ffn(pred)

        if self.res_attention:
            return seq, pred, scores
        else:
            return seq, pred


class _MultiheadAttention(nn.Module):
    def __init__(self, d_model, n_heads, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True):
        """Multi Head Attention Layer
        Input shape:
            Q:       [batch_size (bs) x pred_patch_num x d_model]
            K, V:    [batch_size (bs) x seq_patch_num x d_model]
        """
        super().__init__()
        d_h = d_model // n_heads

        self.scale = d_h ** -0.5
        self.n_heads, self.d_h = n_heads, d_h

        self.W_Q = nn.Linear(d_model, d_h * n_heads, bias=qkv_bias)
        self.W_K = nn.Linear(d_model, d_h * n_heads, bias=qkv_bias)
        self.W_V = nn.Linear(d_model, d_h * n_heads, bias=qkv_bias)

        self.res_attention = res_attention
        self.attn_dropout = nn.Dropout(attn_dropout)

        self.to_out = nn.Sequential(nn.Linear(n_heads * d_h, d_model), nn.Dropout(proj_dropout))

    def forward(self, Q: Tensor, K: Tensor, V: Tensor, prev=None):

        bs = Q.size(0)
        # Linear (+ split in multiple heads)
        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_h)
        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_h)
        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_h)

        attn_scores = torch.einsum('bphd, bshd -> bphs', q_s, k_s) * self.scale

        if prev is not None: attn_scores = attn_scores + prev

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        output = torch.einsum('bphs, bshd -> bphd', attn_weights, v_s)
        output = output.contiguous().view(bs, -1, self.n_heads * self.d_h)
        output = self.to_out(output)

        if self.res_attention:
            return output, attn_weights, attn_scores
        else:
            return output, attn_weights


class Model(nn.Module):
    def __init__(self, configs):
        super().__init__()

        # load parameters
        self.configs = configs
        self.c_in = configs.enc_in
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.n_layers = configs.d_layers
        self.n_heads = configs.n_heads
        self.d_model = configs.d_model
        self.d_ff = configs.d_ff
        self.dropout = configs.dropout
        independence = False
        self.patch_len = configs.patch_len
        self.stride = configs.stride
        self.padding_patch = 'end'
        store_attn = False

        self.QAM_start = configs.QAM_start
        self.QAM_end = configs.QAM_end

        self.num_down_layers = configs.multiscale_levels
        self.down_sampling_window = 2
        self.down_sampling_method = 'conv'
        self.subtract_last = False

        if self.down_sampling_method == 'max':
            self.down_pool = torch.nn.MaxPool1d(self.down_sampling_window, return_indices=False)
        elif self.down_sampling_method == 'avg':
            self.down_pool = torch.nn.AvgPool1d(self.down_sampling_window)
        elif self.down_sampling_method == 'conv':
            padding = 1 if torch.__version__ >= '1.5.0' else 2
            self.down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in,
                                       kernel_size=3, padding=padding,
                                       stride=self.down_sampling_window,
                                       padding_mode='circular',
                                       bias=False)

        self.decomposition = series_decomp(configs.moving_avg)
        self.season_enc = nn.ModuleList([
                                            Model_backbone(c_in=self.c_in, seq_len=self.seq_len,
                                                           pred_len=self.d_model, patch_len=self.patch_len,
                                                           stride=self.stride, n_layers=self.n_layers,
                                                           d_model=self.d_model, n_heads=self.n_heads,
                                                           d_ff=self.d_ff, dropout=self.dropout,
                                                           independence=independence,
                                                           store_attn=store_attn, padding_patch=self.padding_patch,
                                                           QAM_start=self.QAM_start, QAM_end=self.QAM_end)
                                        ] + [
                                            nn.Linear(in_features=self.seq_len // (2 ** i),
                                                      out_features=self.d_model // (2 ** i), bias=True)
                                            for i in range(1, self.num_down_layers + 1)
                                        ])  # first layer use backbone, others use Linear

        pred_patch_num = (self.d_model + self.patch_len - 1) // self.patch_len
        seq_patch_num = int((self.seq_len - self.patch_len) / self.stride + 1)
        if self.padding_patch == 'end':
            self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
            seq_patch_num += 1
        self.trend_dummies = nn.Parameter(0.5 * torch.randn(self.c_in, pred_patch_num, self.patch_len))
        self.trend_pe = nn.Parameter(0.04 * torch.rand(seq_patch_num, self.d_model) - 0.02)
        self.trend_wp = nn.Linear(self.patch_len, self.d_model)
        self.trend_enc = nn.Linear(seq_patch_num, pred_patch_num)
        self.trend_proj = Projection(self.d_model, self.patch_len)
        self.trend_proj2 = nn.Linear(pred_patch_num * self.patch_len, self.pred_len)

        proj_len = 0
        for i in range(self.num_down_layers + 1):
            proj_len += configs.d_model // (2 ** i)
        self.projection = nn.Linear(proj_len, configs.pred_len, bias=True)

    def __multi_scale_process_inputs(self, x_enc):
        if self.down_sampling_method not in ['conv', 'max', 'avg']:
            return x_enc
        # B,T,C -> B,C,T
        x_enc = x_enc.permute(0, 2, 1)
        x_enc_ori = x_enc
        x_enc_sampling_list = []
        x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
        for i in range(self.num_down_layers):
            x_enc_sampling = self.down_pool(x_enc_ori)
            x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
            x_enc_ori = x_enc_sampling
        x_enc = x_enc_sampling_list
        return x_enc

    def seasonal_encoder(self, seasonal_init):
        # Normalization from Non-stationary Transformer
        means = seasonal_init.mean(1, keepdim=True).detach()
        seasonal_init = seasonal_init - means
        stdev = torch.sqrt(torch.var(seasonal_init, dim=1, keepdim=True, unbiased=False) + 1e-5)
        seasonal_init /= stdev

        seasonal_init = self.__multi_scale_process_inputs(seasonal_init)
        # Embedding
        season_out_list = []

        if self.subtract_last:
            seq_last = seasonal_init[0][:, -1:, :].detach()
            seasonal_init[0] = seasonal_init[0] - seq_last

        season_out = self.season_enc[0](seasonal_init[0].permute(0, 2, 1))

        if self.subtract_last:
            season_out = season_out + seq_last.permute(0, 2, 1)

        season_out_list.append(season_out)
        for i, season_x in zip(range(len(seasonal_init)), seasonal_init):
            if i == 0:
                continue
            if self.subtract_last:
                seq_last = season_x[:, -1:, :].detach()
                season_x = season_x - seq_last
            season_out = self.season_enc[i](season_x.permute(0, 2, 1))
            if self.subtract_last:
                season_out = season_out + seq_last.permute(0, 2, 1)
            season_out_list.append(season_out)
        season_out_cat = torch.cat(season_out_list, -1)
        season_out = self.projection(season_out_cat)

        # De-Normalization from Non-stationary Transformer
        season_out = season_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)).permute(0, 2, 1)
        season_out = season_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)).permute(0, 2, 1)
        return season_out

    def trend_encoder(self, trend_init):
        # Normalization from Non-stationary Transformer
        means = trend_init.mean(1, keepdim=True).detach()
        trend_init = trend_init - means
        stdev = torch.sqrt(torch.var(trend_init, dim=1, keepdim=True, unbiased=False) + 1e-5)
        trend_init /= stdev
        if self.subtract_last:
            seq_last = trend_init[:, -1:, :].detach()
            trend_init = trend_init - seq_last
        trend_init = trend_init.permute(0, 2, 1)
        # do patching
        if self.padding_patch == 'end':
            trend_init = self.padding_patch_layer(trend_init)
        trend_init = trend_init.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        # trend_init: [bs x nvars x seq_patch_num x patch_len]

        bs = trend_init.shape[0]
        n_vars = trend_init.shape[1]
        # Input encoding
        trend_init = self.trend_wp(trend_init) + self.trend_pe
        trend_init = torch.reshape(trend_init, (bs * n_vars, trend_init.shape[2], trend_init.shape[3]))
        # trend_init: [bs * nvars x seq_patch_num x d_model]
        trend_out = self.trend_enc(trend_init.permute(0, 2, 1)).permute(0, 2, 1)
        # trend_out: [bs * nvars x patch_num x d_model]
        trend_out = torch.reshape(trend_out, (-1, n_vars, trend_out.shape[-2], trend_out.shape[-1]))
        # trend_out: [bs x nvars x pred_patch_num x d_model]
        trend_out = self.trend_proj(trend_out)
        trend_out = self.trend_proj2(trend_out)
        if self.subtract_last:
            trend_out = trend_out + seq_last.permute(0, 2, 1)
        # De-Normalization from Non-stationary Transformer
        trend_out = trend_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1).permute(0, 2, 1))
        trend_out = trend_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1).permute(0, 2, 1))
        return trend_out

    def forecast(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None):
        seasonal_init, trend_init = self.decomposition(x_enc)

        season_out = self.seasonal_encoder(seasonal_init)
        trend_out = self.trend_encoder(trend_init)

        dec_out = season_out + trend_out
        dec_out = dec_out.permute(0, 2, 1)
        return dec_out

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, target_x=None):
        # x: [Batch, Input length, Channel]

        # Normalization from Non-stationary Transformer
        means = target_x.mean(1, keepdim=True).detach() \
            if target_x is not None else x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc - means
        stdev = torch.sqrt(torch.var(target_x, dim=1, keepdim=True, unbiased=False) + 1e-5) \
            if target_x is not None else torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc /= stdev

        if self.subtract_last:
            seq_last = x_enc[:, -1:, :].detach()
            x_enc = x_enc - seq_last

        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)

        if self.subtract_last:
            dec_out = dec_out + seq_last

        # De-Normalization from Non-stationary Transformer
        dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

        return dec_out