import math
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from vendor.timm.models.vision_transformer import Block



#####################################
# 1) 1D Sin-Cos 位置编码 (unchanged)
#####################################
def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False):
    position = np.arange(length, dtype=float)
    div_term = np.exp(np.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
    pos_embed = np.zeros((length, embed_dim), dtype=float)
    pos_embed[:, 0::2] = np.sin(position[:, None] * div_term[None, :])
    pos_embed[:, 1::2] = np.cos(position[:, None] * div_term[None, :])
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


#####################################
# 2) PatchEmbed1D (unchanged)
#####################################
class PatchEmbed1D(nn.Module):
    def __init__(self, seq_len=256, patch_size=2, in_chans=1, embed_dim=384):
        super().__init__()
        assert seq_len % patch_size == 0, "seq_len 必须能整除 patch_size"
        self.seq_len = seq_len
        self.patch_size = patch_size
        self.num_patches = seq_len // patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.proj = nn.Conv1d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # x: [B,1,256]
        x = self.proj(x)  # => [B, embed_dim, n_patches]
        x = x.transpose(1, 2)  # => [B, n_patches, embed_dim]
        return x


#####################################
# 3) patchify / unpatchify for 1D (unchanged)
#####################################
def patchify_1d(x, patch_size):
    B, C, L = x.shape
    assert L % patch_size == 0
    n_patches = L // patch_size
    x = x.reshape(B, C, n_patches, patch_size)
    x = x.permute(0, 2, 3, 1).reshape(B, n_patches, patch_size * C)
    return x


def unpatchify_1d(x, patch_size, in_chans=1):
    B, n_patches, dim = x.shape
    assert dim == patch_size * in_chans
    L = n_patches * patch_size
    x = x.reshape(B, n_patches, patch_size, in_chans)
    x = x.permute(0, 3, 1, 2).reshape(B, in_chans, L)
    return x


# 假设 PatchEmbed1D, Block, get_1d_sincos_pos_embed, patchify_1d, unpatchify_1d 已经定义

def get_fixed_complementary_masks(batch_size, seq_len, patch_size):
    n_patches = seq_len // patch_size
    pattern_A = [0, 1] * (n_patches // 2)
    if n_patches % 2 != 0:
        pattern_A.append(0)
    pattern_B = [1, 0] * (n_patches // 2)
    if n_patches % 2 != 0:
        pattern_B.append(1)
    mask_A = torch.tensor(pattern_A, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1)
    mask_B = torch.tensor(pattern_B, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1)
    return mask_A, mask_B





#####################################
# DualMaskMAE: 修改 forward(...) 以支持两个新 encoder (隔16掩16)
#####################################
class DualMaskMAE(nn.Module):
    """
    双掩码加上新增的两个隔16掩16 encoder：
      - x_noisy 用于 patch_embed
      - 针对 x_noisy 分别生成隔1掩1的 mask_A 与 mask_B，以及隔16掩16的 mask_C 与 mask_D
      - 分别调用 forward_encoder 得到 latent 表示及恢复索引
      - 通过 get_decoder_input 对每个 latent 拼接 mask token，得到 decoder 输入表示
      - 先对各自的互补对 (A、B) 与 (C、D) 做加权平均，再将两大分支做加权平均后送入 decoder
      - 最后与 x_clean 计算 loss
    """

    def __init__(self,
                 seq_len=256,
                 patch_size=2,
                 in_chans=1,
                 embed_dim=384,
                 depth=6,
                 num_heads=6,
                 decoder_embed_dim=256,
                 decoder_depth=4,
                 decoder_num_heads=8,
                 mlp_ratio=4.,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
                 norm_pix_loss=False,
                 **kwargs):
        super().__init__()
        self.seq_len = seq_len
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.norm_pix_loss = norm_pix_loss

        # encoder
        self.patch_embed = PatchEmbed1D(seq_len, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
                                      requires_grad=False)
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for _ in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # decoder
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
                                              requires_grad=False)
        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size * in_chans, bias=True)

        self.initialize_weights()

    def initialize_weights(self):
        n_patches = self.patch_embed.num_patches
        pe = get_1d_sincos_pos_embed(self.pos_embed.shape[-1], n_patches, cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pe).float().unsqueeze(0))
        dec_pe = get_1d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], n_patches, cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0))
        w = self.patch_embed.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.mask_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, x):
        return patchify_1d(x, self.patch_size)

    def unpatchify(self, x):
        return unpatchify_1d(x, patch_size=self.patch_size, in_chans=self.in_chans)

    def forward_encoder(self, x_embed, mask):
        B, L, D = x_embed.shape
        device = x_embed.device
        x_embed = x_embed + self.pos_embed[:, 1:, :]
        bool_mask = mask.bool()
        keep_mask = ~bool_mask

        x_masked = []
        keep_ids = []
        for b in range(B):
            row_mask = keep_mask[b]
            row_data = x_embed[b][row_mask]
            x_masked.append(row_data.unsqueeze(0))
            keep_ids_b = torch.nonzero(row_mask, as_tuple=False).flatten().to(device)
            keep_ids.append(keep_ids_b.unsqueeze(0))
        x_masked = torch.cat(x_masked, dim=0)
        keep_ids = torch.cat(keep_ids, dim=0)

        ids_restore = []
        for b in range(B):
            all_ids = torch.arange(L, device=device)
            masked_ids = all_ids[bool_mask[b]]
            combined = torch.cat([keep_ids[b], masked_ids], dim=0)
            order = torch.argsort(combined)
            ids_restore.append(order.unsqueeze(0))
        ids_restore = torch.cat(ids_restore, dim=0)

        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_token = cls_token.expand(B, -1, -1)
        x_final = torch.cat([cls_token, x_masked], dim=1)
        for blk in self.blocks:
            x_final = blk(x_final)
        x_final = self.norm(x_final)
        return x_final, mask, ids_restore

    def get_decoder_input(self, latent, ids_restore):
        """
        对 encoder 输出进行 decoder_embed 映射，然后拼接 mask token，并利用 ids_restore 恢复原始顺序，
        返回的表示尚未经过 decoder_blocks
        """
        x = self.decoder_embed(latent)
        B, L_keep_plus1, D_dec = x.shape
        total_len = ids_restore.shape[1] + 1
        mask_tokens = self.mask_token.repeat(B, total_len - L_keep_plus1, 1)
        x_cls = x[:, :1, :]
        x_ = x[:, 1:, :]
        x_ = torch.cat([x_, mask_tokens], dim=1)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D_dec))
        return torch.cat([x_cls, x_], dim=1)

    def forward(self, x_noisy, x_clean):
        """
        x_noisy: [B, 1, 256]  (含噪输入)
        x_clean: [B, 1, 256]  (对应的干净目标)
        """
        B, C, L = x_noisy.shape

        # 生成隔2掩2的互补掩码
        mask_A, mask_B = get_fixed_complementary_masks(B, L, self.patch_size)
        mask_A = mask_A.to(x_noisy.device)
        mask_B = mask_B.to(x_noisy.device)


        # patch embedding
        x_embed = self.patch_embed(x_noisy)

        ####  隔1掩1 分支  ####
        latent_A, _, ids_restore_A = self.forward_encoder(x_embed, mask_A)
        latent_B, _, ids_restore_B = self.forward_encoder(x_embed, mask_B)
        dec_in_A = self.get_decoder_input(latent_A, ids_restore_A)
        dec_in_B = self.get_decoder_input(latent_B, ids_restore_B)

        cls_tokenA = dec_in_A[:, :1, :]  # shape: [B, 1, D]
        cls_tokenB = dec_in_B[:, :1, :]
        # 提取 patch tokens
        patch_A = dec_in_A[:, 1:, :]  # shape: [B, 128, D]
        patch_B = dec_in_B[:, 1:, :]  # shape: [B, 128, D]

        # 最终 fused decoder input
        dec_in_block1 = torch.cat([cls_tokenA, patch_A], dim=1)
        dec_in_block2 = torch.cat([cls_tokenB, patch_B], dim=1)

        # 进入 decoder 后续模块 (保持不变)
        x1 = dec_in_block1 + self.decoder_pos_embed
        for blk in self.decoder_blocks:
            x1 = blk(x1)
        x1 = self.decoder_norm(x1)
        x1 = self.decoder_pred(x1)
        # 去掉 cls token
        x1 = x1[:, 1:, :]
        y1 = self.unpatchify(x1)  # [B, 1, 256]

        x2 = dec_in_block2 + self.decoder_pos_embed
        for blk in self.decoder_blocks:
            x2 = blk(x2)
        x2 = self.decoder_norm(x2)
        x2 = self.decoder_pred(x2)
        # 去掉 cls token
        x2 = x2[:, 1:, :]
        y2 = self.unpatchify(x2)  # [B, 1, 256]
        mask_A_full = mask_A.unsqueeze(-1).repeat(1, 1, self.patch_size).view(B, L)
        mask_B_full = mask_B.unsqueeze(-1).repeat(1, 1, self.patch_size).view(B, L)

        y_A = y1.squeeze(1)  # => [B,256]
        y_B = y2.squeeze(1)
        y_final = y_A * mask_A_full + y_B * mask_B_full
        y_final = y_final.unsqueeze(1)  # => [B,1,256]
        # 计算重建 loss (这里仅用 MSE 作为示例)
        diff = y_final - x_clean
        loss1 = torch.mean(diff ** 2)

        edge_diff = y_final[:, :, 0] - y_final[:, :, -1]
        loss2 = torch.mean(edge_diff ** 2)
        loss = loss1

        return loss, y_final, (mask_A, mask_B)


#####################################
# 便于外部调用的模型函数
#####################################
def denoise_dualmask(**kwargs):
    model = DualMaskMAE(
        seq_len=256,
        patch_size=2,
        in_chans=1,
        embed_dim=384,
        depth=6,
        num_heads=6,
        decoder_embed_dim=256,
        decoder_depth=4,
        decoder_num_heads=8,
        mlp_ratio=4.0,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        norm_pix_loss=False,
        **kwargs
    )
    return model
