import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len

        self.channels = configs.enc_in
        self.individual = configs.individual
        self.scale = 0.02
        self.sparsity_threshold = 0.01
        self.w = nn.Parameter((1 / self.seq_len) * torch.ones([self.seq_len, self.seq_len]))


    def MLP_time(self, x):
        x = x.permute(0, 2, 1)  # B C L
        x = torch.einsum('bij,jk->bik', x, self.w)
        x = x.permute(0, 2, 1)
        return x

    def MLP_frequency_input(self, x):
        x = x.permute(0, 2, 1)  # B C L
        # FFT
        x = torch.fft.fft(x, dim=2, norm='ortho')
        # only real part
        real = x.real
        x = torch.einsum('bij,jk->bik', real, self.w)
        x = x.permute(0, 2, 1)
        return x

    def forward(self, x):
        # x: [Batch, Input length, Channel]
        x = self.MLP_time(x)
        #x = self.MLP_frequency_input(x)
        return x  # [Batch, Output length, Channel]