import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from einops import rearrange


def FFT_for_Period(x, k=2):
    xf = torch.fft.rfft(x, dim=1)
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_weights = frequency_list[top_list]
    top_list = top_list.detach().cpu().numpy()
    top_list[top_list == 0] = 1
    period = x.shape[1] // top_list
    return period, F.softmax(top_weights, dim=-1)


class DYN_CONVNET_Model(nn.Module):
    def __init__(self, d_model: int, k_periods: int = 3):
        super().__init__()
        self.d_model = d_model
        self.k_periods = k_periods
        assert d_model > 1, f"d_model must be greater than 1, but got {d_model}"
        self.fourier_channels = 2 * k_periods + 1
        self.fourier_conv = nn.Conv1d(self.fourier_channels, d_model, kernel_size=1)
        self.fourier_norm = nn.LayerNorm(d_model)
        self.wavelet_conv_3 = nn.Conv1d(1, d_model // 2, kernel_size=3, padding=1)
        self.wavelet_conv_5 = nn.Conv1d(1, d_model - (d_model // 2), kernel_size=5, padding=2)
        self.gating_unit = nn.Sequential(
            nn.Conv1d(self.fourier_channels, d_model, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv1d(d_model, d_model, kernel_size=1),
            nn.Sigmoid()
        )
        self.fusion_conv = nn.Conv1d(2 * d_model, d_model, kernel_size=1)
        self.fusion_norm = nn.LayerNorm(d_model)
        self.core_conv = nn.Sequential(
            nn.Conv1d(d_model, d_model, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv1d(d_model, d_model, kernel_size=3, padding=1),
        )
        self.final_projection = nn.Conv1d(d_model, 1, 1)

    def forward(self, x: torch.Tensor):
        if x.dim() != 2:
            raise ValueError("Input to DYN_CONVNET_Model must be 2D tensor [B, L]")
        B, L = x.shape
        x_reshaped = x.unsqueeze(1)
        periods, period_weights = FFT_for_Period(x.unsqueeze(-1), self.k_periods)
        fourier_features_list = [x_reshaped]
        time_steps = torch.arange(L, device=x.device)
        for i, p in enumerate(periods):
            cos_feature = torch.cos(2 * np.pi * time_steps / p).unsqueeze(0).unsqueeze(0).repeat(B, 1, 1) * \
                          period_weights[i].unsqueeze(0).unsqueeze(0)
            sin_feature = torch.sin(2 * np.pi * time_steps / p).unsqueeze(0).unsqueeze(0).repeat(B, 1, 1) * \
                          period_weights[i].unsqueeze(0).unsqueeze(0)
            fourier_features_list.append(cos_feature)
            fourier_features_list.append(sin_feature)
        fourier_features = torch.cat(fourier_features_list, dim=1)
        fourier_out = self.fourier_conv(fourier_features)
        fourier_out = self.fourier_norm(fourier_out.permute(0, 2, 1)).permute(0, 2, 1)
        wavelet_out_3 = self.wavelet_conv_3(x_reshaped)
        wavelet_out_5 = self.wavelet_conv_5(x_reshaped)
        wavelet_out = torch.cat([wavelet_out_3, wavelet_out_5], dim=1)
        gating_weights = self.gating_unit(fourier_features)
        wavelet_gated = wavelet_out * gating_weights
        fused_features = torch.cat([fourier_out, wavelet_gated], dim=1)
        fused_features = self.fusion_conv(fused_features)
        fused_features = self.fusion_norm(fused_features.permute(0, 2, 1)).permute(0, 2, 1)
        enc_out = self.core_conv(fused_features) + fused_features
        final_out = self.final_projection(enc_out)
        return final_out.squeeze(1)


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.task_name = configs.task_name
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.enc = DYN_CONVNET_Model(d_model=configs.d_model, k_periods=configs.top_k)

    def anomaly_detection(self, x_enc):
        means = x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc.sub(means)
        stdev = torch.sqrt(
            torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc = x_enc.div(stdev)
        x_input = rearrange(x_enc, "B L D -> (B D) L")
        enc_out = self.enc(x_input)
        dec_out = rearrange(enc_out, "(B D) L -> B L D", B=x_enc.size(0))
        dec_out = dec_out * stdev + means
        return dec_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == "anomaly_detection":
            dec_out = self.anomaly_detection(x_enc)
            return dec_out
        raise NotImplementedError("This model is optimized for anomaly detection.")