import torch
import numpy as np
import argparse
from pathlib import Path
from DE import denoise_dualmask
def main():
    # ===============================
    # 1. 设置相关路径与加载模型
    # ===============================
    # 数据文件夹：保存多个 256 长度的一维 npy 文件（原始信号）
    data_dir = Path(r"")
    # 有效信号位置信息文件夹
    position_dir = Path(r"")
    checkpoint_path = r""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = denoise_dualmask()
    with torch.serialization.safe_globals([argparse.Namespace]):
        ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(ckpt['model'])
    model.to(device)
    model.eval()
    # ===============================
    # 2. 遍历文件夹，对每个文件计算扰动区域的 MSE 与 SNR
    # ===============================
    # 获取所有数据文件（假设都是 .npy 文件，每个文件中存储长度为256的一维数组）
    data_files = sorted(data_dir.glob("*.npy"))
    snr_list = []
    # 线性变换系数，保持与训练或后处理一致
    std = 0.0011154172318002256
    mean = 0.029607261518762952
    for data_file in data_files:
        # 加载原始信号（既作为输入，也作为计算参考的“干净信号”），形状: (256,)
        orig_signal = np.load(data_file)
        # 将原始信号转换为模型输入形状 [B,1,L]
        signal_tensor = torch.from_numpy(orig_signal).float().unsqueeze(0).unsqueeze(0).to(device)
        #模型预测（去噪）——传入相同信号作为噪声与参考
        output = model(signal_tensor, signal_tensor)
        denoised_signal = output[1]
        # 将预测结果转换为 numpy 数组，最终形状 (256,)
        denoised_np = denoised_signal.cpu().detach().numpy().squeeze()
        #读取对应的扰动位置信息文件
        base_name = data_file.stem
        pos_file = position_dir / f"{base_name}.npy"
        if not pos_file.exists():
            print(f"扰动位置文件不存在，跳过: {pos_file}")
            continue
        # 读入扰动位置
        mask_data = np.load(pos_file, allow_pickle=True)
        mask_indices = mask_data.flatten().astype(int)
        orig_trans = orig_signal * std + mean
        pred_trans = denoised_np * std + mean
        # ------------- SNR 计算 -------------
        # 非扰动区域：全集索引中不在 mask_indices 内的部分
        all_indices = np.arange(len(orig_signal))
        non_mask_indices = np.setdiff1d(all_indices, mask_indices)
        # 基线：非扰动区域的均值
        baseline = np.mean(orig_trans[non_mask_indices])
        # 有效信号功率：扰动区域中与 baseline 之差的均方值
        signal_power = np.mean((pred_trans[mask_indices] - baseline) ** 2)
        # 噪声功率：非扰动区域中与 baseline 之差的均方值
        noise_power = np.mean((pred_trans[non_mask_indices] - baseline) ** 2)
        if noise_power > 0:
            snr_db = np.abs(10 * np.log10(signal_power / noise_power))
        else:
            snr_db = float('inf')
        snr_list.append(snr_db)


    # ===============================
    # 3. 输出所有文件的统计指标
    # ===============================
    if snr_list:
        avg_snr = np.mean(snr_list)
        print("所有文件扰动区域 SNR 的平均值: {:.2f} dB".format(avg_snr))
    else:
        print("未计算到任何文件的 MSE。")
if __name__ == '__main__':
    main()