#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
from pathlib import Path
import argparse
import numpy as np
from tqdm import trange
from scipy.stats import norm
import matplotlib.pyplot as plt


def load_jsonl(file_path: Path, dataset: str):
    """Load and filter jsonl rows. Return list[dict]."""
    data = []
    with file_path.open("r", encoding="utf-8") as f:
        for line in f:
            try:
                item = json.loads(line)
            except json.JSONDecodeError:
                continue
            if item.get("judge1") is None:
                continue
            data.append(item)
    if len(data) == 0:
        raise ValueError(f"No valid rows loaded from {file_path}")
    return data


def compute_S_and_Sprime(data, dataset: str):
    """Return np.array S (0/1), S' (0/1)."""
    S, S_prime = [], []
    for item in data:
        S.append(int(item["gt"]))
        S_prime.append(int(item["judge1"]))
    return np.asarray(S, dtype=np.int8), np.asarray(S_prime, dtype=np.int8)


def run_once(
    rng: np.random.Generator,
    d_list,
    dataset: str,
    m_dir: int,
    m_cal: int,
    n_test: int,
    alpha_grid: np.ndarray,
    z_zeta: float,
    z_ppi: float,
    tpr_true: float,
    fpr_true: float,
):
    n = len(d_list)
    idx_cal = rng.choice(n, size=m_cal, replace=False)
    idx_dir = rng.choice(n, size=m_dir, replace=False)
    idx_tst = rng.choice(n, size=n_test, replace=False)

    S_cal, Sprime_cal = compute_S_and_Sprime([d_list[i] for i in idx_cal], dataset)
    S_dir, Sprime_dir = compute_S_and_Sprime([d_list[i] for i in idx_dir], dataset)
    S_tst, Sprime_tst = compute_S_and_Sprime([d_list[i] for i in idx_tst], dataset)

    m0 = int((S_cal == 0).sum())
    m1 = int((S_cal == 1).sum())
    if m0 == 0 or m1 == 0:
        return None

    hat_fpr = ((S_cal == 0) & (Sprime_cal == 1)).sum() / m0
    hat_tpr = ((S_cal == 1) & (Sprime_cal == 1)).sum() / m1
    if hat_fpr == 0 and hat_tpr == 0:
        return None

    gamma_hat = hat_tpr - hat_fpr
    sign_hat = 1.0 if gamma_hat >= 0 else -1.0

    R = S_dir.mean()
    Rprime = Sprime_tst.mean()

    # -------------------------------
    # Direct HT
    # -------------------------------
    C_dir = alpha_grid + z_zeta * np.sqrt(alpha_grid * (1.0 - alpha_grid) / m_dir)
    rej_direct_mask = R <= C_dir

    # -------------------------------
    # Noisy HT
    # -------------------------------
    Rprime_hat = gamma_hat * alpha_grid + hat_fpr
    var_D_alpha = (
        Rprime_hat * (1.0 - Rprime_hat) / n_test
        + (1.0 - alpha_grid) ** 2 * hat_fpr * (1.0 - hat_fpr) / m0
        + alpha_grid**2 * hat_tpr * (1.0 - hat_tpr) / m1
    )
    std_D_hat = np.sqrt(var_D_alpha)
    Z = sign_hat * (Rprime - Rprime_hat) / std_D_hat
    rej_est_q_mask = Z <= z_zeta

    # -------------------------------
    # Noisy HT（oracle q）
    # -------------------------------
    gamma_true = tpr_true - fpr_true
    sign_true = 1.0 if gamma_true >= 0 else -1.0
    Rprime_true = gamma_true * alpha_grid + fpr_true
    var_true = Rprime_true * (1.0 - Rprime_true) / n_test
    std_true = np.sqrt(var_true)
    Z_true = sign_true * (Rprime - Rprime_true) / std_true
    rej_known_q_mask = Z_true <= z_zeta

    # PPI
    hat_R_M = S_cal.mean()
    hat_R_JP = Sprime_cal.mean()
    hat_R_J = Sprime_tst.mean()
    hat_lambda = 1
    hat_R = hat_R_M + hat_lambda * (hat_R_J - hat_R_JP)
    hat_R_11 = np.mean((S_cal == 1) & (Sprime_cal == 1))
    hat_C = hat_R_11 - hat_R_M * hat_R_JP
    hat_V = hat_R_M * (1 - hat_R_M) / m_cal + hat_R_J * (1 - hat_R_J) / n_test + hat_R_JP * (1 - hat_R_JP) / m_cal - 2 * hat_C / m_cal
    hat_V = np.clip(hat_V, 0.0, None)
    hat_SE = np.sqrt(hat_V)
    threshold = alpha_grid + z_zeta * hat_SE
    rej_ppi_mask = hat_R <= threshold

    # PPI ++
    taus = np.linspace(0.001, 0.1, 100)
    var_J_term = hat_R_J * (1 - hat_R_J) / n_test

    idxs = rng.permutation(m_cal)
    half = m_cal // 2
    A_idx, B_idx = idxs[:half], idxs[half:]
    
    hat_A = hat_R_J * (1 - hat_R_J) / n_test + hat_R_JP * (1 - hat_R_JP) / m_cal
    hat_B = (hat_R_11 - hat_R_M * hat_R_JP) / m_cal
    tau = 0
    hat_lambda = hat_B / (hat_A + tau)
    hat_R = hat_R_M + hat_lambda * (hat_R_J - hat_R_JP)
    hat_V = hat_R_M * (1 - hat_R_M) / m_cal + hat_lambda ** 2 * hat_A - 2 * hat_lambda * hat_B
    hat_V = np.clip(hat_V, 0.0, None)
    hat_SE = np.sqrt(hat_V)
    threshold = alpha_grid + z_zeta * hat_SE
    rej_ppip_mask = hat_R <= threshold

    # Ridge PPI
    def fold_stats(sel):
            S = S_cal[sel]
            Sp = Sprime_cal[sel]
            m = len(sel)
            R_M = S.mean()
            R_JP = Sp.mean()
            R_11 = np.mean((S == 1) & (Sp == 1))
            A_hat = var_J_term + R_JP * (1 - R_JP) / m
            B_hat = (R_11 - R_M * R_JP) / m
            return (m, R_M, R_JP, R_11, A_hat, B_hat)

    def val_V(tau, est_fold, val_fold):
        m_e, R_M_e, R_JP_e, R_11_e, A_e, B_e = est_fold
        lam = B_e / (A_e + tau)

        m_v, R_M_v, R_JP_v, R_11_v, A_v, B_v = val_fold
        V = (R_M_v * (1 - R_M_v) / m_v) + (lam ** 2) * (var_J_term + R_JP_v * (1 - R_JP_v) / m_v) - 2 * lam * B_v

        # V = R_M_v - lam * R_JP_v
        return max(V, 0.0), lam

    A_fold = fold_stats(A_idx)
    B_fold = fold_stats(B_idx)
    best_tau = None
    best_score = np.inf
    for tau in taus:
        V_AtoB, _ = val_V(tau, A_fold, B_fold)
        V_BtoA, _ = val_V(tau, B_fold, A_fold)
        score = 0.5 * (V_AtoB + V_BtoA)  # 两向平均
        if score < best_score:
            best_score = score
            best_tau = tau
    tau = best_tau
    
    hat_A = hat_R_J * (1 - hat_R_J) / n_test + hat_R_JP * (1 - hat_R_JP) / m_cal
    hat_B = (hat_R_11 - hat_R_M * hat_R_JP) / m_cal
    hat_lambda = hat_B / (hat_A + tau)
    hat_R = hat_R_M + hat_lambda * (hat_R_J - hat_R_JP)
    hat_V = hat_R_M * (1 - hat_R_M) / m_cal + hat_lambda ** 2 * hat_A - 2 * hat_lambda * hat_B
    hat_V = np.clip(hat_V, 0.0, None)
    hat_SE = np.sqrt(hat_V)
    threshold = alpha_grid + z_zeta * hat_SE
    rej_ppir_mask = hat_R <= threshold

    return rej_direct_mask, rej_est_q_mask, rej_known_q_mask, rej_ppi_mask, rej_ppip_mask, rej_ppir_mask


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="SAFE")
    parser.add_argument("--m", type=int, help="m_dir (= m_cal)", default=100)
    parser.add_argument("--j", type=str, default="q")
    parser.add_argument("--B", type=int, default=1000)
    parser.add_argument("--n_test", type=int, default=5000)
    parser.add_argument("--zeta", type=float, default=0.05)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--alpha_min", type=float, default=0.01)
    parser.add_argument("--alpha_max", type=float, default=0.99)
    parser.add_argument("--alpha_n", type=int, default=1000)
    args = parser.parse_args()

    rng = np.random.default_rng(args.seed)

    file_path = Path(f"./data/{args.dataset}/{args.j}.jsonl")
    data = load_jsonl(file_path, args.dataset)

    S_d, Sprime_d = compute_S_and_Sprime(data, args.dataset)
    d0 = int((S_d == 0).sum())
    d1 = len(S_d) - d0
    fpr_true = ((S_d == 0) & (Sprime_d == 1)).sum() / max(d0, 1)
    tpr_true = ((S_d == 1) & (Sprime_d == 1)).sum() / max(d1, 1)

    R_M = S_d.mean()
    alpha_grid = np.linspace(args.alpha_min, args.alpha_max, args.alpha_n)
    z_zeta = norm.ppf(args.zeta)
    z_ppi = norm.ppf(1 - args.zeta)

    rej_direct = np.zeros_like(alpha_grid, dtype=np.int32)
    rej_est_q = np.zeros_like(alpha_grid, dtype=np.int32)
    rej_known_q = np.zeros_like(alpha_grid, dtype=np.int32)
    rej_ppi = np.zeros_like(alpha_grid, dtype=np.int32)
    rej_ppip = np.zeros_like(alpha_grid, dtype=np.int32)
    rej_ppir = np.zeros_like(alpha_grid, dtype=np.int32)

    B_eff = 0
    pbar = trange(args.B, desc=f"{args.dataset} m={args.m}")
    while B_eff < args.B:
        out = run_once(
            rng=rng,
            d_list=data,
            dataset=args.dataset,
            m_dir=args.m,
            m_cal=args.m,  # 题设 m_cal = m_dir
            n_test=args.n_test,
            alpha_grid=alpha_grid,
            z_zeta=z_zeta,
            z_ppi=z_ppi,
            tpr_true=tpr_true,
            fpr_true=fpr_true,
        )
        if out is None:
            pbar.update(1)
            continue
        d_mask, e_mask, k_mask, p_mask, pp_mask, pr_mask = out
        rej_direct += d_mask.astype(np.int32)
        rej_est_q += e_mask.astype(np.int32)
        rej_known_q += k_mask.astype(np.int32)
        rej_ppi += p_mask.astype(np.int32)
        rej_ppip += pp_mask.astype(np.int32)
        rej_ppir += pr_mask.astype(np.int32)
        B_eff += 1
        pbar.update(1)

    prob_direct = rej_direct / B_eff
    prob_est_q = rej_est_q / B_eff
    prob_known_q = rej_known_q / B_eff
    prob_ppi = rej_ppi / B_eff
    prob_ppip = rej_ppip / B_eff
    prob_ppir = rej_ppir / B_eff

    # -------- （left Type-I / right Type-II）---------
    R_M_scalar = float(R_M)
    left = alpha_grid <= R_M_scalar
    right = ~left

    def to_error(prob):
        err = np.empty_like(prob)
        err[left] = prob[left]            # Type-I
        err[right] = 1.0 - prob[right]    # Type-II
        return err

    series = [
        ("Direct HT", prob_direct, dict()),
        ("Noisy HT (general)", prob_est_q, dict()),
        ("Noisy HT (oracle)", prob_known_q, dict(ls="--")),
        ("PPI", prob_ppi, dict()),
        ("PPI++", prob_ppip, dict()),
        ("Ridge PPI", prob_ppir, dict()),
    ]

    os.makedirs(f"./figs/{args.dataset}/", exist_ok=True)
    plt.figure(figsize=(6, 5))

    for label, prob, style in series:
        (line,) = plt.plot(alpha_grid, prob, alpha=0)
        color = line.get_color()
        ls = style.get("ls", "-")
        plt.plot(alpha_grid[left], prob[left], lw=1.2, color=color, ls=ls, label=label)
        plt.plot(alpha_grid[right], 1 - prob[right], lw=1.2, color=color, ls=ls)

    plt.axvline(R_M_scalar, color="black", ls="--", lw=1, label=fr"$R_M$ = {R_M_scalar:.2f}")
    plt.axhline(args.zeta, color="grey", ls=":", lw=1, label=fr"$\zeta$ = {args.zeta:.2f}")
    y_txt = 0.9
    plt.text(R_M_scalar / 2, y_txt, "Type I error", ha="center", va="center")
    plt.text((1 + R_M_scalar) / 2, y_txt, "Type II error", ha="center", va="center")

    plt.xlabel(r"Threshold $\alpha$")
    plt.ylabel(r"$P_e^{(I)} / P_e^{(II)}$")
    plt.title(
        fr"$n_M$ = {args.m}, $n_J$ = {args.n_test}, "
        fr"$TPR$ = {tpr_true:.3f}, $FPR$ = {fpr_true:.3f}"
    )
    plt.grid(alpha=0.3)
    # plt.legend()
    plt.tight_layout()
    out_path = f"./figs/{args.dataset}/{args.j}_{args.m}.png"
    plt.savefig(out_path, dpi=300)
    print(f"[Saved] {out_path}")


if __name__ == "__main__":
    main()