
import argparse
import os

import torch
import numpy as np
import matplotlib.pyplot as plt

from ORE import overrange_mae
from DE import denoise_dualmask


# —— 硬编码常量 ——
PEAK_MEAN, PEAK_STD, PEAK_TH    = 0.01615269998526035, 4.988357229871697, 7.5
DENOISE_MEAN, DENOISE_STD, DENOISE_TH = 0.029607261518762952, 0.0011154172318002256, 0.1


def unpatchify_1d(x, patch_size, in_chans=1):
    B, n_patches, dim = x.shape
    L = n_patches * patch_size
    x = x.reshape(B, n_patches, patch_size, in_chans)
    x = x.permute(0, 3, 1, 2).reshape(B, in_chans, L)
    return x


def zscore_normalize(data_1d, mean_v, std_v):
    return (data_1d - mean_v) / (std_v + 1e-9)


def zscore_denormalize(normed_1d, mean_v, std_v):
    return normed_1d * std_v + mean_v


def load_model(model_class, ckpt_path, device):
    checkpoint = torch.load(ckpt_path, map_location=device)
    model = model_class().to(device)
    sd = checkpoint.get('model', checkpoint)
    model.load_state_dict(sd)
    model.eval()
    return model


def blend_boundary(fused, A, B, start, mid, end):
    n = len(fused)
    s, e = max(0, start), min(n, end)
    for i in range(s, e):
        alpha = 1.0
        if i < mid:
            alpha = 1.0 - (i - s) / ((mid - s) + 1e-9)
        else:
            alpha = (i - mid) / ((e - mid) + 1e-9)
        valA = A[i] if A is not None else fused[i]
        valB = B[i] if B is not None else fused[i]
        fused[i] = alpha * valA + (1 - alpha) * valB
    return fused


def make_noise_mask_30pts_rule(x):
    n = len(x)
    mask = np.zeros(n, bool)
    lt = np.abs(x) < DENOISE_TH
    i = 0
    while i < n:
        if lt[i]:
            j = i
            while j < n and lt[j]:
                j += 1
            if j - i >= 30:
                mask[i:j] = True
            i = j
        else:
            i += 1
    return mask


def moe_gate_improved(x):
    n = len(x)
    gates = np.full(n, 2, int)  # 0:peak,1:noise,2:smooth
    gates[np.abs(x) > PEAK_TH] = 0
    leftover = (gates == 2)
    local = make_noise_mask_30pts_rule(x[leftover])
    idxs = np.where(leftover)[0]
    gates[idxs[local]] = 1
    return gates


def fuse_with_boundary_blend(x, peak_mask, noise_mask,
                             peak_pred, noise_pred, smooth_pred,
                             boundary_len=2):
    n = len(x)
    fused = np.zeros(n, float)
    smooth_mask = ~(peak_mask | noise_mask)

    i = 0
    while i < n:
        if peak_mask[i]:
            j = i
            while j < n and peak_mask[j]: j += 1
            fused[i:j] = peak_pred[i:j] if peak_pred is not None else smooth_pred[i:j]
            if j < n:
                next_mask = (noise_mask[j], smooth_mask[j])
                A, B = (peak_pred, noise_pred) if next_mask[0] else (peak_pred, smooth_pred)
                fused = blend_boundary(fused, A, B, j-boundary_len, j, j+boundary_len)
            i = j
        elif noise_mask[i]:
            j = i
            while j < n and noise_mask[j]: j += 1
            fused[i:j] = noise_pred[i:j] if noise_pred is not None else smooth_pred[i:j]
            if j < n:
                next_mask = (peak_mask[j], smooth_mask[j])
                A, B = (noise_pred, peak_pred) if next_mask[0] else (noise_pred, smooth_pred)
                fused = blend_boundary(fused, A, B, j-boundary_len, j, j+boundary_len)
            i = j
        else:  # smooth
            j = i
            while j < n and smooth_mask[j]: j += 1
            fused[i:j] = smooth_pred[i:j]
            if j < n:
                next_mask = (peak_mask[j], noise_mask[j])
                A, B = (smooth_pred, peak_pred) if next_mask[0] else (smooth_pred, noise_pred)
                fused = blend_boundary(fused, A, B, j-boundary_len, j, j+boundary_len)
            i = j

    return fused


def inference_moe(x_raw, model_peak, model_noise, device, boundary_len=2):
    # 前置门控
    peak_needed = np.any(np.abs(x_raw) >= PEAK_TH)
    noise_needed = np.sum(np.abs(x_raw) < DENOISE_TH) > len(x_raw)*0.5

    # 归一化并推理
    pred_peak, pred_noise = None, None
    if peak_needed:
        x_p = zscore_normalize(x_raw, PEAK_MEAN, PEAK_STD)
        t = torch.tensor(x_p, device=device).float().view(1,1,-1)
        _, y_p, *rest = model_peak(t)
        y_p = unpatchify_1d(y_p, patch_size=2).squeeze().cpu().numpy()
        pred_peak = zscore_denormalize(y_p, PEAK_MEAN, PEAK_STD)
    if noise_needed:
        x_n = zscore_normalize(x_raw, DENOISE_MEAN, DENOISE_STD)
        t = torch.tensor(x_n, device=device).float().view(1,1,-1)
        _, y_n, *_ = model_noise(t, t)
        y_n = y_n.squeeze().cpu().numpy()
        pred_noise = zscore_denormalize(y_n, DENOISE_MEAN, DENOISE_STD)

    # 后置门控 + 融合
    gates = moe_gate_improved(x_raw)
    peak_mask  = (gates == 0)
    noise_mask = (gates == 1)
    fused = fuse_with_boundary_blend(
        x_raw, peak_mask, noise_mask,
        pred_peak, pred_noise, x_raw,
        boundary_len=boundary_len
    )
    return fused


def parse_args():
    p = argparse.ArgumentParser(
        description="MOE-Gyro Inference: fuse Peak & Denoise experts"
    )
    p.add_argument("-p", "--peak_ckpt", type=str, required=True,
                   help="Peak expert checkpoint (.pth)")
    p.add_argument("-d", "--denoise_ckpt", type=str, required=True,
                   help="Denoise expert checkpoint (.pth)")
    p.add_argument("-i", "--input", type=str, required=True,
                   help="Input .npy file (256-length)")
    p.add_argument("-o", "--output", type=str, required=True,
                   help="Output fused .npy path")
    p.add_argument("-b", "--boundary_len", type=int, default=2,
                   help="Boundary blend length (default: 2)")
    return p.parse_args()


def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 加载模型
    model_peak = load_model(overrange_mae,   args.peak_ckpt,   device)
    model_noise= load_model(denoise_dualmask, args.denoise_ckpt, device)

    # 读数据
    x = np.load(args.input)
    if x.ndim != 1 or x.shape[0] != 256:
        print(f"[WARN] input shape {x.shape}, expected (256,)")

    # 推理
    fused = inference_moe(x, model_peak, model_noise, device, args.boundary_len)

    # 保存
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    np.save(args.output, fused)
    print(f"Fused result saved to {args.output}")

    # 可视化（可选）
    plt.figure(figsize=(8,4))
    plt.plot(x, label="CLIP")
    plt.plot(fused, label="Fused(MOE)")
    plt.legend(); plt.grid(True); plt.show()


if __name__ == "__main__":
    main()

