import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt

from models.modules.patching2D import PatchEmbed2D
from models.modules.attention import CustomAttentionBlock
from models.modules.feature_conv import EnhancedHeightWidthFeatureConv, EfficientDimensionReduction
from models.modules.rotary import RotaryPositionEmbedding

################################################################################
# 1) 可学习 WaveFilter
################################################################################

def load_wavelet_kernel(wave_name, kernel_size):
    """
    使用 pywt 中的 wave_name 小波来获取 dec_lo, dec_hi 作为卷积核初始值。
    使用插值方法调整小波基到指定长度，保留原有特性。
    """
    wave = pywt.Wavelet(wave_name)  # 例如 "db6", "sym4" 等
    dec_lo = wave.dec_lo  # list of float
    dec_hi = wave.dec_hi  # list of float

    # 转成 torch tensor
    dec_lo_t = torch.tensor(dec_lo, dtype=torch.float)
    dec_hi_t = torch.tensor(dec_hi, dtype=torch.float)

    # 如果长度不同，使用插值方法调整大小
    if len(dec_lo_t) != kernel_size:
        # 使用线性插值调整滤波器长度
        dec_lo_t = dec_lo_t.view(1, 1, -1)  # [1, 1, original_size]
        dec_lo_t = F.interpolate(
            dec_lo_t,
            size=kernel_size,
            mode='linear',
            align_corners=True
        ).squeeze()  # [kernel_size]

        # 重新归一化滤波器以保持能量
        dec_lo_t = dec_lo_t * (torch.sum(torch.tensor(dec_lo)) / torch.sum(dec_lo_t))

    if len(dec_hi_t) != kernel_size:
        # 对高通滤波器做同样的处理
        dec_hi_t = dec_hi_t.view(1, 1, -1)
        dec_hi_t = F.interpolate(
            dec_hi_t,
            size=kernel_size,
            mode='linear',
            align_corners=True
        ).squeeze()

        # 归一化高通滤波器
        dec_hi_t = dec_hi_t * (torch.sum(torch.abs(torch.tensor(dec_hi))) / torch.sum(torch.abs(dec_hi_t)))

    return dec_lo_t, dec_hi_t


class LearnableWaveFilter(nn.Module):
    """
    可学习小波: wave_init 只做初始化, 训练时卷积核 weight 会更新
    """
    def __init__(self, in_ch=8, kernel_size=16,
                 wave_init='db6',  # 新增 wave_init
                 separate_per_channel=True):
        """
        Args:
          in_ch: 输入通道数
          kernel_size: 1D kernel size
          wave_init: wavelet名字 (在 pywt 中可用，如 'db6','db4','coif3','sym4' 等)
          separate_per_channel: True => depthwise, 每通道独立kernel
        """
        super().__init__()
        self.in_ch = in_ch
        self.kernel_size = kernel_size
        self.wave_init = wave_init
        self.separate_per_channel = separate_per_channel

        # 从 pywt 中加载真实小波系数
        low_init, high_init = load_wavelet_kernel(wave_init, kernel_size)

        # groups 设为 in_ch => depthwise
        groups = in_ch if separate_per_channel else 1

        self.low_filter = nn.Conv1d(
            in_channels=in_ch,
            out_channels=in_ch,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,
            groups=groups,
            bias=False
        )
        self.high_filter = nn.Conv1d(
            in_channels=in_ch,
            out_channels=in_ch,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,
            groups=groups,
            bias=False
        )

        # 初始化卷积核
        with torch.no_grad():
            if separate_per_channel:
                # 每通道单独初始化
                for c in range(in_ch):
                    self.low_filter.weight.data[c, 0, :]  = low_init
                    self.high_filter.weight.data[c, 0, :] = high_init
            else:
                # 所有通道共享同一核
                self.low_filter.weight.data[0, 0, :]  = low_init
                self.high_filter.weight.data[0, 0, :] = high_init

        # weight.requires_grad=True => 可学习

    def forward(self, x):
        """
        x: [B, in_ch, T]
        return approx, detail => [B, in_ch, T//2]
        """
        approx = self.low_filter(x)
        detail = self.high_filter(x)
        # 下采样 factor=2
        approx = approx[..., ::2]
        detail = detail[..., ::2]
        return approx, detail


class AdaptiveWaveletSelector(nn.Module):
    def __init__(self,
                 in_ch=8,
                 wavelet_names=None,
                 kernel_size=16,
                 separate_per_channel=True):
        """
        Args:
          in_ch: 输入通道数
          wavelet_names: 多波形列表, 例如 ['db4','db6','sym4','coif3']
          kernel_size: 卷积核大小
          separate_per_channel: depthwise
        """
        super().__init__()
        if wavelet_names is None:
            wavelet_names = ['db4', 'db6', 'sym4', 'coif3']

        self.in_ch = in_ch
        self.num_wavelets = len(wavelet_names)
        self.kernel_size = kernel_size
        self.separate_per_channel = separate_per_channel

        # 创建多种小波滤波器, 各自加载 wave_init
        self.wavelet_filters = nn.ModuleList([
            LearnableWaveFilter(
                in_ch=in_ch,
                kernel_size=kernel_size,
                wave_init=wname,  # 传 wave_init
                separate_per_channel=self.separate_per_channel
            )
            for wname in wavelet_names
        ])

        # 根据输入 x 的全局特征 => 选择哪个 wavelet
        self.selector = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(in_ch, 128),
            nn.ReLU(),
            nn.Linear(128, self.num_wavelets),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        # x: [B, in_ch, T]
        B, C, T = x.shape

        # 1) 计算每种小波的权重 => [B, num_wavelets]
        weights = self.selector(x)

        # 2) 对每种 wavelet 分解, 并按权重加权融合
        approx_list = []
        detail_list = []
        for i, filter_module in enumerate(self.wavelet_filters):
            approx_i, detail_i = filter_module(x)
            w_i = weights[:, i].view(B, 1, 1)
            approx_list.append(approx_i * w_i)
            detail_list.append(detail_i * w_i)

        approx = sum(approx_list)
        detail = sum(detail_list)
        return approx, detail


################################################################################
# 2) FFN & CrossScaleCAFFN 
################################################################################

class ElementScale(nn.Module):
    def __init__(self, shape, init_value=1.0):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(*shape)*init_value)
    def forward(self, x):
        return x*self.scale

class ChannelAggregationFFN(nn.Module):
    """
    Swin-like FFN + depthwise conv + decomposition
    """
    def __init__(self, embed_dims, ffn_ratio=4., kernel_size=3, dropout=0.1):
        super().__init__()
        hidden_dims = int(embed_dims*ffn_ratio)
        self.fc1 = nn.Conv2d(embed_dims, hidden_dims, kernel_size=1)
        self.dwconv= nn.Conv2d(hidden_dims, hidden_dims,
                               kernel_size=kernel_size, padding=kernel_size//2,
                               groups=hidden_dims)
        self.act = nn.GELU()
        self.fc2 = nn.Conv2d(hidden_dims, embed_dims, kernel_size=1)
        self.drop= nn.Dropout(dropout)

        self.decompose= nn.Conv2d(hidden_dims, 1, kernel_size=1)
        self.decompose_act= nn.GELU()
        self.sigma= ElementScale([1,hidden_dims,1,1], init_value=1e-5)

    def forward(self, x):
        # x: [B, embed_dims, H, W]
        out = self.fc1(x)
        out = self.dwconv(out)
        out = self.act(out)
        out = self.drop(out)

        t = self.decompose(out)
        t = self.decompose_act(t)
        out = out - t
        out = self.sigma(out) + out

        out = self.fc2(out)
        out = self.drop(out)
        return out


class CrossScaleCAFFN(nn.Module):
    """
    先 CA-FFN，再多头注意力融合 prev_feats，最后 out + attn_scale*attn_out
    """
    def __init__(self, embed_dims, ffn_ratio=4., kernel_size=3, dropout=0.1):
        super().__init__()
        self.base_ffn = ChannelAggregationFFN(embed_dims, ffn_ratio, kernel_size, dropout)
        self.cross_attn = nn.MultiheadAttention(embed_dims, num_heads=4, batch_first=True)

        self.attn_scale = nn.Parameter(torch.tensor(0.1), requires_grad=True)

    def forward(self, x, prev_feats=[]):
        # x: [B, embed_dims, H, W]
        out = self.base_ffn(x)
        if len(prev_feats) > 0:
            B, C, H, W = out.shape
            q = out.permute(0, 2, 3, 1).reshape(B, -1, C)  # [B, HW, C]

            context_list = []
            for pf in prev_feats:
                Bp, Cp, Hp, Wp = pf.shape
                pf_2d = pf.permute(0, 2, 3, 1).reshape(Bp, -1, Cp)
                context_list.append(pf_2d)
            k = torch.cat(context_list, dim=1)  # [B, sum(HW?), C]

            attn_out, _ = self.cross_attn(q, k, k)
            attn_out = attn_out.view(B, H, W, C).permute(0, 3, 1, 2)
            out = out + self.attn_scale * attn_out
        return out


################################################################################
# 3) MultiHeadGate + SoftGateWaveletDecomp
################################################################################

class MultiHeadGate(nn.Module):
    """
    多头自注意力做门控
    """
    def __init__(self, in_channels, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.q_proj = nn.Linear(in_channels, in_channels)
        self.k_proj = nn.Linear(in_channels, in_channels)
        self.v_proj = nn.Linear(in_channels, in_channels)

    def forward(self, x):
        """
        x: [B, C, T]
        """
        B, C, T = x.shape
        x_pool = x.mean(dim=-1)   # => [B, C]

        Q = self.q_proj(x_pool).view(B, self.num_heads, -1)
        K = self.k_proj(x_pool).view(B, self.num_heads, -1)
        V = self.v_proj(x_pool).view(B, self.num_heads, -1)

        c_head = C // self.num_heads
        attn = torch.matmul(Q, K.transpose(1,2)) / math.sqrt(c_head)
        attn = F.softmax(attn, dim=-1)  # => [B, num_heads, num_heads]

        out_v = torch.matmul(attn, V)   # => [B, num_heads, c_head]
        gate = out_v.view(B, -1)        # => [B, C]
        gate = torch.sigmoid(gate)
        return gate.unsqueeze(-1)       # => [B, C, 1]


class SoftGateWaveletDecomp(nn.Module):
    """
    多级小波分解 (AdaptiveWaveletSelector) + 自注意力门控 (MultiHeadGate) + 交叉尺度 FFN
    => 输出: [B, (max_level+1)*C, T]
    """
    def __init__(self,
                 in_channels=8,
                 max_level=3,
                 kernel_size=16,
                 wavelet_names=None,
                 use_separate_channel=True,
                 ffn_ratio=4.,
                 ffn_kernel_size=5,
                 ffn_drop=0.1):
        """
        Args:
            in_channels: 输入通道数
            max_level: 小波分解层数
            kernel_size: 小波卷积核大小
            wavelet_names: 小波列表 (可选)
            use_separate_channel: 是否使用 depthwise 分离
            ffn_ratio, ffn_kernel_size, ffn_drop: CrossScaleCAFFN 相关参数
        """
        super().__init__()
        self.max_level = max_level
        self.in_channels = in_channels

        self.wave_filter = AdaptiveWaveletSelector(
            in_ch=in_channels,
            kernel_size=kernel_size,
            wavelet_names=wavelet_names,
            separate_per_channel=use_separate_channel
        )

        self.gate = MultiHeadGate(in_channels, num_heads=4)

        # 多层 CrossScaleCAFFN
        self.sub_ffn = nn.ModuleList([
            CrossScaleCAFFN(
                embed_dims=2 * in_channels,
                ffn_ratio=ffn_ratio,
                kernel_size=ffn_kernel_size,
                dropout=ffn_drop
            )
            for _ in range(max_level)
        ])
        self.res_scale = nn.ParameterList([
            nn.Parameter(torch.zeros(1)) for _ in range(max_level)
        ])

    def forward(self, x):
        # x: [B, C, T]
        B, C, T = x.shape
        detail_accum = torch.zeros_like(x)
        approx = x
        freq_bands = []
        prev_feats = []

        for lvl in range(self.max_level):
            # 1) 小波分解
            approx_new, detail_new = self.wave_filter(approx)  # => [B, C, T//2]

            # 2) 上采样到原 T 维度，做门控融合
            up_approx = F.interpolate(approx_new.unsqueeze(1), size=(C, T), mode='nearest').squeeze(1)
            up_detail = F.interpolate(detail_new.unsqueeze(1), size=(C, T), mode='nearest').squeeze(1)

            gate_score = self.gate(approx)  # => [B, C, 1]
            new_approx = gate_score * approx + (1 - gate_score) * up_approx
            new_detail = gate_score * detail_accum + (1 - gate_score) * up_detail

            # 3) CrossScale CAFFN
            subband = torch.cat([new_approx, new_detail], dim=1)  # [B, 2*C, T]
            sb_2d = subband.unsqueeze(2)  # => [B, 2*C, 1, T]

            out_2d = self.sub_ffn[lvl](sb_2d, prev_feats=prev_feats)
            out_2d = sb_2d + self.res_scale[lvl] * out_2d
            out_2d = out_2d.squeeze(2)  # => [B, 2*C, T]

            # 分离 approx, detail
            new_approx = out_2d[:, :C]
            new_detail = out_2d[:, C:]

            freq_bands.append(new_detail)
            approx = new_approx
            detail_accum = new_detail
            prev_feats.append(sb_2d)

        freq_bands.append(approx)
        # 拼合成多频带输出 => [B,(max_level+1)*C, T]
        wave_spec = torch.cat(freq_bands, dim=1)
        return wave_spec


################################################################################
# 4) Masking 函数
################################################################################

def frequency_guided_masking(x, mask_ratio, importance_ratio=0.5):
    """
    基于频域重要性的自适应掩码策略

    Args:
        x: 输入张量 [B, L, D]
        mask_ratio: 总体掩码比例
        importance_ratio: 频域重要性影响的比例 (0-1)
    """
    B, L, D = x.shape
    len_keep = int(L * (1 - mask_ratio))

    # 1. 计算频域重要性分数
    x_reshaped = x.permute(0, 2, 1)  # [B, D, L]
    x_fft = torch.abs(torch.fft.rfft(x_reshaped, dim=2))  # [B, D, L//2+1]

    # 对频谱求和得到每个位置的重要性分数
    importance_scores = torch.sum(x_fft, dim=1)  # [B, L//2+1]

    # 将重要性分数插值回原始长度
    importance_full = F.interpolate(
        importance_scores.unsqueeze(1),
        size=L,
        mode='linear'
    ).squeeze(1)  # [B, L]

    # 2. 混合随机性和重要性
    random_noise = torch.rand(B, L, device=x.device)
    combined_scores = (1 - importance_ratio) * random_noise - importance_ratio * importance_full

    # 3. 根据混合分数选择掩码位置
    ids_shuffle = torch.argsort(combined_scores, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # 保留分数较低的 tokens
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(
        x, dim=1,
        index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
    )

    mask = torch.ones([B, L], device=x.device)
    mask[:, :len_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore


def random_masking(x, mask_ratio):
    """
    对输入序列进行随机掩码（MAE风格）
    """
    B, L, D = x.shape
    len_keep = int(L * (1 - mask_ratio))

    noise = torch.rand(B, L, device=x.device)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(
        x, dim=1,
        index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
    )

    mask = torch.ones([B, L], device=x.device)
    mask[:, :len_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore)
    return x_masked, mask, ids_restore


################################################################################
# 5) ChannelVisionTransformer 主体
################################################################################
class ChannelVisionTransformer(nn.Module):
    """
    增强型频道视觉Transformer，集成了小波分解、空间特征提取和Rotary位置编码，
    并支持 frequency_guided_masking 的重要性超参等。
    """
    def __init__(
        self,
        # 1) 小波分解相关
        in_ch: int = 16,
        max_level: int = 3,
        wave_kernel_size: int = 16,
        wavelet_names=None,          # 允许自定义小波列表
        use_separate_channel=True,   # AdaptiveWaveletSelector 的 depthwise 标志

        # 2) FFN & CrossScaleCAFFN 相关
        ffn_ratio: float = 4.0,
        ffn_kernel_size: int = 5,
        ffn_drop: float = 0.1,

        # 3) 特征提取 & 降维
        hw_square_kernel: int = 3,
        hw_band_kernel: int = 11,
        reduced_dim: int = 32,

        # 4) Patch 提取参数
        timesteps: int = 1000,
        patch_size: tuple = (1, 50),

        # 5) Transformer 及掩码参数
        embed_dim: int = 512,
        depth: int = 12,
        num_heads: int = 8,
        mlp_ratio: float = 4.,
        norm_layer=nn.LayerNorm,
        drop_path=0.1,
        attention_type: str = 'wavelet_enhanced',
        masking_ratio: float = 0.75,
        importance_ratio: float = 0.6,  # frequency_guided_masking 里用到
        
        # ★ 新增：将 use_masking 放在初始化里，而不是 forward ★
        use_masking: bool = True,

        # 6) RoPE
        max_seq_len: int = 2048,  # RoPE最大序列长度
    ):
        super().__init__()

        # -----------------------------
        # 把 use_masking 存成模型属性
        # -----------------------------
        self.use_masking = use_masking

        self.in_ch = in_ch
        self.max_level = max_level
        self.masking_ratio = masking_ratio
        self.importance_ratio = importance_ratio
        self.embed_dim = embed_dim
        self.timesteps = timesteps
        self.patch_size = patch_size
        
        # 1) 小波分解
        self.wavelet_decomp = SoftGateWaveletDecomp(
            in_channels=in_ch,
            max_level=max_level,
            kernel_size=wave_kernel_size,
            wavelet_names=wavelet_names,
            use_separate_channel=use_separate_channel,
            ffn_ratio=ffn_ratio,
            ffn_kernel_size=ffn_kernel_size,
            ffn_drop=ffn_drop,
        )

        # wave_decomp_ch = (max_level + 1) * in_ch
        self.wave_decomp_ch = (max_level + 1) * in_ch

        # 2) 特征提取和降维
        self.feature_extraction = EnhancedHeightWidthFeatureConv(
            square_kernel=hw_square_kernel,
            band_kernel=hw_band_kernel
        )
        self.dimension_reduction = EfficientDimensionReduction(
            in_channels=self.wave_decomp_ch,
            out_channels=reduced_dim
        )

        # 3) Patch Embed
        self.patch_embed = PatchEmbed2D(
            input_channels=1,
            in_height=reduced_dim,
            in_width=self.timesteps,
            kernel_size=patch_size,
            stride=patch_size,
            embed_dim=embed_dim,
            flatten=True
        )
        self.num_patches = self.patch_embed.num_patches

        # 4) [CLS] / [MASK] Tokens
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 5) Rotary Position Embedding
        self.rotary_pos_embed = RotaryPositionEmbedding(
            dim=embed_dim,
            max_seq_len=max_seq_len
        )

        # 6) Transformer Blocks
        self.blocks = nn.ModuleList([
            CustomAttentionBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=True,
                norm_layer=norm_layer,
                attention_type=attention_type,
                block_idx=i,
                num_channels=reduced_dim,  # 用于内部一些2D处理
                drop_path=drop_path
            )
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        self.initialize_weights()

    def initialize_weights(self):
        """初始化模型权重"""
        # 初始化patch embedding
        if hasattr(self.patch_embed, 'init_patch_embed'):
            self.patch_embed.init_patch_embed()

        # 初始化线性层和Layer Norm
        self.apply(self._init_weights)

        # 初始化 mask_token 和 cls_token
        nn.init.normal_(self.mask_token, std=.02)
        nn.init.normal_(self.cls_token, std=.02)

    def _init_weights(self, m):
        """初始化单个模块的权重"""
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    # -----------------------------
    # 移除 forward() 里的 use_masking 形参
    # -----------------------------
    def forward(self, x):
        """
        Args:
            x: [B, in_ch, T] 输入数据
        Returns:
            x: [B, N, D] 经过 Transformer 后输出
            mask: [B, N] 位置掩码（如果 self.use_masking=True）
            ids_restore: [B, N] 用于还原顺序的索引（如果 self.use_masking=True）
        """
        # 1) 小波分解 => [B, wave_decomp_ch, T]
        wave_spec = self.wavelet_decomp(x)

        # 2) 转为 [B, 1, wave_decomp_ch, T] 做后续 2D 卷积处理
        wave_2d = wave_spec.unsqueeze(1)

        # 3) 特征提取 + 降维 => [B, 1, reduced_dim, T]
        enhanced_features = self.feature_extraction(wave_2d)
        reduced_features = self.dimension_reduction(enhanced_features)

        # 4) PatchEmbedding => [B, num_patches, embed_dim]
        x = self.patch_embed(reduced_features)

        # 5) 掩码逻辑
        mask, ids_restore = None, None
        if self.use_masking:
            # 使用 frequency_guided_masking
            x_masked, mask, ids_restore = frequency_guided_masking(
                x,
                mask_ratio=self.masking_ratio,
                importance_ratio=self.importance_ratio
            )
            # 应用 RoPE
            x = self.rotary_pos_embed(x_masked)

            # 添加 [MASK] token 并恢复顺序
            mask_tokens = self.mask_token.repeat(
                x.size(0), ids_restore.shape[1] - x_masked.shape[1], 1
            )
            if mask_tokens.shape[1] > 0:
                x_ = torch.cat([x, mask_tokens], dim=1)
                x = torch.gather(
                    x_, dim=1,
                    index=ids_restore.unsqueeze(-1).repeat(1, 1, self.embed_dim)
                )
        else:
            # 不使用掩码时，前面添加 [CLS] token
            cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            x = self.rotary_pos_embed(x)

        # 6) Transformer Blocks
        if not self.use_masking:
            # 分离 [CLS] token 和 patch tokens
            x_cls = x[:, 0].unsqueeze(1)
            x_patches = x[:, 1:]

            for blk in self.blocks:
                x_patches = blk(x_patches)
            x_patches = self.norm(x_patches)
            # 最终组合回 [CLS] + patches
            x = torch.cat([x_cls, x_patches], dim=1)
        else:
            for blk in self.blocks:
                x = blk(x)
            x = self.norm(x)

        return x, mask, ids_restore

    def extract_features(self, x):
        """
        提取特征而不进行掩码或 Transformer 处理,
        只到降维为止，返回 [B, 1, reduced_dim, T]
        """
        wave_spec = self.wavelet_decomp(x)  # [B, wave_decomp_ch, T]
        wave_2d = wave_spec.unsqueeze(1)    # [B, 1, wave_decomp_ch, T]
        enhanced_features = self.feature_extraction(wave_2d)
        reduced_features = self.dimension_reduction(enhanced_features)
        return reduced_features
