import math
import torch
import torch.nn as nn
import numpy as np
from functools import partial
# 如果用 timm 里的 Transformer Block:
from vendor.timm.models.vision_transformer import Block
from vendor.timm.models.vision_transformer import Block_la_o


# -------------------------------------------------------------------------
# 1) 1D Sin-Cos位置编码 (替代原2D版本)
# -------------------------------------------------------------------------
def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False):
    """
    生成 1D sin-cos 位置编码:
      输出 shape = [length, embed_dim]
      若 cls_token=True, 则变成 [length+1, embed_dim], 第0行是 cls_token 编码
    """
    position = np.arange(length, dtype=float)
    div_term = np.exp(np.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))

    pos_embed = np.zeros((length, embed_dim), dtype=float)
    pos_embed[:, 0::2] = np.sin(position[:, None] * div_term[None, :])
    pos_embed[:, 1::2] = np.cos(position[:, None] * div_term[None, :])

    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


# -------------------------------------------------------------------------
# 2) 一维 PatchEmbed
#    [B,1,512] => [B,32,embed_dim], kernel_size=16, stride=16
# -------------------------------------------------------------------------
class PatchEmbed1D(nn.Module):
    def __init__(self, seq_len=256, patch_size=2, in_chans=1, embed_dim=384):
        super().__init__()
        assert seq_len % patch_size == 0, "seq_len 必须能整除 patch_size"
        self.seq_len     = seq_len
        self.patch_size  = patch_size
        self.num_patches = seq_len // patch_size
        self.in_chans    = in_chans
        self.embed_dim   = embed_dim

        self.proj = nn.Conv1d(
            in_channels  = in_chans,
            out_channels = embed_dim,
            kernel_size  = patch_size,
            stride       = patch_size
        )

    def forward(self, x):
        # x: [B,1,512]
        x = self.proj(x)           # => [B, embed_dim, num_patches]
        x = x.transpose(1,2)       # => [B, num_patches, embed_dim]
        return x


# -------------------------------------------------------------------------
# 3) patchify / unpatchify for 1D
# -------------------------------------------------------------------------
def patchify_1d(x, patch_size):
    """
    x: [B,1,512] -> [B, num_patches, patch_size]
    num_patches = 512/16=32
    """
    B, C, L = x.shape
    assert L % patch_size == 0
    n_patches = L // patch_size
    x = x.reshape(B, C, n_patches, patch_size)
    x = x.permute(0,2,3,1).reshape(B, n_patches, patch_size*C)  # C=1
    return x

def unpatchify_1d(x, patch_size, in_chans=1):
    """
    x: [B, n_patches, patch_size*C] -> [B,1,512]
    """
    B, n_patches, dim = x.shape
    assert dim == patch_size*in_chans
    L = n_patches*patch_size
    x = x.reshape(B, n_patches, patch_size, in_chans)
    x = x.permute(0,3,1,2).reshape(B,in_chans,L)
    return x


# -------------------------------------------------------------------------
# 4) 基于“每个patch(16点)的原始均值”筛选
#    例如 topk=8 => 取最小8个patch & 最大8个patch => 共16个patch 被掩码
# -------------------------------------------------------------------------
def mask_by_patch_means(x_raw, topk=24):
    """
    输入:
      x_raw: [B, 1, 256] 的时间序列数据（单通道），
             假设 patch_size 为 4，则总共有 n_patch = 256//4 = 64 个patch。
    输出:
      mask_patch_ids: [B, (topk_high + topk_low)] 张量，
             每行包含需要被掩码的 patch 索引，选取方法为：
             - 对于每个 patch，计算该 patch 的最大值和最小值；
             - 对 patch_max 按降序排序，选取 topk 个（如果与低端重复则顺延）；
             - 对 patch_min 按升序排序，选取 topk 个（如果与高端重复则顺延）；
             - 返回合并后的索引（总数通常为 2*topk，如果冲突则可能更多）。
    """
    B, C, L = x_raw.shape
    assert L == 256 and C == 1, "输入形状应为 [B, 1, 256]"
    patch_size = 2
    n_patch = L // patch_size  # 64

    patch_maxes = []  # 存放每个 patch 的最大值，[B, n_patch]
    patch_mins = []   # 存放每个 patch 的最小值，[B, n_patch]
    for i in range(n_patch):
        seg = x_raw[:, :, i*patch_size:(i+1)*patch_size]  # [B, 1, patch_size]
        # 对于每个样本，计算该patch的最大值和最小值
        patch_max = seg.max(dim=-1)[0].squeeze(-1)  # [B]
        patch_min = seg.min(dim=-1)[0].squeeze(-1)  # [B]
        patch_maxes.append(patch_max)
        patch_mins.append(patch_min)
    # 将列表堆叠成 [B, n_patch]
    patch_maxes = torch.stack(patch_maxes, dim=1)  # [B, 64]
    patch_mins = torch.stack(patch_mins, dim=1)      # [B, 64]

    mask_patch_ids = []
    for b in range(B):
        # 对当前样本 b
        # 按降序排列 patch_max 得到高端候选索引
        high_sorted = torch.argsort(patch_maxes[b], descending=True)  # 长度64
        # 按升序排列 patch_min 得到低端候选索引
        low_sorted = torch.argsort(patch_mins[b], descending=False)     # 长度64

        # 为避免冲突，我们取多一些候选
        candidate_high = high_sorted[:topk*2].tolist()  # 转换为 Python list
        candidate_low = low_sorted[:topk*2].tolist()

        selected_high = []
        selected_low = []

        # 依次从 high 候选中选取 topk 个，如果与已经选出的 low 冲突则跳过
        for idx in candidate_high:
            if len(selected_high) < topk:
                if idx not in selected_low:
                    selected_high.append(idx)
        # 从 low 候选中选取 topk 个，如果与已选 high 冲突则跳过
        for idx in candidate_low:
            if len(selected_low) < topk:
                if idx not in selected_high:
                    selected_low.append(idx)
        # 合并两个列表
        combined = selected_high + selected_low
        mask_patch_ids.append(torch.tensor(combined, device=x_raw.device))
    # 堆叠得到 [B, (topk_high+topk_low)]
    mask_patch_ids = torch.stack(mask_patch_ids, dim=0)
    return mask_patch_ids



def mask_by_threshold(x_raw, threshold=3.1, total_mask=48, min_abs_threshold=0.3):
    """
    输入:
      x_raw: [B, 1, 256] 的时间序列数据（单通道）
      threshold: 阈值，取绝对值≥threshold的patch为候选
      total_mask: 最终需要掩码的patch总数（例如48）
      min_abs_threshold: 补充部分的阈值，选择绝对值小于此值的patch进行补充（例如0.3）
    输出:
      mask_patch_ids: [B, total_mask] 张量，每行为选定的patch索引
      all_mask_patch_ids: [B, total_mask] 张量，所有选定的mask索引，包括阈值大于1.505的部分
    """
    B, C, L = x_raw.shape
    patch_size = 2  # 每个patch大小为2
    n_patch = L // patch_size  # 总共有128个patch
    mask_patch_ids = []
    all_mask_patch_ids = []

    for b in range(B):
        # 取出当前样本，形状 (256,)
        x_sample = x_raw[b, 0]
        # 将数据分成 n_patch 个 patch，每个 patch 长度为 patch_size，形状 (n_patch, patch_size)
        patches = x_sample.unfold(0, patch_size, patch_size)  # shape: (n_patch, patch_size)
        # 计算每个 patch 内绝对值的最大值
        patch_abs_max = patches.abs().max(dim=1).values  # shape: (n_patch,)

        # 得到绝对值大于 threshold 的 patch 索引（满足条件≥threshold）
        candidate = (patch_abs_max >= threshold).nonzero(as_tuple=False).squeeze()
        if candidate.dim() == 0:
            candidate = candidate.unsqueeze(0)
        candidate = candidate.tolist() if candidate.numel() > 0 else []

        # 计算补充部分的 patch 索引（绝对值小于 min_abs_threshold）
        remaining = (patch_abs_max < min_abs_threshold).nonzero(as_tuple=False).squeeze()
        remaining = remaining.tolist()

        # 补充不足 total_mask 的部分
        if len(candidate) < total_mask:
            num_needed = total_mask - len(candidate)
            if len(remaining) > 0:
                additional = np.random.choice(remaining, num_needed, replace=False).tolist()
            else:
                additional = []
            selected = candidate + additional
        else:
            # 如果候选超过 total_mask，则随机抽取 total_mask 个
            selected = np.random.choice(candidate, total_mask, replace=False).tolist()

        # 保证顺序稳定
        selected = sorted(selected)
        mask_patch_ids.append(torch.tensor(selected, device=x_raw.device))
        all_mask_patch_ids.append(torch.tensor(selected + remaining[:total_mask - len(selected)], device=x_raw.device))

    mask_patch_ids = torch.stack(mask_patch_ids, dim=0)  # [B, total_mask]
    all_mask_patch_ids = torch.stack(all_mask_patch_ids, dim=0)  # [B, total_mask]

    return mask_patch_ids, all_mask_patch_ids

def mle_loss(full_pred, mask_time, kappa):
    """
    计算基于 MLE 思想的正则化损失，只在 mask 区域内计算预测信号的局部能量差异，
    并通过调制函数将能量约束到合适范围，从而帮助峰值重建。

    参数:
      full_pred: [B, 1, L]，模型还原后的全序列预测信号
      mask_time: [B, 1, L]，时域 mask，1 表示该时刻处于需要约束的区域
      kappa: 调制参数（标量），用于调节正则项中对低能量和高能量的惩罚程度

    计算步骤:
      1. 对 full_pred 沿时间维度使用有限差分计算二阶导数：
         laplacian = (x[t+1] - 2 * x[t] + x[t-1]) / (dt^2)
      2. 利用 mask_time 对有限差分结果的平方进行平均，得到每个样本在 mask 区域的平均能量
      3. 对平均能量使用 sigmoid 函数归一化到 (0, 1)
      4. 计算正则项 RMLE = -log(E_norm) - kappa * log(1 - E_norm)
      5. 最后对 batch 中所有样本取均值作为损失
    """
    dt = 0.01
    # 计算二阶导数（有限差分），注意结果形状为 [B, 1, L-2]
    laplacian = (full_pred[:, :, 2:] - 2 * full_pred[:, :, 1:-1] + full_pred[:, :, :-2])


    d1 = full_pred[:, :, 1:] - full_pred[:, :, :-1]
    # 对应 mask 也需要取中间部分（索引 1:-1），形状 [B, 1, L-2]
    mask_reg = mask_time[:, :, :-1]
    zero_pad = torch.zeros_like(laplacian[:, :, :1])  # 形状 [B, 1, 1]，与 laplacian 其他维度匹配
    laplacian_shifted1 = torch.cat([zero_pad, laplacian], dim=-1)  # [B, 1, L-1]
    laplacian_shifted2 = torch.cat([laplacian, zero_pad], dim=-1)  # [B, 1, L-1]
    # 计算每个位置的功率
    energy = laplacian_shifted1*d1  # [B, 1, L-2]
    # 对每个样本，在 mask 区域内求平均功率
    # 计算 masked 区域的能量和和有效元素个数
    energy_sum = (energy * mask_reg).sum(dim=-1)  # [B, 1]
    mask_count = mask_reg.sum(dim=-1)  # [B, 1]
    # 避免除零，加一个很小的常数
    energy_mean = energy_sum / (mask_count + 1e-6)  # [B, 1]
    # 归一化能量到 (0, 1) 区间（这里采用 sigmoid）
    E_norm = torch.sigmoid(energy_mean)

    # 计算正则项：当 E_norm 接近 0 或 1 时，正则项趋于无穷，从而迫使能量保持在适中水平
    loss_per_sample = -torch.log(E_norm + 1e-6) - kappa * torch.log(1 - E_norm + 1e-6)  # [B, 1]

    # 最后取 batch 均值
    loss = loss_per_sample.mean()

    return loss
# -------------------------------------------------------------------------
# 5) 核心: MaskedAutoencoderViT, 改写random_masking -> "patch_means"
# -------------------------------------------------------------------------
class MaskedAutoencoderViT_PatchMeanMask(nn.Module):
    """
    在 patch_embed 之前计算每个patch(16点)的均值 => 选出最小topk & 最大topk
    => mask它们, 其余保留, 进入encoder
    => decoder阶段再补mask token
    """
    def __init__(self,
                 seq_len=256,
                 patch_size=2,
                 in_chans=1,
                 embed_dim=384,
                 depth=6,
                 num_heads=6,
                 decoder_embed_dim=256,
                 decoder_depth=4,
                 decoder_num_heads=8,
                 mlp_ratio=4.,
                 norm_layer=nn.LayerNorm,
                 norm_pix_loss=False,
                 topk=24):
        super().__init__()
        self.seq_len     = seq_len
        self.patch_size  = patch_size
        self.in_chans    = in_chans
        self.norm_pix_loss = norm_pix_loss
        self.topk = topk  # 用于决定要mask的patch数=2*topk

        # encoder
        self.patch_embed = PatchEmbed1D(seq_len, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches  # 512/16=32

        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches+1, embed_dim),
            requires_grad=False
        )

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True,
                   norm_layer = norm_layer)
            for _ in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # decoder
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1,1,decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches+1, decoder_embed_dim),
            requires_grad=False
        )
        self.decoder_blocks = nn.ModuleList([
            Block_la_o(decoder_embed_dim, decoder_num_heads, mlp_ratio,window_size=27,
                  qkv_bias =True, norm_layer=norm_layer)
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size*in_chans, bias=True)

        # init
        self.initialize_weights()

    def initialize_weights(self):
        # 1D sin-cos pos_embed
        n_patches = self.patch_embed.num_patches
        pe = get_1d_sincos_pos_embed(
            self.pos_embed.shape[-1],
            n_patches, cls_token=True
        )
        self.pos_embed.data.copy_(torch.from_numpy(pe).float().unsqueeze(0))

        dec_pe = get_1d_sincos_pos_embed(
            self.decoder_pos_embed.shape[-1],
            n_patches, cls_token=True
        )
        self.decoder_pos_embed.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0))

        # conv init
        w = self.patch_embed.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        nn.init.normal_(self.cls_token, std=.02)
        nn.init.normal_(self.mask_token, std=.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    # --------------------------
    # patchify / unpatchify
    # --------------------------
    def patchify(self, x):
        return patchify_1d(x, self.patch_size)

    def unpatchify(self, x):
        return unpatchify_1d(x, patch_size=self.patch_size, in_chans=self.in_chans)

    # --------------------------
    # forward_encoder
    # --------------------------
    def forward_encoder(self, x_embed, mask_patch_ids):
        """
        x_embed: [B,32,embed_dim]
        mask_patch_ids: [B,2*topk], e.g. topk=8 => shape=[B,16]
        """
        B, L, D = x_embed.shape  # L=32
        device  = x_embed.device

        # 加 pos_embed (跳过cls)
        x_embed = x_embed + self.pos_embed[:, 1:, :]  # => [B,32,D]
        "print(x_embed)"
        # 构造 bool 掩码: [B,32], True=mask
        bool_mask = torch.zeros([B,L], dtype=torch.bool, device=device)
        # 先对 mask_patch_ids 排序
        mask_patch_ids_sorted, _ = torch.sort(mask_patch_ids, dim=1)
        for b in range(B):
            bool_mask[b, mask_patch_ids_sorted[b]] = True

        # 反过来 keep_mask = ~bool_mask
        keep_mask = ~bool_mask
        # gather 保留的 patch
        keep_count = keep_mask.sum(dim=1).unique()  # tensor([L- 2*topk])
        # => [B, (L - 2*topk), D]
        x_masked = []
        for b in range(B):
            row_mask = keep_mask[b]  # [32]
            row_data = x_embed[b][row_mask]  # => shape[(L-2*topk), D]
            x_masked.append(row_data.unsqueeze(0))
        x_masked = torch.cat(x_masked, dim=0)
        """print("Encoder actual input (x_masked):", x_masked)"""
        # 构造 ids_restore
        keep_ids = []
        for b in range(B):
            keep_ids_b = torch.nonzero(keep_mask[b], as_tuple=False).flatten()  # => shape [L-2*topk]
            keep_ids.append(keep_ids_b.unsqueeze(0))
        keep_ids = torch.cat(keep_ids, dim=0)  # => [B, (L-2*topk)]
        keep_ids_sorted, _ = torch.sort(keep_ids, dim=1)
        # mask_ids_sorted
        mask_ids_sorted = mask_patch_ids_sorted
        # 拼
        ids_shuffle = torch.cat([keep_ids_sorted, mask_ids_sorted], dim=1)  # => [B,32]
        ids_restore = torch.argsort(ids_shuffle, dim=1)                    # => [B,32]

        # 构造 mask: [B,32], 1=mask, 0=keep
        # 先 [B,32] => 前 keep_count=0, 后=1, 再 gather
        mask_arr = torch.ones([B,L], device=device)
        for b in range(B):
            kc = keep_mask[b].sum()
            mask_arr[b, :kc] = 0
        mask_final = torch.gather(mask_arr, dim=1, index=ids_restore)

        # 加 cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_token = cls_token.expand(x_masked.shape[0], -1, -1) # => [B,1,D]
        x_final = torch.cat([cls_token, x_masked], dim=1)       # => [B, (1 + (L-2*topk)), D]

        # Encoder blocks
        for blk in self.blocks:
            x_final = blk(x_final)
        x_final = self.norm(x_final)
        return x_final, mask_final, ids_restore

    # --------------------------
    # forward_decoder
    # --------------------------
    def forward_decoder(self, x, ids_restore):
        # x: [B, (1 + keep_count), encoder_dim]
        # => linear => decoder_dim
        x = self.decoder_embed(x)
        "print(x)"
        B, L_keep_plus1, D_dec = x.shape

        total_len = ids_restore.shape[1] + 1  # +1 for cls
        # mask_token
        mask_tokens = self.mask_token.repeat(B, total_len - L_keep_plus1, 1)
        x_cls = x[:, :1, :]
        x_ = x[:, 1:, :]
        x_ = torch.cat([x_, mask_tokens], dim=1)  # => [B, (L-1), D_dec]
        # unshuffle
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1,1,D_dec))
        """print("Decoder input after concatenation (before adding pos_embed):", x_)"""
        x = torch.cat([x_cls, x_], dim=1)  # => [B, (1+L), D_dec]

        # + decoder pos_embed
        x = x + self.decoder_pos_embed

        # decoder blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # => [B, (1+L), patch_size], remove cls => [B,L,patch_size]
        x = self.decoder_pred(x)
        x = x[:,1:,:]
        return x

    def forward_loss_o(self, x_raw, pred, mask_patch_ids):
        """
        计算总的重建 loss，由三部分组成：
          1. 基础的 patch-level L2 loss（仅计算被 mask 的 patch 部分）
          2. Loss2：在 mask 区域内，计算全序列（unpatchify 后）的预测极值与目标极值之间的 MSE
          3. Loss3：对于每个被 mask 的 patch，计算目标 patch 中极值位置对应的预测值与目标值之间的均方误差
          4. Loss4：计算掩码区域的一阶导数误差（原数据和预测数据）
          5. Loss5：计算掩码区域内原数据一阶导数符号改变的位置的预测误差

        参数:
          x_raw: [B, 1, 256]，原始输入时间序列
          pred: [B, n_patches, patch_size]，decoder 输出的重建 patch
          mask_patch_ids: [B, n_patches]，二值 mask（1 表示该 patch 被 mask，需要重建）
        """
        B = x_raw.size(0)
        L_full = x_raw.size(-1)  # 256

        # 1. 计算 patch 化的目标
        target = patchify_1d(x_raw, self.patch_size)  # [B, n_patches, patch_size]
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1e-6).sqrt()

        # 将 mask_patch_ids 转换为与 patch 数量相同的布尔 mask
        B, n_patches, _ = target.shape
        mask_bool = torch.zeros(B, n_patches, device=x_raw.device, dtype=torch.float32)
        for b in range(B):
            # mask_patch_ids[b] 为该样本需要 mask 的 patch 索引
            mask_bool[b, mask_patch_ids[b]] = 1.0

        # -------------------------
        # 降维操作：将数据压平为一维向量
        full_pred = self.unpatchify(pred)  # [B, 1, 256]
        x_raw_flat = x_raw.squeeze(1)  # [B, 256]
        full_pred_flat = full_pred.squeeze(1)  # [B, 256]

        # 计算目标数据和预测数据的差异
        target_first = x_raw_flat[:, 1:] - x_raw_flat[:, :-1]  # 目标数据一阶导数
        pred_first = full_pred_flat[:, 1:] - full_pred_flat[:, :-1]  # 预测数据一阶导数

        # 将 patch mask 扩展到时域
        time_mask = mask_bool.unsqueeze(-1).repeat(1, 1, self.patch_size).view(B, L_full)

        # -------------------------
        # 1) 基础的 L2 Loss (MSE Loss) —— 只计算被 mask 的 patch
        loss_basic = (pred - target) ** 2
        loss_basic = loss_basic.mean(dim=-1)  # [B, n_patches]
        basic_loss = (loss_basic * mask_bool).sum() / mask_bool.sum()

        # 2) 第一项：掩码部分的原数据和预测数据的一阶导数误差
        loss_derivative = ((target_first - pred_first) ** 2)  # [B, 255]，一阶导数的误差
        loss_derivative = (loss_derivative * time_mask[:, :-1]).sum() / time_mask[:, :-1].sum()

        # 3) 第二项：原数据一阶导数符号改变的点与预测的符号差异
        target_diff_sign = torch.sign(target_first)  # 获取目标数据一阶导数符号
        # 找到符号变化的点：原数据一阶导数符号变化的地方
        # 计算符号变化前后的mask，标记符号变化的位置
        sign_change_mask = (target_diff_sign[:,:-1] != target_diff_sign[:,1:]).float()  # [B, 254]
        # 创建一个形状为 [64, 1] 的全零张量，表示要填充的第一个时间步
        zero_column = torch.zeros(B, 1, device=x_raw.device)  # [64, 1]
        # 将 zero_column 拼接到 sign_change_mask 的前面
        final_sign_change_mask = torch.cat([zero_column, sign_change_mask], dim=-1)  # [64, 255]
        # 与time_mask的交集：得到最终有效区域的mask
        final_mask = final_sign_change_mask * time_mask[:, :-1]  # [B, 255]
        final_mask = torch.cat([final_mask, zero_column], dim=-1)
        # 在符号变化的位置上计算均方误差
        loss_sign_change = ((full_pred_flat - x_raw_flat) ** 2) # [B, 256]，计算MSE
        loss_sign_change = (loss_sign_change * final_mask).sum() / final_mask.sum() if final_mask.sum() > 0 else 0.0

        # 4. Loss4：基于 MLE 思想的正则化损失
        # 为了计算二阶导数的正则项，我们使用 full_pred [B, 1, 256]
        # 并构造一个时域 mask，要求形状为 [B, 1, 256]
        mask_time = time_mask.unsqueeze(1)  # [B, 1, 256]
        # 调用之前定义的 mle_loss 函数，传入 full_pred, mask_time 及调制参数 kappa（例如 self.kappa）
        loss4 = mle_loss(full_pred, mask_time, 1)
        # -------------------------
        # 总损失 = 基础损失 + 一阶导数损失 + 符号变化损失
        total_loss = basic_loss + 0.5*loss_sign_change+0.5*loss_derivative+0.2*loss4

        return total_loss, basic_loss, loss_sign_change

    # --------------------------
    # forward
    # --------------------------
    def forward(self, x_raw):
        """
        x_raw: [B,1,512]
        1) 先算 patch_means => 找到 mask_patch_ids
        2) patch_embed => encoder => decoder => loss
        """
        B, C, L = x_raw.shape

        # 1) 计算要掩码的 patch 索引 => [B, 2*topk]
        mask_patch_ids, all_mask_patch_ids = mask_by_threshold(x_raw, 1.505, 48, 1)

        # 2) patch_embed => [B,32,embed_dim]
        x_embed = self.patch_embed(x_raw)

        # 3) forward_encoder
        latent, mask, ids_restore = self.forward_encoder(x_embed, all_mask_patch_ids)

        # 4) decoder
        pred = self.forward_decoder(latent, ids_restore)

        # 5) loss
        loss, mse, pl = self.forward_loss_o(x_raw, pred, mask_patch_ids)
        return loss, pred, mask, mse, pl


# -------------------------------------------------------------------------
# 6) 构造一个函数: mae_vit_1d_patchmean
# -------------------------------------------------------------------------
def overrange_mae(**kwargs):
    model = MaskedAutoencoderViT_PatchMeanMask(
        seq_len=256,
        patch_size=2,
        in_chans=1,
        embed_dim=384,         # 缩小嵌入维度
        depth=6,               # 减少Transformer层数
        num_heads=6,           # 减少注意力头数
        decoder_embed_dim=256, # 缩小解码器嵌入维度
        decoder_depth=4,       # 减少解码器层数
        decoder_num_heads=8,   # 可根据需要调整（例如8个头）
        mlp_ratio=4.0,         # 可以保持不变或稍微降低
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        norm_pix_loss=False,
        topk=24,                # 根据任务需求决定是否需要调整
        **kwargs
    )
    return model

