import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Autoformer_EncDec import series_decomp
from layers.Embed import DataEmbedding_wo_pos
from layers.StandardNorm import Normalize
from layers.ChebyKANLayer import ChebyKANLinear
import math
from models.KAN import KAN


class ChebyKANLayer(nn.Module):
    def __init__(self, in_features, out_features,order):
        super().__init__()
        self.fc1 = ChebyKANLinear(
                            in_features,
                            out_features,
                            order)
    def forward(self, x):
        B, N, C = x.shape
        x = self.fc1(x.reshape(B*N,C))
        x = x.reshape(B,N,-1).contiguous()
        return x

# class TemporalKAN(nn.Module):

#     def __init__(self, configs):
#         super(TemporalKAN, self).__init__()
#         self.N = configs.d_model
#         self.kans = nn.ModuleList([M_KAN(configs.d_model, configs.begin_order) for _ in range(3)])
     
#     def forward(self, x_low, x_band, x_high):
#         x_low_t = torch.fft.irfft(x_low, n=self.N, dim=1, norm='ortho')
#         x_band_t = torch.fft.irfft(x_band, n=self.N, dim=1, norm='ortho')
#         x_high_t = torch.fft.irfft(x_high, n=self.N, dim=1, norm='ortho')
#         out_1 = self.kans[0](x_low_t)
#         out_2 = self.kans[1](x_band_t)
#         out_3 = self.kans[2](x_high_t)

#         self.weight_low = nn.Parameter(torch.tensor(1.0))
#         self.weight_band = nn.Parameter(torch.tensor(1.0))
#         self.weight_high = nn.Parameter(torch.tensor(1.0))

#         out = self.weight_low * out_1 + self.weight_band * out_2 + self.weight_high * out_3
        
#         return out
    



class TemporalKAN(nn.Module):
    def __init__(self, configs):
        super(TemporalKAN, self).__init__()
        self.N = configs.d_model
        self.kans = nn.ModuleList([M_KAN(configs.d_model, configs.begin_order) for _ in range(3)])

        # Learnable frequency fusion weights
        self.freq_weights = nn.Parameter(torch.ones(3))

        # Optional: Cross-attention between frequency components
        self.cross_attn = nn.MultiheadAttention(embed_dim=configs.d_model, num_heads=4, batch_first=True)
        self.attn_proj = nn.Linear(configs.d_model, configs.d_model)

    def forward(self, x_low, x_band, x_high):
        # IFFT to time domain
        x_low_t = torch.fft.irfft(x_low, n=self.N, dim=1, norm='ortho')
        x_band_t = torch.fft.irfft(x_band, n=self.N, dim=1, norm='ortho')
        x_high_t = torch.fft.irfft(x_high, n=self.N, dim=1, norm='ortho')

        # Pass through KAN branches
        out_1 = self.kans[0](x_low_t)   # [B, T, C]
        out_2 = self.kans[1](x_band_t)
        out_3 = self.kans[2](x_high_t)


        # Apply cross-attention across frequency bands
        attn_out, _ = self.cross_attn(outs, outs, outs)  # [B*T, 3, C]
        attn_out = attn_out.mean(dim=1)  # Mean over 3 frequency heads => [B*T, C]
        attn_out = self.attn_proj(attn_out)
        attn_out = attn_out.view(B, T, C)

        # Frequency-wise learnable fusion
        fused = (self.freq_weights[0] * out_1 +
                 self.freq_weights[1] * out_2 +
                 self.freq_weights[2] * out_3)

        return fused + attn_out
    

class FrequencyKAN(nn.Module):
    def __init__(self, configs):
        super(FrequencyKAN, self).__init__()
        self.seq_len = configs.seq_len
        self.N = configs.d_model
        self.d_model = configs.d_model
        self.sparsity_threshold = 0.001
        self.hidden_size = 256
        self.kans = nn.ModuleList([self._init_kan_layer(KAN, [self.d_model, self.hidden_size,
                                                                  self.d_model]) for _ in range(3)])
        
    def _init_kan_layer(self, kan_class, layers_hidden):
        return kan_class(
            layers_hidden=layers_hidden,
            grid_size=2,
            spline_order=1,
            scale_noise=0.1,
            scale_base=1.0,
            scale_spline=1.0,
            base_activation=torch.nn.SiLU,
            grid_eps=0.02,
            grid_range=[-1, 1],
            regularize_activation=1.0,
            regularize_entropy=1.0,
            update_grid=False,
        )

    # frequency learner
    def KAN_frequency(self, x, index):
        # [B, N, L, D]
        # x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on L dimension
        
        y_real = self.kans[index](x.real)
        y_imag = self.kans[index](x.imag)
        y = torch.stack([y_real, y_imag], dim=-1)
        y = F.softshrink(y, lambd=self.sparsity_threshold)
        y = torch.view_as_complex(y)
        
        # x = torch.fft.irfft(y, n=self.seq_len, dim=2, norm="ortho")
        return y
    
    def forward(self, x_low, x_band, x_high):
        out_1 = self.KAN_frequency(x_low, 0)
        out_2 = self.KAN_frequency(x_band, 1)

        self.weight_low = nn.Parameter(torch.tensor(1.0))
        self.weight_band = nn.Parameter(torch.tensor(1.0))
        self.weight_high = nn.Parameter(torch.tensor(1.0))

        out = self.weight_low * out_1 + self.weight_band * out_2 + self.weight_high * out_3
        
        return out
    
    
class M_KAN(nn.Module):
    def __init__(self,d_model,order):
        super().__init__()
        self.channel_mixer = nn.Sequential(
            ChebyKANLayer(d_model, d_model,order)
        )
        self.conv = BasicConv(d_model,d_model,kernel_size=3,degree=order,groups=d_model)
    def forward(self,x):
        x1 = self.channel_mixer(x)
        x2 = self.conv(x)
        out  = x1 + x2
        return out 


# x = self.conv(x.transpose(-1,-2)).transpose(-1,-2)
# if self.bn: x = self.bn(x)


class BasicConv(nn.Module):
    def __init__(self,c_in,c_out, kernel_size, degree,stride=1, padding=0, dilation=1, groups=1, act=False, bn=False, bias=False,dropout=0.):
        super(BasicConv, self).__init__()
        self.out_channels = c_out
        self.conv = nn.Conv1d(c_in,c_out, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm1d(c_out) if bn else None
        self.act = nn.GELU() if act else None
        self.dropout = nn.Dropout(dropout)
    def forward(self, x): 
        x = self.conv(x.transpose(-1,-2)).transpose(-1,-2)
        if self.bn is not None:
            x = self.bn(x)
        if self.act is not None:
            x = self.act(x)
        if self.dropout is not None:
            x = self.dropout(x)
        return x
    

class TemporalSelfAttention(nn.Module):
    def __init__(self, configs):
        super().__init__()
        self.d_model = configs.d_model
        self.n_heads = getattr(configs, 'n_heads', 4)
        self.dropout = nn.Dropout(configs.dropout)

        self.attn = nn.MultiheadAttention(embed_dim=self.d_model,
                                          num_heads=self.n_heads,
                                          batch_first=True)

        self.ffn = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(configs.dropout),
            nn.Linear(self.d_model, self.d_model)
        )
        self.norm1 = nn.LayerNorm(self.d_model)
        self.norm2 = nn.LayerNorm(self.d_model)

    def forward(self, x):
        # x: [B, T, C]
        residual = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x)
        x = self.dropout(x) + residual

        residual = x
        x = self.norm2(x)
        x = self.ffn(x) + residual
        return x  # [B, T, C]



class Model(nn.Module):

    def __init__(self, configs):
        super(Model, self).__init__()
        self.configs = configs
        self.task_name = configs.task_name
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len
        self.pred_len = configs.pred_len
        self.channel_independence = configs.channel_independence
        self.asfm = nn.ModuleList([Adaptive_Spectral_Filter_Module(configs.d_model)
                                         for _ in range(configs.e_layers)])
        self.t_kans = nn.ModuleList([TemporalKAN(configs)
                                         for _ in range(configs.e_layers)])
        # self.t_attn = nn.ModuleList([TemporalSelfAttention(configs)
        #                                  for _ in range(configs.e_layers)])
        self.f_kans = nn.ModuleList([FrequencyKAN(configs)
                                         for _ in range(configs.e_layers)])
        # self.res_blocks = nn.ModuleList([FrequencyDecomp(configs)
        #                                  for _ in range(configs.e_layers)])
        # self.add_blocks = nn.ModuleList([FrequencyMixing(configs)
        #                                  for _ in range(configs.e_layers)])

        self.enc_in = configs.enc_in

        self.layer = configs.e_layers
        self.normalize_layers = Normalize(self.configs.enc_in, affine=True, non_norm=True if configs.use_norm == 0 else False)
                
        self.projection_layer = nn.Linear(
                    configs.d_model, 1, bias=True)
        self.predict_layer =nn. Linear(
                        configs.d_model,
                        configs.pred_len,
                    )

    def forecast(self, x):
        # x_enc = self.__multi_level_process_inputs(x_enc)
        B, T, N = x.size()
        x = self.normalize_layers(x, 'norm')
        x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

        enc_out = self.enc_embedding(x, None)  # [B,T,C]

        for i in range(self.layer):
            freq_enc = self.asfm[i](enc_out)
            f_out = self.f_kans[i](*freq_enc)
            #t_out = self.t_attn[i](enc_out)
            t_out = self.t_kans[i](*freq_enc)
        # 修复版本（将 t_out 上采样到与 f_out 相同时间维度）
        if f_out.shape[1] != t_out.shape[1]:
            t_out = F.interpolate(t_out.permute(0, 2, 1), size=f_out.shape[1], mode='linear', align_corners=True).permute(0, 2, 1)


        dec_out = f_out + t_out
        dec_out = self.predict_layer(dec_out.permute(0, 2, 1)).permute(0, 2, 1)  
        dec_out = self.projection_layer(dec_out).reshape(B, self.configs.c_out, self.pred_len).permute(0, 2, 1).contiguous()
        dec_out = self.normalize_layers(dec_out, 'denorm')
        return dec_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == 'long_term_forecast':
            dec_out = self.forecast(x_enc)
            return dec_out
        else:
            raise ValueError('Other tasks implemented yet')

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

        

# class Adaptive_Spectral_Filter_Module(nn.Module):
#     def __init__(self, d_model):
#         super().__init__()
#         self.d_model = d_model

#         # Learnable thresholds for soft masking
#         self.low_pass_cut = nn.Parameter(torch.rand(1) * 0.5)
#         self.high_pass_cut = nn.Parameter(torch.rand(1) * 0.5)

#         # Learnable frequency weights
#         self.freq_weights = nn.Parameter(torch.randn(3))

#     def forward(self, x):
#         B, T, C = x.shape
#         x = x.to(torch.float32)
#         x_fft = torch.fft.rfft(x, dim=1, norm='ortho')  # [B, F, C]

#         freq = torch.fft.rfftfreq(T, d=1.0 / T).to(x.device)  # [F]

#         # Generate soft masks
#         low_mask = torch.sigmoid(20 * (self.low_pass_cut - freq)).unsqueeze(0).unsqueeze(-1)  # [1, F, 1]
#         high_mask = torch.sigmoid(20 * (freq - self.high_pass_cut)).unsqueeze(0).unsqueeze(-1)
#         band_mask = 1.0 - low_mask - high_mask

#         x_low = x_fft * low_mask
#         x_band = x_fft * band_mask
#         x_high = x_fft * high_mask

#         return x_low, x_band, x_high
    
class Adaptive_Spectral_Filter_Module(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.threshold_param = nn.Parameter(torch.rand(1) * 0.5)
        self.low_pass_cut_freq_param = nn.Parameter(torch.rand(1)*65)#用于确定低通滤波的截至频率，维度大小的一半减去一个小的随机值
        self.high_pass_cut_freq_param = nn.Parameter(torch.rand(1)*5)#高通滤波的截至频率，维度大小的四分之一减去一个小的随机值

 #Learnable frequency weights
        self.freq_weights = nn.Parameter(torch.randn(3))
    
    def create_adaptive_high_freq_mask(self, x_fft):
        B, _, _ = x_fft.shape

        # Calculate energy in the frequency domain
        energy = torch.abs(x_fft).pow(2).sum(dim=-1)

        # Flatten energy across H and W dimensions and then compute median
        flat_energy = energy.view(B, -1)  # Flattening H and W into a single dimension
        median_energy = flat_energy.median(dim=1, keepdim=True)[0]  # Compute median
        median_energy = median_energy.view(B, 1)  # Reshape to match the original dimensions

        # Normalize energy
        normalized_energy = energy / (median_energy + 1e-6)

        threshold = torch.quantile(normalized_energy, self.threshold_param_high)
        dominant_frequencies = normalized_energy > threshold

        # Initialize adaptive mask
        adaptive_mask = torch.zeros_like(x_fft, device=x_fft.device)
        adaptive_mask[dominant_frequencies] = 1

        return adaptive_mask
        
        
    def forward(self, x_in):
        B, N, C = x_in.shape
        dtype = x_in.dtype
        x = x_in.to(torch.float32)
         
        # Apply FFT along the time dimension
        x_fft = torch.fft.rfft(x, dim=1, norm='ortho')
        
        adaptive_filter = True
        if adaptive_filter:
            x_low_pass = self.adaptive_freq_pass(x_fft, flag="low")#低通
            x_band_pass = self.adaptive_freq_pass(x_fft, flag="band")#高通
            x_high_pass = self.adaptive_freq_pass(x_fft, flag="high")#高通

        return x_low_pass, x_band_pass, x_high_pass




class Adaptive_Spectral_Filter_Module_two(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        # Learnable soft thresholds (in normalized freq [0,1])
        self.low_cut = nn.Parameter(torch.tensor(0.3))   # 左分界点
        self.high_cut = nn.Parameter(torch.tensor(0.7))  # 右分界点

        # 控制 mask 平滑程度的 sharpness 参数
        self.sharpness = 40.0

        # Learnable frequency fusion weights (可选)
        self.freq_weights = nn.Parameter(torch.randn(3))

    def forward(self, x_in):
        B, N, C = x_in.shape
        x = x_in.to(torch.float32)

        # FFT: [B, N, C] → [B, F, C]
        x_fft = torch.fft.rfft(x, dim=1, norm='ortho')
        F_size = x_fft.shape[1]
        freq = torch.fft.rfftfreq(N, d=1.0 / N).to(x.device)  # [F]

        # Create smooth frequency masks
        low_mask, band_mask, high_mask = self.soft_mask(freq, self.low_cut, self.high_cut)  # each [1, F, 1]

        # Apply masks to FFT
        x_low = x_fft * low_mask
        x_band = x_fft * band_mask
        x_high = x_fft * high_mask

        return x_low, x_band, x_high
