import torch
import torch.nn as nn
import math

from layers.xPatch_layers import DECOMP, Network


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()

        # Parameters
        seq_len = configs.seq_len   # lookback window L
        pred_len = configs.pred_len # prediction length (96, 192, 336, 720)

        # Patching
        patch_len = 16
        stride = 8
        padding_patch = 'end'

        # Moving Average
        self.ma_type = 'ema'
        alpha = 0.3       # smoothing factor for EMA (Exponential Moving Average)
        beta = 0.3         # smoothing factor for DEMA (Double Exponential Moving Average)

        self.decomp = DECOMP(self.ma_type, alpha, beta)
        self.net = Network(seq_len, pred_len, patch_len, stride, padding_patch)
        # self.net_mlp = NetworkMLP(seq_len, pred_len) # For ablation study with MLP-only stream
        # self.net_cnn = NetworkCNN(seq_len, pred_len, patch_len, stride, padding_patch) # For ablation study with CNN-only stream

    def forecast(self, x):
        # x: [Batch, Input, Channel]
        if self.ma_type == 'reg':   # If no decomposition, directly pass the input to the network
            x = self.net(x, x)
            # x = self.net_mlp(x) # For ablation study with MLP-only stream
            # x = self.net_cnn(x) # For ablation study with CNN-only stream
        else:
            seasonal_init, trend_init = self.decomp(x)
            x = self.net(seasonal_init, trend_init)
        return x

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, target_x=None):
        B, L, N, C = x_enc.shape
        # Normalization from Non-stationary Transformer
        means = target_x.mean(1, keepdim=True).detach() \
            if target_x is not None else x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc - means
        stdev = torch.sqrt(torch.var(target_x, dim=1, keepdim=True, unbiased=False) + 1e-5) \
            if target_x is not None else torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc /= stdev

        x_enc = x_enc.reshape([B, L, -1])
        dec_out = self.forecast(x_enc)[:, -self.pred_len:, -N * C:]
        dec_out = dec_out.reshape([B, dec_out.shape[1], N, C])  # [B, L, N, C]

        # De-Normalization from Non-stationary Transformer
        dec_out = dec_out * (stdev[:, 0, :, :].unsqueeze(1).repeat(1, self.pred_len, 1, 1))
        dec_out = dec_out + (means[:, 0, :, :].unsqueeze(1).repeat(1, self.pred_len, 1, 1))
        return dec_out
