import torch
from torch import nn
import torch_dct as dct
import math
from layers.RevIN import RevIN


class Model(nn.Module):
    def __init__(self, configs, revin=True, affine=True, subtract_last=False):
        super().__init__()

        self.enc_in = configs.enc_in

        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len

        self.length_ratio = (self.seq_len + self.pred_len) / self.seq_len

        self.seg_len = configs.period_len


        self.revin_layer = RevIN(self.enc_in) if revin else None

        self.mask1 = nn.Parameter(torch.ones(self.seq_len) * 0.1)
        self.mask2 = nn.Parameter(torch.ones(self.seq_len) * 0.1)
        self.mask3 = nn.Parameter(torch.ones(self.seq_len) * 0.1)

        self.dropout = nn.Dropout(0.1)


        self.FLinear = nn.Linear(self.seq_len, self.pred_len)

    def forward(self, x):


        # diff_scale = True
        # if diff_scale:
        #     return self.V2(x, self.scale)

        # scale1 = int(self.seq_len / 30)
        # scale2 = int(self.seq_len / 10)
        # scale3 = int(self.seq_len / 2)

        # x_v1 = self.V2(x, scale1) * 0.98
        # x_v2 = self.V2(x, scale2) * 0.01
        # x_v3 = self.V2(x, scale3) * 0.01

        if self.seg_len ==24:
            scale1 = 2
        else:
            scale1 = self.seg_len

        # scale2 = int(self.seq_len/2)
        scale2 = int(self.seq_len/30)
        scale3 = self.seq_len

        # x_v1 = self.V2(x, scale1,self.mask1) * 0.998
        # x_v2 = self.V2(x, scale2,self.mask2) * 0.001
        # x_v3 = self.V2(x, scale3,self.mask3) * 0.001

        x_v1 = self.V2(x, scale1, self.mask1) * 0.33
        x_v2 = self.V2(x, scale2, self.mask2) * 0.33
        x_v3 = self.V2(x, scale3, self.mask3) * 0.33

        # alpha = (torch.mean(x_v2)+torch.mean(x_v3))/(torch.mean(x_v1)+torch.mean(x_v2)+torch.mean(x_v3))
        # beta = (1-alpha)*torch.mean(x_v3)/(torch.mean(x_v2)+torch.mean(x_v3))
        # gama = (1-alpha)*torch.mean(x_v2)/(torch.mean(x_v2)+torch.mean(x_v3))

        return x_v1  + x_v2 + x_v3

    def V2(self, x, scale,mask):
        if self.revin_layer:
            x = self.revin_layer(x, 'norm')

        batch_size = x.shape[0]
        seq_mean = torch.mean(x, dim=1).unsqueeze(1)
        x_var = torch.var(x, dim=1, keepdim=True) + 1e-5
        x = (x - seq_mean).permute(0, 2, 1)

        # Apply DCT
        x_dct_os = [dct.dct(x[:, :, i * scale:(i + 1) * scale]).reshape(batch_size * self.enc_in, -1)
                    for i in range(int(self.seq_len / scale))]
        x_dct_os = torch.hstack(x_dct_os).reshape(batch_size, self.enc_in, -1)

        #no mask
        # x_dct = x_dct_os

        #mask
        x_dct = x_dct_os * mask.unsqueeze(0).unsqueeze(0)

        x = self.FLinear(x_dct)

        x_f1 = dct.idct(x).permute(0, 2, 1)
        x_f1 = x_f1 * self.length_ratio * torch.sqrt(x_var)
        x = x_f1 + seq_mean

        if self.revin_layer:
            x = self.revin_layer(x, 'denorm')

        return x



