import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import math
from model.base import LOBAutoEncoder
from model.layers.Embed import DataEmbedding
from model.layers.Conv_Blocks import Inception_Block_V1
torch.autograd.set_detect_anomaly(True)
    
def FFT_for_Period(x, k=2):
    # [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    period = x.shape[1] // top_list
    return period, abs(xf).mean(-1)[:, top_list]


class TimesBlock(nn.Module):
    def __init__(self,seq_len,pred_len,top_k,d_model,d_ff,num_kernels):
        super(TimesBlock, self).__init__()
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.k = top_k
        self.conv = nn.Sequential(
            Inception_Block_V1(d_model, d_ff, num_kernels=num_kernels),
            nn.GELU(),
            Inception_Block_V1(d_ff, d_model, num_kernels=num_kernels)
        )
        
    def forward(self,x):
        B, T, N = x.size()
        period_list, period_weight = FFT_for_Period(x, self.k)

        res = []
        for i in range(self.k):
            period = period_list[i]
            # padding
            if (self.seq_len + self.pred_len) % period != 0:
                length = (
                                 ((self.seq_len + self.pred_len) // period) + 1) * period
                padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
                out = torch.cat([x, padding], dim=1)
            else:
                length = (self.seq_len + self.pred_len)
                out = x
            # reshape
            out = out.reshape(B, length // period, period,
                              N).permute(0, 3, 1, 2).contiguous()
            # 2D conv: from 1d Variation to 2d Variation
            out = self.conv(out)
            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
            res.append(out[:, :(self.seq_len + self.pred_len), :])
        res = torch.stack(res, dim=-1)
        # adaptive aggregation
        period_weight = F.softmax(period_weight, dim=1)
        period_weight = period_weight.unsqueeze(
            1).unsqueeze(1).repeat(1, T, N, 1)
        res = torch.sum(res * period_weight, -1)
        # residual connection
        res = res + x
        return res
    
class TimesNet_AE(LOBAutoEncoder):

    def __init__(self, 
                 seq_len,
                 pred_len,
                 e_layers,
                 enc_in,
                 d_model,
                 dim_ff,
                 embed,
                 freq,
                 dropout,
                 top_k,
                 num_kernels,
                 c_out,
                 **kwargs):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.enc_in = enc_in
        self.layer = e_layers
        self.model = nn.ModuleList([TimesBlock(seq_len,pred_len,top_k,d_model,dim_ff,num_kernels)
                                    for _ in range(e_layers)])
        self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
        self.layer = e_layers
        self.layer_norm = nn.LayerNorm(d_model)
        self.linear_encoding = nn.Linear(in_features=d_model * seq_len, out_features=d_model)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead=8, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.projection = nn.Linear(d_model, c_out, bias=True)
    
    def encode(self, x):
        # embedding
        enc_out = self.enc_embedding(x, None)  # [B,T,C]
        # TimesNet
        for i in range(self.layer):
            enc_out = self.layer_norm(self.model[i](enc_out))
        enc_out = self.linear_encoding(enc_out.view((-1, self.seq_len * self.d_model)))
        return enc_out
        
    def forward(self, x):
        # Normalization from Non-stationary Transformer
        means = x.mean(1, keepdim=True).detach()
        x = x - means
        stdev = torch.sqrt(
            torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x /= stdev

        # embedding
        enc_out = self.enc_embedding(x, None)  # [B,T,C]
        # TimesNet
        for i in range(self.layer):
            enc_out = self.layer_norm(self.model[i](enc_out))
        # porject back
        memory = torch.zeros(enc_out.shape[0], enc_out.shape[1], enc_out.shape[2], device=enc_out.device)
        enc_out = self.decoder(enc_out, memory)
        out = self.projection(enc_out)

        # De-Normalization from Non-stationary Transformer
        out = out * \
                  (stdev[:, 0, :].unsqueeze(1).repeat(
                      1, self.pred_len + self.seq_len, 1))
        out = out + \
                  (means[:, 0, :].unsqueeze(1).repeat(
                      1, self.pred_len + self.seq_len, 1))
        return out
