import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.seasonal import STL


class RecurrentCycle(torch.nn.Module):

    def __init__(self, cycle_len, channel_size):
        super(RecurrentCycle, self).__init__()
        self.cycle_len = cycle_len
        self.channel_size = channel_size
        self.data = torch.nn.Parameter(torch.zeros(cycle_len, channel_size), requires_grad=True)

    def forward(self, index, length):
        gather_index = (index.view(-1, 1) + torch.arange(length, device=index.device).view(1, -1)) % self.cycle_len

        return self.data[gather_index]


class SeriesDecomp(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        assert kernel_size % 2 == 1
        self.avg_pool = nn.AvgPool1d(
            kernel_size=kernel_size,
            stride=1,
            padding=(kernel_size - 1) // 2
        )

    def forward(self, x):
        x_perm = x.permute(0, 2, 1)

        trend = self.avg_pool(x_perm)

        trend = trend.permute(0, 2, 1)

        seasonal = x - trend

        return seasonal, trend


class Expert(nn.Module):
    def __int__(self, input, output):
        super(Expert, self).__init__()
        self.linear = nn.Linear(input, output)

    def forward(self, x):
        x = self.linear(x)
        return x


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()

        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        self.cycle_len = configs.cycle
        self.model_type = configs.model_type
        self.d_model = configs.d_model
        self.use_revin = configs.use_revin
        self.n = configs.n
        self.k = configs.k
        self.cycleQueue = RecurrentCycle(cycle_len=self.cycle_len, channel_size=self.enc_in)

        self.w = nn.Parameter(0.02 * torch.randn(1, self.seq_len, self.enc_in))
        self.act = configs.act
        self.act2 = configs.act2
        if self.act == 'RELU':

            self.expert = nn.Sequential(nn.Linear(self.enc_in, self.d_model),
                                        nn.ReLU(),
                                        nn.Linear(self.d_model, self.enc_in))
        elif self.act == 'Tanh':
            self.expert = nn.Sequential(nn.Linear(self.enc_in, self.d_model),
                                        nn.Tanh(),
                                        nn.Linear(self.d_model, self.enc_in))
        else:
            self.expert = nn.Sequential(nn.Linear(self.enc_in, self.d_model),
                                        nn.GELU(),
                                        nn.Linear(self.d_model, self.enc_in))
        if self.act2 == 'RELU':


            self.model = nn.Sequential(nn.Linear(self.seq_len, self.pred_len),
                                       nn.ReLU(),
                                       nn.Linear(self.pred_len, self.pred_len),
                                       nn.ReLU(),
                                       nn.Dropout(0.6),
                                       nn.Linear(self.pred_len, self.pred_len))
        elif self.act2 == 'Tanh':
            self.model = nn.Sequential(nn.Linear(self.seq_len, self.pred_len),
                                       nn.Tanh(),
                                       nn.Linear(self.pred_len, self.pred_len),
                                       nn.Tanh(),
                                       nn.Dropout(0.6),
                                       nn.Linear(self.pred_len, self.pred_len))
        elif self.act2 == 'GELU':
            self.model = nn.Sequential(nn.Linear(self.seq_len, self.pred_len),
                                       nn.GELU(),
                                       nn.Linear(self.pred_len, self.pred_len),
                                       nn.GELU(),
                                       nn.Dropout(0.6),
                                       nn.Linear(self.pred_len, self.pred_len))
        else:
            self.model = nn.Sequential(nn.Linear(self.seq_len, self.pred_len),
                                       nn.SiLU(),
                                       nn.Linear(self.pred_len, self.pred_len),
                                       nn.GELU(),
                                       nn.Dropout(0.6),
                                       nn.Linear(self.pred_len, self.pred_len))

        self.decom = SeriesDecomp(kernel_size=3)
        self.experts = nn.ModuleList([self.expert for _ in range(self.n)])
        self.gate = nn.Sequential(nn.Linear(self.enc_in, self.k))


    def forward(self, x, cycle_index):


        if self.use_revin:
            seq_mean = torch.mean(x, dim=1, keepdim=True)
            seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
            x = (x - seq_mean) / torch.sqrt(seq_var)
        k = x
        list = []
        for i in range(self.n):
            r,_ = self.decom(k)
            list.append(r)
            k = r


        x = x - self.cycleQueue(cycle_index, self.seq_len)

        out_list = []

        for r in list:
            B, L, C = r.shape

            r = r.reshape(B * L, C)

            gates = self.gate(r)

            gates = F.softmax(gates, dim=-1)

            top_k_weights, top_k_indices = torch.topk(gates, k=self.k, dim=-1)

            zeros = torch.zeros((B * L, self.enc_in), device=r.device, dtype=r.dtype)
            final_output = zeros

            # lst = []
            for i in range(self.k):

                expert_mask = (top_k_indices == i).any(dim=-1)

                token_indices = expert_mask.nonzero(as_tuple=True)[0]

                expert_input = r[token_indices]

                expert_output = self.experts[i](expert_input)
                final_output[token_indices] = expert_output

            final_output = final_output.view(B, L, C)
            out_list.append(final_output)

        for i in out_list:
            x += i

        x += self.w

        y = self.model(x.permute(0, 2, 1)).permute(0, 2, 1)

        y = y + self.cycleQueue((cycle_index + self.seq_len) % self.cycle_len, self.pred_len)

        if self.use_revin:
            y = y * torch.sqrt(seq_var) + seq_mean

        return y
