import torch
import torch.nn as nn
from model.embed import DataEmbedding
from model.local_global import Seasonal_Prediction, series_decomp_multi


class Model(nn.Module):
    def __init__(self, configs):
    
        super(Model, self).__init__()
        decomp_kernel = []  # kernel of decomposition operation 
        isometric_kernel = []  # kernel of isometric convolution
        for ii in configs.conv_kernel:
            if ii%2 == 0:   # the kernel of decomposition operation must be odd
                decomp_kernel.append(ii+1)
                isometric_kernel.append((configs.seq_len + configs.pred_len+ii) // ii) 
            else:
                decomp_kernel.append(ii)
                isometric_kernel.append((configs.seq_len + configs.pred_len+ii-1) // ii) 

        dec_in = configs.dec_in
        self.out_len = configs.pred_len
        self.pred_len = configs.pred_len
        self.seq_len = configs.seq_len
        self.c_out = configs.c_out
        self.decomp_kernel = decomp_kernel
        self.conv_kernel = configs.conv_kernel
        self.isometric_kernel = isometric_kernel
        self.mode = configs.mode
        self.device = configs.gpu
        self.dropout = configs.dropout
        self.d_model = configs.d_model

        self.decomp_multi = series_decomp_multi(decomp_kernel)

        # embedding
        self.dec_embedding = DataEmbedding(dec_in, self.d_model, configs.embed, configs.freq, self.dropout)

        self.conv_trans = Seasonal_Prediction(embedding_size=self.d_model, n_heads=configs.n_heads, dropout=self.dropout,
                                     d_layers=configs.d_layers, decomp_kernel=self.decomp_kernel, c_out=self.c_out, conv_kernel=self.conv_kernel,
                                     isometric_kernel=self.isometric_kernel, device=self.device)

        self.regression = nn.Linear(self.seq_len, self.out_len)
        self.regression.weight = nn.Parameter((1/self.out_len) * torch.ones([self.out_len, self.seq_len]), requires_grad=True)

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
                enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):

        # trend-cyclical prediction block: regre or mean
        if self.mode == 'regre':
            seasonal_init_enc, trend = self.decomp_multi(x_enc)
            trend = self.regression(trend.permute(0,2,1)).permute(0, 2, 1)
        elif self.mode == 'mean':
            mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
            seasonal_init_enc, trend = self.decomp_multi(x_enc)
            trend = torch.cat([trend[:, -self.seq_len:, :], mean], dim=1)

        # embedding
        zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device)
        seasonal_init_dec = torch.cat([seasonal_init_enc[:, -self.seq_len:, :], zeros], dim=1)
        dec_out = self.dec_embedding(seasonal_init_dec, x_mark_dec)

        dec_out = self.conv_trans(dec_out)
        dec_out = dec_out[:, -self.pred_len:, :] + trend[:, -self.pred_len:, :]
        return dec_out

