import torch
import numpy as np
import argparse
from pathlib import Path
import allantools
from DE import denoise_dualmask
import pandas as pd
mean = 0.029607261518762952
std = 0.0011154172318002256



def simple_chunk_denoise(signal, model, window_size=256, device='cpu'):
    N = len(signal)
    out_arr = np.zeros(N, dtype=float)
    def predict_sample(_model, chunk_np):
        with torch.no_grad():
            t = torch.from_numpy(chunk_np).float().unsqueeze(0).unsqueeze(0).to(device)
            loss, y_final,_ = _model(t,t)
        return y_final.squeeze().cpu().numpy()
    chunk_id = 0
    start = 0
    while start < N:
        end = start + window_size
        if end > N:
            end = N
        chunk = signal[start:end]
        real_len = len(chunk)
        # 如果不足 window_size, zero-pad
        if real_len < window_size:
            pad_len = window_size - real_len
            chunk = np.concatenate([chunk, np.zeros(pad_len)], axis=0)
        pred_chunk = predict_sample(model, chunk)[:real_len]
        out_arr[start:start + real_len] = pred_chunk
        start += window_size
        chunk_id += 1
    return out_arr


def compute_allan_metrics(data, dt):
    # 如果需要数据校准，可以先做线性变换：
    data = np.asarray(data)
    data = data * std + mean
    # 使用 allantools 计算 Allan 偏差（adev）
    # rate 参数为 1/dt，因为 dt 是采样周期
    taus, adev, ade, nada = allantools.adev(data, rate=1 / dt, data_type="freq")
    sigma = adev  # Allan偏差
    # ARW估计：选取前 n_arw 个τ点（短时随机噪声区域），计算σ*sqrt(τ)
    arw_vals = sigma[7]
    ARW = np.mean(arw_vals)
    # Bias Instability（BI）：取整个τ区间内的最小σ作为估计，并用系数0.664校正
    sigma_min = np.min(sigma)
    BI = sigma[-1]*0.664
    # Quantization Noise（QN）：使用最短τ点的σ，乘以√3
    QN = sigma[0]*taus[0]/np.sqrt(3)
    return ARW, BI, QN, taus, sigma


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. 加载模型
    checkpoint_path1 = r""
    # 如果只需要一个model1
    model1 = denoise_dualmask()
    with torch.serialization.safe_globals([argparse.Namespace]):
        ckpt = torch.load(checkpoint_path1, map_location=device, weights_only=False)
    model1.load_state_dict(ckpt['model'])
    model1.to(device)
    model1.eval()

    # 2. 获取全部 npy 文件列表
    data_dir = Path(r"")
    npy_files = sorted(list(data_dir.glob("*.npy")))
    total_files = len(npy_files)
    group_size = 200
    num_groups = total_files // group_size  # 例如 50000/200 = 250

    # 用于累计各指标
    ARW_denoised_total = 0.0
    BI_denoised_total = 0.0
    QN_denoised_total = 0.0


    sigma_denoised_list = []

    taus_list = []

    fs = 100  # 采样频率
    dt = 1 / fs

    for i in range(num_groups):
        # 取第 i 组：每组 200 个文件
        group_files = npy_files[i * group_size:(i + 1) * group_size]
        group_segments = [np.load(f) for f in group_files]  # 每个 shape: (256,)
        orig_long = np.concatenate(group_segments, axis=0)  # shape: (200*256,)
        print(f"Processing group {i + 1}/{num_groups}, signal length: {orig_long.shape[0]}")
        denoised_signal = simple_chunk_denoise(
        signal=orig_long,
        model=model1,
        window_size=256
        )
        ARW_denoised, BI_denoised, QN_denoised, taus, sigma_denoised = compute_allan_metrics(denoised_signal*(180/np.pi), dt)
        ARW_denoised_total += ARW_denoised
        BI_denoised_total += BI_denoised
        QN_denoised_total += QN_denoised
        sigma_denoised_list.append(sigma_denoised)
        taus_list.append(taus)

    # 计算各指标平均值
    ARW_denoised_avg = ARW_denoised_total / num_groups
    BI_denoised_avg = BI_denoised_total / num_groups
    QN_denoised_avg = QN_denoised_total / num_groups

    print("Averaged Original Signal Allan Metrics:")
    print("  ARW: {:.6f}".format(ARW_denoised_avg * 60))
    print("  BI: {:.6f}".format(BI_denoised_avg * 3600))
    print("  QN: {:.6f}".format(QN_denoised_avg))
    taus_plot = taus_list[0]
    sigma_denoised_avg_curve = np.mean(sigma_denoised_list, axis=0)

    data_points = list(zip(taus_plot, sigma_denoised_avg_curve))

    # 转为 DataFrame，方便查看和导出
    df = pd.DataFrame(data_points, columns=["tau (s)", "sigma"])
    df.to_csv("", index=False)

if __name__ == "__main__":
    main()