import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft as fft

import numpy as np
from einops import rearrange, reduce, repeat
import math, random
from scipy.fftpack import next_fast_len

from typing import Dict
from math import sqrt

class MovingAverage(nn.Module):
    def __init__(self, kernel_size, stride):
        super(MovingAverage, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.AvgPool = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
    def forward(self, x):
        front_padding = x[:,0:1,:].repeat(1, (self.kernel_size-1)//2, 1)
        back_padding = x[:,-1:,:].repeat(1, (self.kernel_size-1)//2, 1)

        x = torch.cat([front_padding, x, back_padding], dim=1)
        x = self.AvgPool(x.permute(0,2,1))
        x = x.permute(0,2,1)
        return x 

class TimeSeriesDecomposition(nn.Module):
    def __init__(self, kernel_size, stride):
        super(TimeSeriesDecomposition, self).__init__()

        self.MovingAverage = MovingAverage(kernel_size, stride)
    
    def forward(self, x):
        #(B,L,D) -> (B,L,D), (B,L,D), (B,L//2,D), (B,L,D)
        trend = self.MovingAverage(x)
        seasonality = x-trend
        
        return trend, seasonality

class Mixerlayer(nn.Module):
    def __init__(self, in_channels, out_channels, seq_len):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, out_channels)
        self.fc2 = nn.Linear(seq_len, seq_len)

    def forward(self, x):
        out = self.fc1(x)
        out = F.gelu(out)
        out = out.permute(0,2,1)
        out = self.fc2(out)
        out = out.permute(0,2,1)
        out = F.gelu(out)

        return out

class TrendGenerator(nn.Module):
    def __init__(self, d_model, d_rep, d_feat, d_latent, seq_len, pred_len, k, n_layer, dropout):
        super().__init__()
        self.fcn = nn.Sequential(*[
            Mixerlayer(
                in_channels= d_rep if i > 0 else d_model,
                out_channels=d_rep, 
                seq_len = seq_len
                )
                for i in range(n_layer)
        ])
        self.W_feature = nn.Linear(seq_len*d_rep, d_feat)
        self.W_weight = nn.Linear(d_feat, d_latent * d_latent)
        self.W_bias = nn.Linear(d_feat, pred_len)
        self.dropout = nn.Dropout(dropout)


    def forward(self,x):
        B, L, D = x.shape
        x = self.fcn(x)
        x = x.reshape(B,-1)
        feature = F.gelu(self.W_feature(x))
        weight = self.W_weight(feature)

        bias = self.W_bias(feature)
        weight = self.dropout(weight)
        bias = self.dropout(bias)
        
        return weight, bias

class DilatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1):
        super().__init__()
        self.receptive_field = (kernel_size - 1) * dilation + 1
        self.padding = (self.receptive_field)//2
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            padding=self.padding,
            dilation=dilation,
            groups=groups
        )
        self.remove=1 if self.receptive_field %2 ==0 else 0 
        
    def forward(self, x):
        out = self.conv(x)
        if self.remove > 0:
            out = out[:, :, : -self.remove]
        out = F.gelu(out)
        return out

class SeasonalGenerator(nn.Module):
    def __init__(self, d_model, d_rep, d_feat, d_latent, seq_len, pred_len, k, n_layer, dropout):
        super().__init__()
        self.tcn = nn.Sequential(*[
            DilatedConv(
                in_channels= d_rep if i > 0 else d_model,
                out_channels=d_rep,
                kernel_size=k,
                dilation=2**i
                )
                for i in range(n_layer)
        ])
        self.W_feature = nn.Linear(seq_len*d_rep, d_feat)
        self.W_weight = nn.Linear(d_feat, d_latent * d_latent)
        self.W_bias = nn.Linear(d_feat, pred_len)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        B, L, D = x.shape
        x = x.permute(0,2,1)
        x = self.tcn(x)
        x = x.reshape(B,-1)
        feature = F.gelu(self.W_feature(x))
        weight = self.W_weight(feature)

        bias = self.W_bias(feature)
        weight = self.dropout(weight)
        bias = self.dropout(bias)
        return weight, bias
    

class PredictorGenerator(nn.Module):
    def __init__(self, seq_len, pred_len, d_model, d_rep, d_feat, d_latent, k, n_layer, dropout):
        super().__init__()
        self.seq_len = seq_len
        self.pred_len = pred_len 
        self.d_div = d_latent

        self.trend_predictor_generator = TrendGenerator(d_model, d_rep, d_feat, d_latent, seq_len, pred_len, k, n_layer, dropout)
        self.seasonal_predictor_generator = SeasonalGenerator(d_model, d_rep, d_feat, d_latent, seq_len, pred_len, k, n_layer, dropout)

    def forward(self, trend, seasonal):
        trend_weight, trend_bias = self.trend_predictor_generator(trend)
        seasonal_weight, seasonal_bias = self.seasonal_predictor_generator(seasonal)

        weight_gen = trend_weight + seasonal_weight
        weight_gen = weight_gen.view(-1, self.d_div, self.d_div)

        bias_gen = trend_bias + seasonal_bias

        return weight_gen, bias_gen



class LGPred(nn.Module):
    def __init__(self, seq_len, pred_len, d_model, d_rep, d_feat, d_latent, k, dropout, n_layer):
        super().__init__()
        self.seq_len = seq_len
        self.pred_len = pred_len 
        self.d_model = d_model
        self.Linear_base1 = nn.Linear(self.seq_len, d_latent)
        self.Linear_base2 = nn.ModuleList()
        self.Linear_base2 = nn.Linear(d_latent, self.pred_len)
        self.Decomposition = TimeSeriesDecomposition(kernel_size=7, stride=1) 

        self.generator = PredictorGenerator(seq_len, pred_len, d_model, d_rep, d_feat, d_latent, k, n_layer, dropout)
    
    def forward(self,x):
        B, L, D = x.shape
        seq_last = x[:,-1:,:].detach()
        trend, seasonality = self.Decomposition(x)
        weight_gen, bias_gen = self.generator(trend, seasonality)
        x = x-seq_last
        predict = self.Linear_base1(x.permute(0,2,1)).permute(0,2,1)
        predict = torch.bmm(weight_gen, predict)
        predict = self.Linear_base2(predict.permute(0,2,1)).permute(0,2,1) + bias_gen.unsqueeze(2)

        preds = predict
        preds = preds + seq_last

        return preds, trend, seasonality, None
        




if __name__ == '__main__':
    model =LGPred(720, 720, 862, 128, 128, 256, 5, 0.1, 4)
    x = torch.randn(16,720,862)
    y = torch.randn(16,720,862)
    x = x.to("cuda")
    model.to("cuda")
    y = y.to("cuda")
    output, trends, seasonals, res =model(x)
    print(output)
    loss = nn.MSELoss()
    MSE_loss = loss(output, y)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    MSE_loss.backward() 
    optimizer.step()
    # When individual=True, maximum batchsize=4 