import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from tqdm import tqdm, trange

# ---------------------------------------------------------------------
# Simulation parameters
# ---------------------------------------------------------------------
B = 1000
m_dir = 100
m_cal = m_dir
n_test = 10000
p_true = 0.25
tpr_true = 0.75
fpr_true = 0.25
zeta = 0.05
z_zeta = norm.ppf(zeta)
z_ppi = norm.ppf(1 - zeta / 2)
alpha_grid = np.linspace(0.01, 0.49, 1000)

# Rejection probabilities for each method at each p_true
prob_direct_list = []
prob_est_q_list = []
prob_known_q_list = []
prob_ppi = []
prob_ppip = []
prob_ppir = []


# ---------------------------------------------------------------------
# Helper: generate S′ from S through a symmetric channel
# ---------------------------------------------------------------------
def noisy_channel(S, tpr, fpr):
    """
    Apply an asymmetric noisy channel:
      - If S = 1, keep it as 1 with prob tpr, flip to 0 with prob 1-tpr.
      - If S = 0, flip to 1 with prob fpr, keep as 0 with prob 1-fpr.
    """
    S_noisy = np.empty_like(S)
    mask_pos = S == 1
    mask_neg = ~mask_pos
    S_noisy[mask_pos] = (np.random.rand(mask_pos.sum()) < tpr).astype(int)
    S_noisy[mask_neg] = (np.random.rand(mask_neg.sum()) < fpr).astype(int)

    return S_noisy


# ---------------------------------------------------------------------
# Loop over alpha grid
# ---------------------------------------------------------------------
rng = np.random.default_rng()

for alpha in tqdm(alpha_grid):
    rej_direct = 0
    rej_est_q = 0
    rej_known_q = 0
    rej_ppi = 0
    rej_ppip = 0
    rej_ppir = 0

    idx = 0
    while idx < B:
        S_cal = rng.binomial(1, p_true, size=m_cal)
        Sprime_cal = noisy_channel(S_cal, tpr=tpr_true, fpr=fpr_true)
        S_Jp = Sprime_cal

        m0 = (S_cal == 0).sum()
        m1 = m_cal - m0
        if m0 == 0 or m1 == 0:
            continue

        hat_fpr = ((S_cal == 0) & (Sprime_cal == 1)).sum() / m0
        hat_tpr = ((S_cal == 1) & (Sprime_cal == 1)).sum() / m1
        hat_gamma = hat_tpr - hat_fpr
        sign_hat = 1 if hat_gamma >= 0 else -1

        S_dir = rng.binomial(1, p_true, size=m_dir)
        R = S_dir.mean()
        S_M = S_dir

        S_test = rng.binomial(1, p_true, size=n_test)
        Sprime_test = noisy_channel(S_test, tpr=tpr_true, fpr=fpr_true)
        S_J = Sprime_test
        Rprime = Sprime_test.mean()

        # --- Direct HT
        C_dir = alpha + z_zeta * np.sqrt(alpha * (1 - alpha) / m_dir)
        if R <= C_dir:
            rej_direct += 1

        # --- Noisy HT (estimated q)
        Rprime_hat = hat_fpr + (hat_tpr - hat_fpr) * alpha
        var_D_alpha = (
            Rprime_hat * (1 - Rprime_hat) / n_test
            + (1 - alpha) ** 2 * hat_fpr * (1 - hat_fpr) / m0
            + alpha**2 * hat_tpr * (1 - hat_tpr) / m1
        )
        std_D_hat = np.sqrt(var_D_alpha)
        if std_D_hat == 0:
            continue
        Z = sign_hat * (Rprime - Rprime_hat) / std_D_hat
        if Z <= z_zeta:
            rej_est_q += 1

        # --- Noisy HT (true q)
        Rprime_true = fpr_true + (tpr_true - fpr_true) * alpha
        var_true = Rprime_true * (1 - Rprime_true) / n_test
        std_true = np.sqrt(var_true)
        sign_true = 1 if tpr_true >= 0.5 else -1
        Z_true = sign_true * (Rprime - Rprime_true) / std_true
        if Z_true <= z_zeta:
            rej_known_q += 1

        # --- PPI
        hat_R_M = S_cal.mean()
        hat_R_JP = Sprime_cal.mean()
        hat_R_J = Sprime_test.mean()
        hat_lambda = 1
        hat_R = hat_R_M + hat_lambda * (hat_R_J - hat_R_JP)
        hat_R_11 = np.mean((S_cal == 1) & (Sprime_cal == 1))
        hat_C = hat_R_11 - hat_R_M * hat_R_JP
        hat_V = (
            hat_R_M * (1 - hat_R_M) / m_cal
            + hat_R_J * (1 - hat_R_J) / n_test
            + hat_R_JP * (1 - hat_R_JP) / m_cal
            - 2 * hat_C / m_cal
        )
        hat_V = np.clip(hat_V, 0.0, None)
        hat_SE = np.sqrt(hat_V)
        threshold = alpha + z_zeta * hat_SE
        if hat_R <= threshold:
            rej_ppi += 1

        # PPI ++
        hat_A = hat_R_J * (1 - hat_R_J) / n_test + hat_R_JP * (1 - hat_R_JP) / m_cal
        hat_B = (hat_R_11 - hat_R_M * hat_R_JP) / m_cal
        tau = 0
        hat_lambda = hat_B / (hat_A + tau)
        hat_R = hat_R_M + hat_lambda * (hat_R_J - hat_R_JP)
        hat_V = (
            hat_R_M * (1 - hat_R_M) / m_cal
            + hat_lambda**2 * hat_A
            - 2 * hat_lambda * hat_B
        )
        hat_V = np.clip(hat_V, 0.0, None)
        hat_SE = np.sqrt(hat_V)
        threshold = alpha + z_zeta * hat_SE
        if hat_R <= threshold:
            rej_ppip += 1

        # Ridge PPI
        taus = np.linspace(0.001, 0.1, 100)
        var_J_term = hat_R_J * (1 - hat_R_J) / n_test

        idxs = rng.permutation(m_cal)
        half = m_cal // 2
        A_idx, B_idx = idxs[:half], idxs[half:]

        def fold_stats(sel):
            S = S_cal[sel]
            Sp = Sprime_cal[sel]
            m = len(sel)
            R_M = S.mean()
            R_JP = Sp.mean()
            R_11 = np.mean((S == 1) & (Sp == 1))
            A_hat = var_J_term + R_JP * (1 - R_JP) / m
            B_hat = (R_11 - R_M * R_JP) / m
            return (m, R_M, R_JP, R_11, A_hat, B_hat)

        def val_V(tau, est_fold, val_fold):
            m_e, R_M_e, R_JP_e, R_11_e, A_e, B_e = est_fold
            lam = B_e / (A_e + tau)

            m_v, R_M_v, R_JP_v, R_11_v, A_v, B_v = val_fold
            V = (
                (R_M_v * (1 - R_M_v) / m_v)
                + (lam**2) * (var_J_term + R_JP_v * (1 - R_JP_v) / m_v)
                - 2 * lam * B_v
            )
            return max(V, 0.0), lam

        A_fold = fold_stats(A_idx)
        B_fold = fold_stats(B_idx)
        best_tau = None
        best_score = np.inf
        for tau in taus:
            V_AtoB, _ = val_V(tau, A_fold, B_fold)
            V_BtoA, _ = val_V(tau, B_fold, A_fold)
            score = 0.5 * (V_AtoB + V_BtoA)
            if score < best_score:
                best_score = score
                best_tau = tau
        tau = best_tau

        hat_A = hat_R_J * (1 - hat_R_J) / n_test + hat_R_JP * (1 - hat_R_JP) / m_cal
        hat_B = (hat_R_11 - hat_R_M * hat_R_JP) / m_cal
        hat_lambda = hat_B / (hat_A + tau)
        hat_R = hat_R_M + hat_lambda * (hat_R_J - hat_R_JP)
        hat_V = (
            hat_R_M * (1 - hat_R_M) / m_cal
            + hat_lambda**2 * hat_A
            - 2 * hat_lambda * hat_B
        )
        hat_V = np.clip(hat_V, 0.0, None)
        hat_SE = np.sqrt(hat_V)
        threshold = alpha + z_zeta * hat_SE
        if hat_R <= threshold:
            rej_ppir += 1

        idx += 1

    # Append average rejection probabilities
    prob_direct_list.append(rej_direct / B)
    prob_est_q_list.append(rej_est_q / B)
    prob_known_q_list.append(rej_known_q / B)
    prob_ppi.append(rej_ppi / B)
    prob_ppip.append(rej_ppip / B)
    prob_ppir.append(rej_ppir / B)


# ---------------------------------------------------------------------
# Plot
# ---------------------------------------------------------------------
plt.figure(figsize=(6, 5))


def fold_tail(y, alpha, pivot):
    y = np.asarray(y)
    return np.where(alpha <= pivot, y, 1 - y)


direct = np.array(prob_direct_list)
est_q = np.array(prob_est_q_list)
known_q = np.array(prob_known_q_list)
ppi = np.array(prob_ppi)
ppip = np.array(prob_ppip)
ppir = np.array(prob_ppir)

direct_f = fold_tail(direct, alpha_grid, p_true)
est_q_f = fold_tail(est_q, alpha_grid, p_true)
known_q_f = fold_tail(known_q, alpha_grid, p_true)
ppi_f = fold_tail(ppi, alpha_grid, p_true)
ppip_f = fold_tail(ppip, alpha_grid, p_true)
ppir_f = fold_tail(ppir, alpha_grid, p_true)

plt.plot(alpha_grid, direct_f, label="Direct HT", lw=1.2)
plt.plot(alpha_grid, est_q_f, label="Noisy HT (general)", lw=1.2)
plt.plot(alpha_grid, known_q_f, label="Noisy HT (oracle)", ls="--", lw=1.2)
plt.plot(alpha_grid, ppi_f, label="PPI", lw=1.2)
plt.plot(alpha_grid, ppip_f, label="PPI++", lw=1.2)
plt.plot(alpha_grid, ppir_f, label="Ridge PPI", lw=1.2)

R_M_scalar = float(p_true)
y_txt = 0.8
plt.text(R_M_scalar / 2, y_txt, "Type I error", ha="center", va="center")
plt.text((0.5 + R_M_scalar) / 2, y_txt, "Type II error", ha="center", va="center")

plt.axhline(zeta, color="gray", ls=":", lw=1, label=f"$\\zeta$ = {zeta:.2f}")
plt.axvline(p_true, color="black", ls="--", lw=1, label=f"$R_M$ = {p_true:.2f}")

plt.xlabel(f"$\\alpha$")
plt.ylabel("Error Probability")
plt.title(
    f"$TPR_{{true}}$ = {tpr_true}, $FPR_{{true}}$ = {fpr_true}, $n_M$ = {m_cal}, $n_J$ = {n_test}"
)
# plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(
    f"./figs/synth/NHT_syn_p:{p_true}_tpr:{tpr_true}_fpr{fpr_true}.png", dpi=300
)
