import torch
import torch.nn as nn
from layers.Embed import PositionalEmbedding
import math


import torch
import torch.nn as nn
from layers.Embed import PositionalEmbedding

class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()

        # get parameters
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        self.period_len = configs.period_len
        self.com_len = configs.com_len


        self.seg_num_x = self.seq_len // self.period_len
        self.seg_num_y = self.pred_len // self.period_len

        self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1 + 2 * self.period_len // 2,
                                stride=1, padding=self.period_len // 2, padding_mode="zeros", bias=False)

        self.linear = nn.Linear(self.seg_num_x, self.seg_num_y, bias=False)


        self.linear1 = nn.Linear(self.seg_num_x, self.com_len, bias=False)
        # self.linear2 = nn.Linear(self.com_len, self.seg_num_x+self.seg_num_y, bias=False)
        self.linear2 = nn.Linear(self.com_len, self.seg_num_y, bias=False)

        self.lpf = configs.lpf
        self.alpha = configs.alpha

        # FLinear
        self.FLinear1 = nn.Linear(self.lpf, 2, bias=False).to(torch.cfloat)
        # self.FLinear2 = nn.Linear(1, self.seg_num_x+self.seg_num_y, bias=False).to(torch.cfloat)
        self.FLinear2 = nn.Linear(2, self.seg_num_y, bias=False).to(torch.cfloat)




    def forward(self, x):
        batch_size = x.shape[0]
        # normalization and permute     b,s,c -> b,c,s
        seq_mean = torch.mean(x, dim=1).unsqueeze(1)
        x = (x - seq_mean).permute(0, 2, 1)

        # 1D convolution aggregation
        x = self.conv1d(x.reshape(-1, 1, self.seq_len)).reshape(-1, self.enc_in, self.seq_len) + x

        # downsampling: b,c,s -> bc,n,w -> bc,w,n
        x = x.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1)

        # sparse forecasting
        # y = self.linear(x)  # bc,w,m

        y_t = self.linear2(self.linear1(x))  # bc,w,m

        # Frequency Domain

        x_fft = torch.fft.fft(x, dim=2)[:,:,:self.lpf]
        # x_fft = x_fft.view(-1,self.lpf)
        x_fft = self.FLinear2(self.FLinear1(x_fft))

        y_f = torch.fft.ifft(x_fft, dim=2).float()
        # y_f = x_rfft.permute(0, 2, 1).reshape(batch_size, self.enc_in, self.pred_len)

        print("shape", y_t.shape, y_f.shape)


        y = y_t*self.alpha+y_f*(1-self.alpha)
        # y = y_t

        # upsampling: bc,w,m -> bc,m,w -> b,c,s
        y = y.permute(0, 2, 1).reshape(batch_size, self.enc_in, -1)

        # permute and denorm
        y = y.permute(0, 2, 1) + seq_mean

        return y





