import argparse

from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from scipy.stats import pearsonr

import ORE

mean = 0.01615269998526035
std = 4.988357229871697

PEAK_TH = 1.505
def unpatchify_1d(x, patch_size, in_chans=1):
    """
    将 patchified 表示还原为原始序列
    x: [B, n_patches, patch_size * C] -> [B, 1, 512]
    假设 n_patches * patch_size == 512, C == in_chans
    """
    B, n_patches, dim = x.shape
    assert dim == patch_size * in_chans, f"Expected dim={patch_size * in_chans}, but got {dim}"
    L = n_patches * patch_size
    x = x.reshape(B, n_patches, patch_size, in_chans)
    x = x.permute(0, 3, 1, 2).reshape(B, in_chans, L)
    return x


class TimeSeriesDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.root = Path(root)
        self.files = sorted(list(self.root.glob("*.npy")))
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        # 每个文件期望存储 shape: (512,)
        data = np.load(self.files[index])
        data = torch.tensor(data, dtype=torch.float32).unsqueeze(0)  # 转为 [1, 512]
        if self.transform:
            data = self.transform(data)
        # 返回 (data, dummy)
        return data, 0


def parse_args():
    parser = argparse.ArgumentParser(description="Test MAE for 1D Time Series with Two Data Sources")
    parser.add_argument('--checkpoint', type=str, required=True,
                        help='Path to checkpoint file')
    parser.add_argument('--data_path', type=str, default=r"",
                        help='Path to raw test dataset (folder containing .npy files) used as ground truth')
    parser.add_argument('--data2_path', type=str, default=r"",
                        help='Path to prediction dataset (folder containing .npy files) used as model input')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for testing')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of DataLoader workers')
    parser.add_argument('--model', type=str, default='',
                        help='')
    parser.add_argument('--seq_len', type=int, default=256, help='Input sequence length')
    return parser.parse_args()




def main():
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ---------- 1. 加载模型 ----------
    model = ORE.__dict__[args.model]()
    with torch.serialization.safe_globals([argparse.Namespace]):
        ckpt = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(ckpt.get('model', ckpt))
    model.to(device).eval()

    # ---------- 2. DataLoader ----------
    raw_loader  = DataLoader(TimeSeriesDataset(args.data_path),
                             batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers, pin_memory=True)
    pred_loader = DataLoader(TimeSeriesDataset(args.data2_path),
                             batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers, pin_memory=True)

    all_peaks = []
    for raw_batch, _ in raw_loader:
        raw_flat = raw_batch.squeeze(1).numpy()  # [B, L]
        peaks = np.max(np.abs(raw_flat), axis=1)  # 每个样本的最大绝对值，shape=[B,]
        # 筛选出峰值大于 1.505 的样本

        valid_peaks = peaks[peaks > PEAK_TH]

        all_peaks.extend(valid_peaks.tolist())
    peakavg = float(np.mean(all_peaks))  # 全局平均峰值
    peak_const = peakavg - PEAK_TH  # 固定的 peak 值

    # ---------- 4. 初始化累计量 ----------
    total_err, total_valid     = 0.0, 0           # 掩码‑MSE
    total_corr, num_corr       = 0.0, 0           # Pearson
    total_peak_err, peak_cnt   = 0.0, 0           # Peak‑MSE

    # ---------- 5. 评估循环 ----------
    with torch.no_grad():
        for (raw_batch, _), (pred_batch, _) in zip(raw_loader, pred_loader):
            raw_data = raw_batch.to(device)                      # [B,1,L]
            pred_in  = pred_batch.to(device)

            # 前向
            _, pred_full, _, _, _ = model(pred_in)
            pred_full = unpatchify_1d(pred_full, patch_size=2, in_chans=1)

            # 掩码
            valid_mask = torch.abs(raw_data) > PEAK_TH           # [B,1,L]

            # 累计掩码‑MSE
            sq_err     = (pred_full - raw_data) ** 2
            total_err  += sq_err[valid_mask].sum().item()
            total_valid+= valid_mask.sum().item()

            # 转 numpy 计算其他指标
            raw_flat  = raw_data.squeeze(1).cpu().numpy()
            pred_flat = pred_full.squeeze(1).cpu().numpy()
            mask_flat = valid_mask.squeeze(1).cpu().numpy()

            # 峰值位置 mask
            diff      = raw_flat[:,1:] - raw_flat[:,:-1]
            sign_chg  = (np.sign(diff)[:,1:] != np.sign(diff)[:,:-1]).astype(float)
            peak_mask = np.zeros_like(raw_flat)
            peak_mask[:,1:-1] = sign_chg
            peak_mask = peak_mask * mask_flat

            # 逐样本指标
            for i in range(raw_flat.shape[0]):
                idx_mask = mask_flat[i].astype(bool)
                if idx_mask.sum() < 2:
                    continue

                true_vals = raw_flat[i][idx_mask]
                pred_vals = pred_flat[i][idx_mask]

                # Pearson
                corr, _ = pearsonr(true_vals, pred_vals)
                if not np.isnan(corr):
                    total_corr += corr
                    num_corr   += 1

                # Peak‑MSE
                idx_peak = peak_mask[i].astype(bool)
                if idx_peak.sum() > 0:
                    peak_err = np.mean((pred_flat[i][idx_mask] - raw_flat[i][idx_peak])**2)
                    total_peak_err += peak_err * idx_peak.sum()
                    peak_cnt       += idx_peak.sum()

    # ---------- 6. 计算并打印全局结果 ----------
    global_mse = total_err / total_valid if total_valid else 0.0
    global_psnr = 10 * np.log10((peak_const ** 2) / global_mse) if global_mse > 0 else 0.0
    avg_corr  = total_corr  / num_corr    if num_corr    else 0.0
    peak_mse  = total_peak_err / peak_cnt if peak_cnt    else 0.0

    print(f"Fixed Peak Value       : {peak_const:.4f}")
    print(f"Global MSE   (masked)  : {global_mse:.6f}")
    print(f"PSNR  (masked)         : {global_psnr:.2f} dB")
    print(f"Average Corr (masked)  : {avg_corr:.4f}")
    print(f"Peak‑MSE     (masked)  : {peak_mse:.6f}")

if __name__ == "__main__":
    main()

