import torch
import torch.nn as nn

from layers.VQ_model import VQ_model
from layers.RevIN import RevIN
from layers.Embed import DataEmbedding_inverted
from layers.Transformer_EncDec import Encoder, EncoderLayer
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")



class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.pred_len = configs.pred_len
        self.seq_len = configs.seq_len
        self.d_model = configs.d_model
        self.task_name = configs.task_name
        self.dropout = configs.dropout
        self.enc_in = configs.enc_in
        self.activation = configs.activation
        self.e_layers = configs.e_layers
        self.n_heads = configs.n_heads
        self.patch_len = configs.patch_len
        self.patch_stride = configs.patch_stride
        self.d_ff = configs.d_ff
        self.d_code = configs.d_code
        self.kernel = configs.kernel
        self.stride = configs.stride
        self.patch_num = int((self.seq_len - self.patch_len) / self.patch_stride + 1)
        self.num_code = configs.num_code

        self.normalized = RevIN(self.enc_in, affine=False)
        self.embedding = DataEmbedding_inverted(c_in=self.seq_len, d_model=configs.d_model, dropout=configs.dropout)

        self.vq_model = VQ_model(seq_len=self.seq_len, d_code=self.d_code, enc_in=self.enc_in, kernel=self.kernel,
                             stride=self.stride, num_code=self.num_code, vq_layers=configs.vq_layers)

        self.freq_mapping = nn.Linear(int(self.seq_len / 2 + 1), int(self.pred_len / 2 + 1)).to(torch.cfloat)
        self.LPF = int(int((self.seq_len / 2 + 1) * (1-configs.lpf)))

        self.linear = nn.Linear(self.seq_len*2, self.seq_len)
        self.output = nn.Linear(self.d_model, self.pred_len)

        self.encoder = Encoder(
            [
                EncoderLayer(
                    ChannelMixing(configs.d_model, configs.pred_len, configs.enc_in),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation,
                ) for l in range(configs.e_layers)
            ]
        )

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):

        # We will provide the complete code after the receipt.

        pass

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
            return dec_out
        return None


class ChannelMixing(nn.Module):
    def __init__(self, d_model, pred_len, enc_in):
        super(ChannelMixing, self).__init__()

        self.enc_in = enc_in
        # We will provide the complete code after the receipt.

    def forward(self, x_enc, fft_enc, *args, **kwargs):

        # We will provide the complete code after the receipt.

        pass



