import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class ConvLayer(nn.Module):
    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        self.downConv = nn.Conv1d(in_channels=c_in,
                                  out_channels=c_in,
                                  kernel_size=3,
                                  padding=2,
                                  padding_mode='circular')
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.downConv(x.permute(0, 2, 1))
        x = self.norm(x)
        x = self.activation(x)
        x = self.maxPool(x)
        x = x.transpose(1, 2)
        return x


class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, featscale=True, dropout=0.1, activation="relu", gamma=3, alpha=0.5):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.gamma = gamma
        self.alpha = alpha
        
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)

        self.gen1 = nn.Linear(d_model, d_model)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

        self.featscale = featscale
        if self.featscale:
            self.lamb1 = nn.Parameter(torch.zeros(d_model), requires_grad=True)
            self.lamb2 = nn.Parameter(torch.randn(d_model), requires_grad=True)

    def freq_decompose(self, x):
        batch_size, num_vars, embedding_dim = x.shape

        # Fourier Transformation
        # x_fft = torch.fft.rfft(x, dim=1, norm='ortho')  # 形状: [batch, seq_len // 2 + 1, 1]
        # x_d = x_fft.clone()
        # x_d[:, self.gamma:, :] = 0
        # x_d = torch.fft.irfft(x_d, n=num_vars, dim=1, norm='ortho').real.float()  # 形状: [batch, embedding_dim, 1]

        # Introducing FAN design
        x_fft = torch.fft.rfft(x, dim=1, norm='ortho')
        k_values = torch.topk(x_fft.abs(), self.gamma, dim = 1)
        indices = k_values.indices
        mask = torch.zeros_like(x_fft)
        mask.scatter_(1, indices, 1)
        x_filtered = x_fft * mask
        x_d = torch.fft.irfft(x_filtered, n=num_vars, dim=1, norm='ortho').real.float()

        # Approximation Fourier Transformation
        # x_d = torch.mean(x, 1, keepdim=True) # [bs, 1, dim]

        x_h = x - x_d
        
        return x_d, x_h

    def forward(self, x, attn_mask=None, tau=None, delta=None):

        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta
        )

        if self.featscale:
            x_d, x_h = self.freq_decompose(x)
            x_d = x_d * self.lamb1
            x_h = x_h * self.lamb2
            x = x + x_d + x_h

        x = x + self.dropout(new_x)
        feature_map = x.detach().cpu() 
        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1((y).transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(y + x), attn, feature_map



class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        
        self.norm = norm_layer

    def forward(self, x, attn_mask=None, tau=None, delta=None):

        attns = []
        feature_maps = []

        for attn_layer in self.attn_layers:
            x, attn, feature_map = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
            attns.append(attn)
            feature_maps.append(feature_map)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns, feature_maps


class DecoderLayer(nn.Module):
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu"):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
        x = x + self.dropout(self.self_attention(
            x, x, x,
            attn_mask=x_mask,
            tau=tau, delta=None
        )[0])
        x = self.norm1(x)

        x = x + self.dropout(self.cross_attention(
            x, cross, cross,
            attn_mask=cross_mask,
            tau=tau, delta=delta
        )[0])

        y = x = self.norm2(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm3(x + y)


class Decoder(nn.Module):
    def __init__(self, layers, norm_layer=None, projection=None):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        self.projection = projection

    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
        for layer in self.layers:
            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)
        return x
