import torch
import torch.nn as nn
from layers.RevIN import RevIN
from layers.Energy import EnergyEnhancer, EnergyPredictor

class moving_avg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """

    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class series_decomp(nn.Module):
    """
    Series decomposition block
    """

    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.channels = configs.enc_in
        self.hidden_size = configs.hidden_size
        self.revin_layer = RevIN(configs.enc_in, affine=True, subtract_last=False)
        self.task_name = configs.task_name

        kernel_size = 25
        self.decompsition = series_decomp(kernel_size)

        self.linear_seasonal = nn.Sequential(
            nn.Linear(self.seq_len, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.pred_len)
        )

        self.linear_trend = nn.Sequential(
            nn.Linear(self.seq_len, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.pred_len)
        )

        # SCI block
        self.SCI = configs.SCI
        self.extract_common_pattern = nn.Sequential(
            nn.Linear(self.channels, self.channels),
            nn.LeakyReLU(),
            nn.Linear(self.channels, 1)
        )

        self.model_common_pattern = nn.Sequential(
            nn.Linear(self.seq_len, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.seq_len)
        )

        self.model_spacific_pattern = nn.Sequential(
            nn.Linear(self.seq_len, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.seq_len)
        )

        self.embed_dim = configs.embed_dim
        self.embed_dim_out = configs.embed_dim_out
        self.lambda_init = configs.lambda_init
        self.alpha = configs.alpha
        self.energy_enhancer = EnergyEnhancer(self.seq_len, self.channels, self.embed_dim, self.lambda_init, self.alpha)
        self.energy_predictor = EnergyPredictor(self.seq_len, self.pred_len, self.embed_dim_out)

    def encoder(self, x):
        B, T, C = x.size()

        # RevIN
        z = x
        z = self.revin_layer(z, 'norm')
        x = z

        x_denoised, x_inverse_fft, loss_nonstat = self.energy_enhancer(x)

        # SCI block
        if self.SCI:
            x = x_denoised
            # extract common pattern
            common_pattern = self.extract_common_pattern(x)
            common_pattern = self.model_common_pattern(common_pattern.permute(0, 2, 1)).permute(0, 2, 1)
            # model specific pattern
            specififc_pattern = x - common_pattern.repeat(1, 1, C)
            specififc_pattern = self.model_spacific_pattern(specififc_pattern.permute(0, 2, 1)).permute(0, 2, 1)

            x = specififc_pattern + common_pattern.repeat(1, 1, C)
            x_denoised = x

        # Seasonal Trend Forecaster
        seasonal, trend = self.decompsition(x_denoised)
        seasonal = self.linear_seasonal(seasonal.permute(0, 2, 1)).permute(0, 2, 1)
        trend = self.linear_trend(trend.permute(0, 2, 1)).permute(0, 2, 1)
        out = seasonal + trend

        out = self.energy_predictor(x_inverse_fft, out) # (bs, pred_len, n_vars)

        # inverse RevIN
        z = out
        z = self.revin_layer(z, 'denorm')
        out = z

        return out, loss_nonstat

    def forecast(self, x_enc):
        # Encoder
        out, loss_nonstat = self.encoder(x_enc)
        return out, loss_nonstat

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            dec_out, loss_nonstat = self.forecast(x_enc)
            return dec_out[:, -self.pred_len:, :], loss_nonstat
        return None

if __name__ == '__main__':
    pass