import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt

################################################################################
# 1) 可学习 WaveFilter
################################################################################

def load_wavelet_kernel(wave_name, kernel_size):
    """
    使用 pywt 中的 wave_name 小波来获取 dec_lo, dec_hi 作为卷积核初始值。
    使用插值方法调整小波基到指定长度，保留原有特性。
    """
    wave = pywt.Wavelet(wave_name)  # 例如 "db6", "sym4" 等
    dec_lo = wave.dec_lo  # list of float
    dec_hi = wave.dec_hi  # list of float

    # 转成 torch tensor
    dec_lo_t = torch.tensor(dec_lo, dtype=torch.float)
    dec_hi_t = torch.tensor(dec_hi, dtype=torch.float)
    
    # 如果长度不同，使用插值方法调整大小
    if len(dec_lo_t) != kernel_size:
        # 使用线性插值调整滤波器长度
        dec_lo_t = dec_lo_t.view(1, 1, -1)  # [1, 1, original_size]
        dec_lo_t = F.interpolate(
            dec_lo_t, 
            size=kernel_size, 
            mode='linear', 
            align_corners=True
        ).squeeze()  # [kernel_size]
        
        # 重新归一化滤波器以保持能量
        dec_lo_t = dec_lo_t * (torch.sum(torch.tensor(dec_lo)) / torch.sum(dec_lo_t))
    
    if len(dec_hi_t) != kernel_size:
        # 对高通滤波器做同样的处理
        dec_hi_t = dec_hi_t.view(1, 1, -1)
        dec_hi_t = F.interpolate(
            dec_hi_t, 
            size=kernel_size, 
            mode='linear', 
            align_corners=True
        ).squeeze()
        
        # 归一化高通滤波器
        dec_hi_t = dec_hi_t * (torch.sum(torch.abs(torch.tensor(dec_hi))) / torch.sum(torch.abs(dec_hi_t)))
    
    return dec_lo_t, dec_hi_t

class LearnableWaveFilter(nn.Module):
    """
    可学习小波: wave_init 只做初始化, 训练时卷积核 weight 会更新
    """
    def __init__(self, in_ch=8, kernel_size=16,
                 wave_init='db6',  # 新增 wave_init
                 separate_per_channel=True):
        """
        Args:
          in_ch: 输入通道数
          kernel_size: 1D kernel size
          wave_init: wavelet名字 (在 pywt 中可用，如 'db6','db4','coif3','sym4' 等)
          separate_per_channel: True => depthwise, 每通道独立kernel
        """
        super().__init__()
        self.in_ch = in_ch
        self.kernel_size = kernel_size
        self.wave_init = wave_init
        self.separate_per_channel = separate_per_channel
        
        # 从 pywt 中加载真实小波系数
        low_init, high_init = load_wavelet_kernel(wave_init, kernel_size)
        
        # groups 设为 in_ch => depthwise
        groups = in_ch if separate_per_channel else 1

        self.low_filter = nn.Conv1d(
            in_channels=in_ch,
            out_channels=in_ch,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,
            groups=groups,
            bias=False
        )
        self.high_filter = nn.Conv1d(
            in_channels=in_ch,
            out_channels=in_ch,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,
            groups=groups,
            bias=False
        )
        
        # 初始化卷积核
        with torch.no_grad():
            if separate_per_channel:
                # 每通道单独初始化
                for c in range(in_ch):
                    self.low_filter.weight.data[c, 0, :]  = low_init
                    self.high_filter.weight.data[c, 0, :] = high_init
            else:
                # 所有通道共享同一核
                self.low_filter.weight.data[0, 0, :]  = low_init
                self.high_filter.weight.data[0, 0, :] = high_init
        
        # weight.requires_grad=True => 可学习

    def forward(self, x):
        """
        x: [B, in_ch, T]
        return approx, detail => [B, in_ch, T//2]
        """
        approx = self.low_filter(x)
        detail = self.high_filter(x)
        # 下采样 factor=2
        approx = approx[..., ::2]
        detail = detail[..., ::2]
        return approx, detail

class AdaptiveWaveletSelector(nn.Module):
    def __init__(self,
                 in_ch=8,
                 wavelet_names=None,
                 kernel_size=16,
                 separate_per_channel=True):
        """
        Args:
          in_ch: 输入通道数
          wavelet_names: 多波形列表, 例如 ['db4','db6','sym4','coif3']
          kernel_size: 卷积核大小
          separate_per_channel: depthwise
        """
        super().__init__()
        if wavelet_names is None:
            wavelet_names = ['db4', 'db6', 'sym4', 'coif3']

        self.in_ch = in_ch
        self.num_wavelets = len(wavelet_names)
        self.kernel_size = kernel_size
        self.separate_per_channel = separate_per_channel

        # 创建多种小波滤波器, 各自加载 wave_init
        self.wavelet_filters = nn.ModuleList([
            LearnableWaveFilter(
                in_ch=in_ch,
                kernel_size=kernel_size,
                wave_init=wname,  # 传 wave_init
                separate_per_channel=self.separate_per_channel
            )
            for wname in wavelet_names
        ])

        # 根据输入 x 的全局特征 => 选择哪个 wavelet
        self.selector = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(in_ch, 128),
            nn.ReLU(),
            nn.Linear(128, self.num_wavelets),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        # x: [B, in_ch, T]
        B, C, T = x.shape

        # 1) 计算每种小波的权重 => [B, num_wavelets]
        weights = self.selector(x)

        # 2) 对每种 wavelet 分解, 并按权重加权融合
        approx_list = []
        detail_list = []
        for i, filter_module in enumerate(self.wavelet_filters):
            approx_i, detail_i = filter_module(x)
            w_i = weights[:, i].view(B, 1, 1)
            approx_list.append(approx_i * w_i)
            detail_list.append(detail_i * w_i)

        approx = sum(approx_list)
        detail = sum(detail_list)
        return approx, detail


################################################################################
# 3) FFN & CrossScaleCAFFN 
################################################################################

class ElementScale(nn.Module):
    def __init__(self, shape, init_value=1.0):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(*shape)*init_value)
    def forward(self, x):
        return x*self.scale

class ChannelAggregationFFN(nn.Module):
    """
    Swin-like FFN + depthwise conv + decomposition
    """
    def __init__(self, embed_dims, ffn_ratio=4., kernel_size=3, dropout=0.1):
        super().__init__()
        hidden_dims = int(embed_dims*ffn_ratio)
        self.fc1 = nn.Conv2d(embed_dims, hidden_dims, kernel_size=1)
        self.dwconv= nn.Conv2d(hidden_dims, hidden_dims,
                               kernel_size=kernel_size, padding=kernel_size//2,
                               groups=hidden_dims)
        self.act = nn.GELU()
        self.fc2 = nn.Conv2d(hidden_dims, embed_dims, kernel_size=1)
        self.drop= nn.Dropout(dropout)

        self.decompose= nn.Conv2d(hidden_dims, 1, kernel_size=1)
        self.decompose_act= nn.GELU()
        self.sigma= ElementScale([1,hidden_dims,1,1], init_value=1e-5)

    def forward(self, x):
        # x: [B, embed_dims, H, W]
        out = self.fc1(x)
        out = self.dwconv(out)
        out = self.act(out)
        out = self.drop(out)

        t = self.decompose(out)
        t = self.decompose_act(t)
        out= out - t
        out= self.sigma(out) + out

        out= self.fc2(out)
        out= self.drop(out)
        return out

class CrossScaleCAFFN(nn.Module):
    """
    先CA-FFN,再多头注意力融合prev_feats,最后 out+attn_scale*attn_out
    """
    def __init__(self, embed_dims, ffn_ratio=4., kernel_size=3, dropout=0.1):
        super().__init__()
        self.base_ffn = ChannelAggregationFFN(embed_dims, ffn_ratio, kernel_size, dropout)
        self.cross_attn = nn.MultiheadAttention(embed_dims, num_heads=4, batch_first=True)
        
        self.attn_scale = nn.Parameter(torch.tensor(0.1), requires_grad=True)

    def forward(self, x, prev_feats=[]):
        # x: [B, embed_dims, H, W]
        out = self.base_ffn(x)
        if len(prev_feats)>0:
            B, C, H, W = out.shape
            q = out.permute(0,2,3,1).reshape(B, -1, C)  # [B, HW, C]

            context_list=[]
            for pf in prev_feats:
                Bp,Cp,Hp,Wp = pf.shape
                pf_2d = pf.permute(0,2,3,1).reshape(Bp, -1, Cp)  
                context_list.append(pf_2d)
            k = torch.cat(context_list, dim=1)  # [B, sum(HW?), C]

            attn_out, _ = self.cross_attn(q, k, k)
            attn_out= attn_out.view(B,H,W,C).permute(0,3,1,2)
            out= out + self.attn_scale* attn_out
        return out

################################################################################
# 4) MultiHeadGate + SoftGateWaveletDecomp
################################################################################

class MultiHeadGate(nn.Module):
    """
    多头自注意力做门控
    """
    def __init__(self, in_channels, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.q_proj = nn.Linear(in_channels, in_channels)
        self.k_proj = nn.Linear(in_channels, in_channels)
        self.v_proj = nn.Linear(in_channels, in_channels)

    def forward(self, x):
        """
        x: [B, C, T]
        """
        B,C,T = x.shape
        x_pool = x.mean(dim=-1)   # => [B, C]
        Q = self.q_proj(x_pool).view(B, self.num_heads, -1)
        K = self.k_proj(x_pool).view(B, self.num_heads, -1)
        V = self.v_proj(x_pool).view(B, self.num_heads, -1)

        c_head = C // self.num_heads
        attn = torch.matmul(Q, K.transpose(1,2)) / math.sqrt(c_head)
        attn = F.softmax(attn, dim=-1)  # => [B,num_heads,num_heads]

        out_v = torch.matmul(attn, V)   # => [B, num_heads, c_head]
        gate = out_v.view(B, -1)        # => [B,C]
        gate = torch.sigmoid(gate)
        return gate.unsqueeze(-1)       # => [B,C,1]


class SoftGateWaveletDecomp(nn.Module):
    """
    多级小波分解 (AdaptiveWaveletSelector) + 自注意力门控 (MultiHeadGate) + 交叉尺度 FFN
    => 输出: [B, (max_level+1)*C, T]
    """
    def __init__(self,
                 in_channels=8,
                 max_level=3,
                 kernel_size=16,
                 ffn_ratio=4.,
                 ffn_kernel_size=3,
                 ffn_drop=0.1,
                 use_separate_channel=True):
        """
        use_separate_channel: 是否在AdaptiveWaveletSelector里也设 separate_per_channel=True
        """
        super().__init__()
        self.max_level = max_level
        self.in_channels = in_channels

        # 用AdaptiveWaveletSelector替换单一LearnableWaveFilter
        # 可指定 separate_per_channel=True => 每通道独立
        self.wave_filter = AdaptiveWaveletSelector(
            in_ch=in_channels,
            kernel_size=kernel_size,
            wavelet_names=['db4', 'db6', 'sym4', 'coif3'],
            separate_per_channel=use_separate_channel
        )

        self.gate = MultiHeadGate(in_channels, num_heads=4)

        self.sub_ffn = nn.ModuleList([
            CrossScaleCAFFN(
                embed_dims=2*in_channels,
                ffn_ratio=ffn_ratio,
                kernel_size=ffn_kernel_size,
                dropout=ffn_drop
            )
            for _ in range(max_level)
        ])
        self.res_scale = nn.ParameterList([
            nn.Parameter(torch.zeros(1)) for _ in range(max_level)
        ])

    def forward(self, x):
        # x: [B, C, T]
        B, C, T = x.shape
        detail_accum = torch.zeros_like(x)
        approx = x
        freq_bands = []
        prev_feats = []

        for lvl in range(self.max_level):
            # 1) 小波分解 (AdaptiveWaveletSelector)
            approx_new, detail_new = self.wave_filter(approx)  # => [B, C, T//2]

            # 2) 上采样到原T维度，做门控融合
            up_approx = F.interpolate(approx_new.unsqueeze(1), size=(C, T), mode='nearest').squeeze(1)
            up_detail = F.interpolate(detail_new.unsqueeze(1), size=(C, T), mode='nearest').squeeze(1)

            gate_score = self.gate(approx)  # => [B, C, 1]
            new_approx = gate_score * approx + (1 - gate_score) * up_approx
            new_detail = gate_score * detail_accum + (1 - gate_score) * up_detail

            # 3) CrossScale
            subband = torch.cat([new_approx, new_detail], dim=1)  # [B, 2*C, T]
            sb_2d = subband.unsqueeze(2)  # => [B, 2*C, 1, T]

            out_2d = self.sub_ffn[lvl](sb_2d, prev_feats=prev_feats)
            out_2d = sb_2d + self.res_scale[lvl] * out_2d
            out_2d = out_2d.squeeze(2)  # => [B, 2*C, T]

            # 分离approx, detail
            new_approx = out_2d[:, :C]
            new_detail = out_2d[:, C:]

            freq_bands.append(new_detail)
            approx = new_approx
            detail_accum = new_detail
            prev_feats.append(sb_2d)

        freq_bands.append(approx)
        # 拼合成多频带输出 => [B,(max_level+1)*C, T]
        wave_spec = torch.cat(freq_bands, dim=1)
        return wave_spec
