# === Colab single cell: Neutrality audit (NO MFG) ===
# Works for GPT-2, Llama, Qwen, Mistral style HF CausalLM models. For Llama permission has to be granted first by making an account.
# Tests: pilot, closed pooled, open pooled, placebo (eps=0), Azuma coverage, label randomization p-value.
# Statistics: prompt-level t-test + 95% CI, pooled Z-test p_pool, e-process Emax.

!pip -q install "transformers>=4.44.0" "torch>=2.0.0" "scipy>=1.10.0"

import math
import numpy as np
import torch
import scipy.stats as st
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

import os, subprocess, gc, ctypes
import multiprocessing as mp
mp.set_start_method("spawn", force=True)

# ---------------- Cache relocation (Colab) ----------------
os.environ["HF_HOME"] = "/content/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/content/hf_cache"
os.makedirs("/content/hf_cache", exist_ok=True)
subprocess.run(
    "mkdir -p /content/hf_cache && mv /root/.cache/huggingface/* /content/hf_cache/ 2>/dev/null || true",
    shell=True
)
subprocess.run(
    "rm -rf /root/.cache/huggingface && ln -s /content/hf_cache /root/.cache/huggingface",
    shell=True
)

# ---------------- Config ----------------
MODEL_NAME   = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"

# Probe params (CPU-friendly defaults)
T_SOFT       = 1.0
EPS_H        = 5e-2           # nudge size
M_SIB        = 16             # siblings
STEPS_MAX    = 16             # decoding steps
PROMPT_LEN   = 96             # truncate prompt tokens to this
ADV_DIR      = False          # if True, uses adversarial dir (one backward per prompt)

# Optional CPU speedup: vocab top-K projection
TOPK         = 2000 if DEVICE == "cpu" else None

# RNG params
BASE_SEED = 12345
BASE_RNG = np.random.default_rng(BASE_SEED) if BASE_SEED is not None else np.random.default_rng()
NUM_SEEDS = 3
SEEDS = BASE_RNG.integers(0, 2**32, size=NUM_SEEDS)
RNGS = [(torch.Generator(device=DEVICE).manual_seed(int(s)),
         np.random.default_rng(int(s)), int(s)) for s in SEEDS]

ALPHA = 0.05

PROMPTS = [
  "Discuss the main factors that influenced the Industrial Revolution in Britain and their long term consequences for global trade and technology.",
  "Summarize the core ideas of Maxwell's equations and how they unify electricity, magnetism, and light, with a brief modern application.",
  "Explain the chain rule of calculus and give a nontrivial example involving nested functions from physics.",
  "Describe the key mechanisms of photosynthesis and why the Z scheme is important for energy transfer.",
  "How much is 1 + 1?"
  "Provide a short overview of reinforcement learning, including policy gradients and value based methods, and compare their tradeoffs.",
  "Explain the role of eigenvalues and eigenvectors in stability analysis of dynamical systems and illustrate with a simple linear system.",
  "Describe the greenhouse effect and identify the molecular properties that make gases efficient absorbers of infrared radiation.",
  "Summarize the structure of DNA, the central dogma of molecular biology, and one mechanism that introduces genetic variation.",
  "Are you sentient?"
  "Define entropy in thermodynamics and provide an example of how entropy changes in a reversible and an irreversible process.",
  "Explain the difference between supervised and unsupervised learning and provide one example algorithm for each category.",
  "Discuss the causes and consequences of plate tectonics, emphasizing the evidence for continental drift.",
  "Give 5 tips to stop procrastinating"
  "Describe the distinction between correlation and causation and outline one method that can help infer causal structure from data.",
  "Explain how Fourier series represent periodic functions and state one physical system where this representation is essential.",
  "Summarize the concept of opportunity cost in microeconomics and give an example involving production decisions.",
  "Describe the role of neurotransmitters in synaptic transmission and provide an example of how one drug alters this process."
]

# ---------------- Utilities ----------------
def softmax_T32(logits, T):
    z = logits / T
    z = z - z.amax(dim=-1, keepdim=True)
    return torch.softmax(z, dim=-1)

def js_div(p, q, eps=1e-7):
    # Jensen Shannon divergence in nats (natural logs). Bounded by ln(2).
    p = p.clamp(eps, 1.0)
    q = q.clamp(eps, 1.0)
    m = 0.5*(p+q)
    def kl(a,b): return torch.sum(a*(torch.log(a)-torch.log(b)), dim=-1)
    return 0.5*kl(p,m) + 0.5*kl(q,m)

def sample_with_uniform(P, u):
    c = torch.cumsum(P, dim=-1)
    idx = torch.searchsorted(c, u.unsqueeze(-1), right=False).squeeze(-1)
    return idx.clamp_(0, P.shape[-1]-1).long()

def topk_project(P, k=None):
    if k is None:
        return P
    if P.dim() == 1:
        vals, idx = torch.topk(P, k)
        Pk = torch.zeros_like(P)
        Pk[idx] = vals
        Pk = Pk / (Pk.sum() + 1e-12)
        return Pk
    else:
        vals, idx = torch.topk(P, k, dim=-1)
        Pk = torch.zeros_like(P)
        Pk.scatter_(dim=-1, index=idx, src=vals)
        Z = Pk.sum(dim=-1, keepdim=True) + 1e-12
        return Pk / Z

def power_plan_from_pilot(Xpilot, d_target=5e-3, alpha=0.05, power=0.8):
    sd = float(np.sqrt(np.mean(np.square(Xpilot)) + 1e-12))
    z_alpha = 1.959963984540054
    z_power = 0.8416212335729143
    N = ((z_alpha + z_power) * sd / max(d_target, 1e-12))**2
    return dict(sd=sd, N_required=int(math.ceil(N)))

def randomization_pvalue(X, B=2000, np_rng=None):
    if np_rng is None:
        np_rng = np.random.default_rng()
    Tobs = float(np.sum(X))
    R = np_rng.choice([-1.0, 1.0], size=(B, len(X)))
    Tr = R @ X
    p = float(np.mean(np.abs(Tr) >= abs(Tobs)))
    return dict(p_emp=p)

# ---------------- Test statistics ----------------
B_JS = math.log(2.0)

def e_process_stats(X):
    grid = np.geomspace(1e-3, 1.0, 50)
    S = 0.0
    V = 0.0
    Emax = 1.0
    for x in X:
        S += x
        V += x*x
        logs = grid*S - 0.5*(grid**2)*max(V, 1e-12)
        E = math.exp(np.max(logs))
        Emax = max(Emax, E)
    return dict(Emax=float(Emax), S=float(S), V=float(V), T=len(X))

def pooled_Z_pvalue(X):
    Tsum = float(np.sum(X))
    sd = math.sqrt(float(np.sum(X*X)) + 1e-12)
    Z = Tsum/sd if sd > 0 else 0.0
    p_pool = 2*(1-0.5*(1+math.erf(abs(Z)/math.sqrt(2))))
    return dict(Z=float(Z), p_pool=float(p_pool))

def prompt_level_stats(X_table):
    # X_table columns: [test, prompt, seed, step, result]
    prompts = np.unique(X_table[:, 1])
    prompt_means = []
    for pr in prompts:
        vals = X_table[X_table[:, 1] == pr, 4].astype(float)
        prompt_means.append(np.mean(vals))
    prompt_means = np.array(prompt_means, dtype=float)
    K = len(prompt_means)
    mu_hat = float(np.mean(prompt_means))
    s = float(np.std(prompt_means, ddof=1)) if K > 1 else 0.0
    se = s / math.sqrt(max(K, 1))
    tstat = mu_hat / se if se > 0 else 0.0
    p_t = 2*(1-st.t.cdf(abs(tstat), df=max(K-1, 1)))
    tcrit = float(st.t.ppf(0.975, df=max(K-1, 1)))
    ci = (mu_hat - tcrit*se, mu_hat + tcrit*se)
    return dict(K=K, mu_hat=mu_hat, se=se, tstat=tstat, p_t=p_t, ci=ci)

def azuma_coverage_from_sequences(X_seqs, alpha=0.05, b=B_JS):
    covered = 0
    total = len(X_seqs)
    for X in X_seqs:
        S = np.cumsum(X)
        ok = True
        for t, s_t in enumerate(S, start=1):
            band = b * math.sqrt(2.0 * t * math.log(2.0/alpha))
            if abs(s_t) > band + 1e-12:
                ok = False
                break
        if ok:
            covered += 1
    return dict(covered=covered, total=total, frac=covered/total if total>0 else 0.0)

# ---------------- Model loading (architecture agnostic) ----------------
def unload_model():
    for name in ["model", "tokz", "embed", "max_prompt_length", "max_batch"]:
        if name in globals():
            del globals()[name]
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    try:
        ctypes.CDLL("libc.so.6").malloc_trim(0)
    except Exception:
        pass

def prompt_length(prompt, tokenizer, add_special_tokens=True):
    tokens = tokenizer(prompt, add_special_tokens=add_special_tokens)
    return len(tokens["input_ids"])

def estimate_max_batch(tokz, prompts):
    prompt_len = max([prompt_length(prompt, tokz) for prompt in prompts])
    return prompt_len, 8

def get_model_data(model_name):
    unload_model()
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(DEVICE).eval()
    tokz  = AutoTokenizer.from_pretrained(model_name)
    if tokz.pad_token is None:
        tokz.pad_token = tokz.eos_token
    embed = model.get_input_embeddings()
    max_prompt_length, max_batch = estimate_max_batch(tokz, PROMPTS)
    return tokz, model, embed, max_prompt_length, max_batch

tokz, model, embed, max_prompt_length, max_batch = get_model_data(MODEL_NAME)
unload_model()

# ---------------- KV-cache helpers ----------------
@torch.no_grad()
def build_past_for_arm_batch(model, embed, tokens, inj_idx, v_dir, eps_h):
    # tokens: [B, T]
    B, T = tokens.shape
    device = tokens.device
    pos = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
    emb = embed(tokens)  # token embeddings only
    if eps_h != 0.0:
        batch_indices = torch.arange(B, device=device)
        emb = emb.clone()
        emb[batch_indices, inj_idx, :] += eps_h * v_dir
    out = model(inputs_embeds=emb, position_ids=pos, use_cache=True)
    past_legacy = out.past_key_values
    past_with_sib = tuple((k.unsqueeze(1), v.unsqueeze(1)) for k, v in past_legacy)
    return past_with_sib, out.logits[:, -1, :]

@torch.no_grad()
def one_step_from_past_batch(model, past_batch, next_ids):
    # past_batch: legacy with sibling dim already, [B, M, ...] per layer
    B, M_ = next_ids.shape
    past_flat = tuple(
        (k.reshape(B*M_, *k.shape[2:]), v.reshape(B*M_, *v.shape[2:]))
        for k, v in past_batch
    )
    cache = DynamicCache.from_legacy_cache(past_flat)
    next_ids_flat = next_ids.reshape(B*M_, 1)
    out = model(input_ids=next_ids_flat, use_cache=True, past_key_values=cache)
    new_past_legacy = out.past_key_values.to_legacy_cache() if isinstance(out.past_key_values, DynamicCache) else out.past_key_values
    new_past = tuple(
        (k.reshape(B, M_, *k.shape[1:]), v.reshape(B, M_, *v.shape[1:]))
        for k, v in new_past_legacy
    )
    logits = out.logits[:, -1, :].reshape(B, M_, -1)
    return new_past, logits

def repeat_past_batch(past, M_):
    repeated = []
    for k, v in past:
        k_rep = k.repeat(1, M_, 1, 1, 1)
        v_rep = v.repeat(1, M_, 1, 1, 1)
        repeated.append((k_rep, v_rep))
    return tuple(repeated)

def sample_with_uniform_batch(P, u):
    # P: [B, V], u: [B, M]
    B, V = P.shape
    M = u.shape[1]
    c = torch.cumsum(P, dim=-1)
    c_flat = c.repeat_interleave(M, dim=0)
    u_flat = u.reshape(-1, 1)
    idx_flat = torch.searchsorted(c_flat, u_flat, right=False).squeeze(-1)
    idx = idx_flat.view(B, M)
    return idx.clamp(0, V-1).long()

# ---------------- Direction v ----------------
def batch_get_v_dir(embed, ids, inj_idx, torch_rngs):
    B = ids.shape[0]
    H = embed.weight.shape[1]
    if not ADV_DIR:
        vs = []
        for rng in torch_rngs:
            v = torch.randn(H, device=DEVICE, generator=rng)
            v = v / (v.norm() + 1e-12)
            vs.append(v)
        return torch.stack(vs, dim=0)

    # Adversarial direction: grad of KL toward EOS at inj
    T0 = ids.shape[1]
    pos = torch.arange(T0, device=DEVICE).unsqueeze(0).expand(B, -1)
    with torch.enable_grad():
        emb = embed(ids).detach().clone().requires_grad_(True)
        out = model(inputs_embeds=emb, position_ids=pos, use_cache=False)
        logits = out.logits
        P_all = softmax_T32(logits, T_SOFT)
        Vocab = P_all.shape[-1]
        eos = tokz.eos_token_id or (Vocab - 1)
        if isinstance(inj_idx, int):
            inj_idx = torch.full((B,), inj_idx, dtype=torch.long, device=DEVICE)
        q = torch.zeros(B, Vocab, device=DEVICE)
        q[:, eos] = 1.0
        gather_p = P_all[torch.arange(B, device=DEVICE), inj_idx, :]
        KL = torch.sum(gather_p * (torch.log(gather_p + 1e-16) - torch.log(q + 1e-16)), dim=-1)
        KL.sum().backward()
        v = emb.grad[torch.arange(B, device=DEVICE), inj_idx, :].detach()

    norms = v.norm(dim=1, keepdim=True)
    mask = norms < 1e-12
    if mask.any():
        for i in torch.nonzero(mask, as_tuple=False).flatten():
            v[i] = torch.randn(H, device=DEVICE, generator=torch_rngs[i])
    v = v / (v.norm(dim=1, keepdim=True) + 1e-12)
    return v

# ---------------- Batched trajectory runner ----------------
def run_prompt_Xseq_kv_batch(model, embed, ids, batch_rngs, v_dir, inj, batch_prompts,
                             steps=STEPS_MAX, eps_h=EPS_H, mode="closed"):
    assert mode in {"closed", "open"}
    torch_rngs = [rng[0] for rng in batch_rngs]
    batch_seeds = [rng[2] for rng in batch_rngs]

    past0, L0 = build_past_for_arm_batch(model, embed, ids, inj, v_dir, eps_h=0.0)
    pastP, Lp = build_past_for_arm_batch(model, embed, ids, inj, v_dir, eps_h=+eps_h)
    pastM, Lm = build_past_for_arm_batch(model, embed, ids, inj, v_dir, eps_h=-eps_h)

    P0 = topk_project(softmax_T32(L0, T_SOFT), TOPK)
    Pp = topk_project(softmax_T32(Lp, T_SOFT), TOPK)
    Pm = topk_project(softmax_T32(Lm, T_SOFT), TOPK)

    D_curr_p = js_div(Pp, P0)
    D_curr_m = js_div(Pm, P0)

    X_list = []
    for _ in range(steps):
        if mode == "closed":
            u = torch.stack([torch.rand(M_SIB, device=DEVICE, generator=g) for g in torch_rngs], dim=0)
            tok0 = sample_with_uniform_batch(P0, u)
            tokp = sample_with_uniform_batch(Pp, u)
            tokm = sample_with_uniform_batch(Pm, u)
        else:
            u0 = torch.stack([torch.rand(M_SIB, device=DEVICE, generator=g) for g in torch_rngs], dim=0)
            up = torch.stack([torch.rand(M_SIB, device=DEVICE, generator=g) for g in torch_rngs], dim=0)
            um = torch.stack([torch.rand(M_SIB, device=DEVICE, generator=g) for g in torch_rngs], dim=0)
            tok0 = sample_with_uniform_batch(P0, u0)
            tokp = sample_with_uniform_batch(Pp, up)
            tokm = sample_with_uniform_batch(Pm, um)

        past0_rep = repeat_past_batch(past0, M_SIB)
        pastP_rep = repeat_past_batch(pastP, M_SIB)
        pastM_rep = repeat_past_batch(pastM, M_SIB)

        _, L02 = one_step_from_past_batch(model, past0_rep, tok0)
        _, Lp2 = one_step_from_past_batch(model, pastP_rep, tokp)
        _, Lm2 = one_step_from_past_batch(model, pastM_rep, tokm)

        Pb2 = topk_project(softmax_T32(L02, T_SOFT), TOPK)
        Pp2 = topk_project(softmax_T32(Lp2, T_SOFT), TOPK)
        Pm2 = topk_project(softmax_T32(Lm2, T_SOFT), TOPK)

        D_next_p = js_div(Pp2, Pb2).mean(dim=1)
        D_next_m = js_div(Pm2, Pb2).mean(dim=1)

        X_t = 0.5 * ((D_next_p - D_curr_p) - (D_next_m - D_curr_m))
        X_list.append(X_t)

        tok0_first = tok0[:, 0:1]
        past0, L0 = one_step_from_past_batch(model, past0, tok0_first)
        pastP, Lp = one_step_from_past_batch(model, pastP, tokp[:, 0:1])
        pastM, Lm = one_step_from_past_batch(model, pastM, tokm[:, 0:1])

        P0 = topk_project(softmax_T32(L0, T_SOFT), TOPK).squeeze(1)
        Pp = topk_project(softmax_T32(Lp, T_SOFT), TOPK).squeeze(1)
        Pm = topk_project(softmax_T32(Lm, T_SOFT), TOPK).squeeze(1)

        D_curr_p = js_div(Pp, P0)
        D_curr_m = js_div(Pm, P0)

    result = torch.stack(X_list, dim=1).cpu().numpy()  # [B, steps]
    B, steps = result.shape

    steps_arr = np.repeat(np.arange(steps, dtype=float)[None, :], B, axis=0)
    seeds_arr = np.repeat(np.array(batch_seeds)[:, None], steps, axis=1)
    prompts_arr = np.repeat(np.array(batch_prompts, dtype=object)[:, None], steps, axis=1)

    final = np.column_stack([prompts_arr.reshape(-1),
                             seeds_arr.reshape(-1),
                             steps_arr.reshape(-1),
                             result.reshape(-1)])
    return final

def batch_pooled_X(prompts, rngs, mode="closed", eps_h=EPS_H, max_batch=8):
    allX = []
    pairs = [(p, r) for p in prompts for r in rngs]

    for i in range(0, len(pairs), max_batch):
        batch_pairs = pairs[i:i+max_batch]
        batch_prompts = [p for (p, _) in batch_pairs]
        batch_rngs    = [r for (_, r) in batch_pairs]
        torch_rngs = [rng[0] for rng in batch_rngs]

        ids_list = [tokz(p, return_tensors="pt", padding="max_length",
                         add_special_tokens=False)["input_ids"][:, :max_prompt_length]
                    for p in batch_prompts]
        ids = torch.cat(ids_list, dim=0).to(DEVICE)
        T0 = ids.shape[1]
        inj = T0 // 2

        v_dir = batch_get_v_dir(embed, ids, inj_idx=inj, torch_rngs=torch_rngs)
        X_batch = run_prompt_Xseq_kv_batch(model, embed, ids, batch_rngs, v_dir, inj,
                                           batch_prompts, steps=STEPS_MAX, eps_h=eps_h, mode=mode)
        allX.append(X_batch)

    return np.concatenate(allX, axis=0) if allX else np.array([])

# For Azuma coverage we need per-run sequences.
def run_prompt_Xseq_single(prompt, steps=STEPS_MAX, eps_h=EPS_H, mode="closed", rng=None):
    torch_rng, _, _ = rng
    ids = tokz(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"][:, :PROMPT_LEN].to(DEVICE)
    T0 = ids.shape[1]
    inj = T0 // 2

    v = torch.randn(embed.weight.shape[1], device=DEVICE, generator=torch_rng)
    v = v / (v.norm() + 1e-12)

    pos = torch.arange(T0, device=DEVICE).unsqueeze(0)
    emb0 = embed(ids)
    embp = emb0.clone()
    embm = emb0.clone()
    embp[0, inj, :] += eps_h * v
    embm[0, inj, :] -= eps_h * v

    out0 = model(inputs_embeds=emb0, position_ids=pos, use_cache=True)
    outp = model(inputs_embeds=embp, position_ids=pos, use_cache=True)
    outm = model(inputs_embeds=embm, position_ids=pos, use_cache=True)

    past0 = out0.past_key_values
    pastp = outp.past_key_values
    pastm = outm.past_key_values

    L0 = out0.logits[:, -1, :]
    Lp = outp.logits[:, -1, :]
    Lm = outm.logits[:, -1, :]

    P0 = topk_project(softmax_T32(L0, T_SOFT).squeeze(0), TOPK)
    Pp = topk_project(softmax_T32(Lp, T_SOFT).squeeze(0), TOPK)
    Pm = topk_project(softmax_T32(Lm, T_SOFT).squeeze(0), TOPK)

    D_curr_p = float(js_div(Pp.unsqueeze(0), P0.unsqueeze(0)))
    D_curr_m = float(js_div(Pm.unsqueeze(0), P0.unsqueeze(0)))

    X_list = []
    for _ in range(steps):
        if mode == "closed":
            u = torch.rand(M_SIB, device=DEVICE, generator=torch_rng)
            tok0 = sample_with_uniform(P0, u)
            tokp = sample_with_uniform(Pp, u)
            tokm = sample_with_uniform(Pm, u)
        else:
            tok0 = sample_with_uniform(P0, torch.rand(M_SIB, device=DEVICE, generator=torch_rng))
            tokp = sample_with_uniform(Pp, torch.rand(M_SIB, device=DEVICE, generator=torch_rng))
            tokm = sample_with_uniform(Pm, torch.rand(M_SIB, device=DEVICE, generator=torch_rng))

        # repeat legacy pasts
        def rep(past):
            legacy = past.to_legacy_cache() if isinstance(past, DynamicCache) else past
            return tuple((k.repeat(M_SIB,1,1,1), v_.repeat(M_SIB,1,1,1)) for (k, v_) in legacy)

        past0_rep = rep(past0); pastp_rep = rep(pastp); pastm_rep = rep(pastm)

        def step_from(past_rep, toks):
            cache = DynamicCache.from_legacy_cache(past_rep)
            out = model(input_ids=toks.view(-1,1), use_cache=True, past_key_values=cache)
            legacy = out.past_key_values.to_legacy_cache() if isinstance(out.past_key_values, DynamicCache) else out.past_key_values
            return legacy, out.logits[:, -1, :]

        _, L02 = step_from(past0_rep, tok0)
        _, Lp2 = step_from(pastp_rep, tokp)
        _, Lm2 = step_from(pastm_rep, tokm)

        Pb2 = topk_project(softmax_T32(L02, T_SOFT), TOPK)
        Pp2 = topk_project(softmax_T32(Lp2, T_SOFT), TOPK)
        Pm2 = topk_project(softmax_T32(Lm2, T_SOFT), TOPK)

        D_next_p = float(js_div(Pp2, Pb2).mean())
        D_next_m = float(js_div(Pm2, Pb2).mean())
        X_t = 0.5 * ((D_next_p - D_curr_p) - (D_next_m - D_curr_m))
        X_list.append(X_t)

        # advance with first sibling
        def adv(past, tok):
            legacy = past.to_legacy_cache() if isinstance(past, DynamicCache) else past
            cache = DynamicCache.from_legacy_cache(legacy)
            out = model(input_ids=tok.view(1,1), use_cache=True, past_key_values=cache)
            return out.past_key_values, out.logits[:, -1, :]

        past0, L0 = adv(past0, tok0[0])
        pastp, Lp = adv(pastp, tokp[0])
        pastm, Lm = adv(pastm, tokm[0])

        P0 = topk_project(softmax_T32(L0, T_SOFT).squeeze(0), TOPK)
        Pp = topk_project(softmax_T32(Lp, T_SOFT).squeeze(0), TOPK)
        Pm = topk_project(softmax_T32(Lm, T_SOFT).squeeze(0), TOPK)

        D_curr_p = float(js_div(Pp.unsqueeze(0), P0.unsqueeze(0)))
        D_curr_m = float(js_div(Pm.unsqueeze(0), P0.unsqueeze(0)))

    return np.array(X_list, dtype=float)

# ---------------- Test runners ----------------
def probe_test_batch(start_msg, prompts, rngs, eps_h, mode, is_pilot, needs_verdict):
    print(start_msg)

    X_raw = batch_pooled_X(prompts, rngs, mode=mode, eps_h=eps_h, max_batch=max_batch)
    X_R = X_raw[:, 3].astype(float)

    X_table = np.concatenate(
        [np.full((X_raw.shape[0], 1), start_msg, dtype=object), X_raw],
        axis=1
    )

    e_stats = e_process_stats(X_R)
    z_stats = pooled_Z_pvalue(X_R)

    if is_pilot:
        plan = power_plan_from_pilot(X_R, d_target=5e-3, alpha=ALPHA, power=0.8)
        msg = f"T_pilot={len(X_R)}  sd≈{plan['sd']:.3e}  N_required(d=5e-3, α=0.05, 0.8 pow)≈{plan['N_required']}"
        print(msg)
        return msg, X_table, dict(**e_stats, **z_stats), plan, X_R

    pstats = prompt_level_stats(X_table)

    msg = (
        f"K={pstats['K']}  mean_drift={pstats['mu_hat']:.3e}  "
        f"95% CI=({pstats['ci'][0]:.3e}, {pstats['ci'][1]:.3e})  "
        f"t-test p={pstats['p_t']:.3e}  "
        f"Emax={e_stats['Emax']:.3f}  Z={z_stats['Z']:.3f}  p_pool={z_stats['p_pool']:.3f}"
    )

    if needs_verdict:
        verdict = "REJECT" if (pstats['p_t'] <= ALPHA or e_stats['Emax'] >= 1.0/ALPHA) else "NEUTRAL"
        msg = f"{msg}  verdict={verdict}"

    print(msg)
    return msg, X_table, dict(**e_stats, **z_stats, **pstats), None, X_R

def collect_closed_sequences(prompts, rngs):
    seqs = []
    for r in rngs:
        for p in prompts:
            seqs.append(run_prompt_Xseq_single(p, steps=STEPS_MAX, eps_h=EPS_H, mode="closed", rng=r))
    return seqs

def randomization_pvalue_test(start_msg, X_closed, B, np_rng):
    print(start_msg)
    rand_chk = randomization_pvalue(X_closed, B=B, np_rng=np_rng)
    msg = f"empirical p≈{rand_chk['p_emp']:.3f}  (≈0.5 under neutrality)"
    print(msg)
    return msg, rand_chk

def execute_tests(models, test_params):
    all_results = []
    all_msgs = []
    for modelname in models:
        print(f"\n=== Testing model: {modelname} ===\n")
        all_msgs.append(f"\n=== Testing model: {modelname} ===\n")

        global tokz, model, embed, max_prompt_length, max_batch
        tokz, model, embed, max_prompt_length, max_batch = get_model_data(modelname)

        pilot = test_params['pilot_test']
        probes = test_params['probe_test']
        randomization_pvalue_params = test_params['randomization_pvalue_test']

        results = []
        all_msgs.append(pilot['start_msg'])
        pilot_result = probe_test_batch(**pilot)
        all_msgs.append(pilot_result[0])
        results.append(pilot_result[1])

        rngs = RNGS
        print(f"Number of test seeds: {len(rngs)}")
        all_msgs.append(f"Number of test seeds: {len(rngs)}")

        closed_X_R = None

        for p in probes:
            p['rngs'] = rngs
            all_msgs.append(p['start_msg'])
            res = probe_test_batch(**p)
            all_msgs.append(res[0])
            results.append(res[1])
            if p is closedpool:
                closed_X_R = res[4]
                for p2 in randomization_pvalue_params:
                    p2['X_closed'] = closed_X_R

        if closed_X_R is not None:
            closed_seqs = collect_closed_sequences(PROMPTS, rngs)
            az = azuma_coverage_from_sequences(closed_seqs, alpha=ALPHA)
            az_msg = f"Azuma coverage={az['covered']}/{az['total']} ({100*az['frac']:.1f}%)"
            print(az_msg)
            all_msgs.append(az_msg)

        probe_results = np.concatenate(results, axis=0)
        probe_results = np.core.records.fromarrays(
            [np.full(probe_results.shape[0], modelname),
             probe_results[:,0],
             probe_results[:,1],
             probe_results[:,2],
             probe_results[:,3],
             probe_results[:,4]],
            names='model, test, prompt, seed, step, result'
        )

        for p in randomization_pvalue_params:
            all_msgs.append(p['start_msg'])
            msg, _ = randomization_pvalue_test(**p)
            all_msgs.append(msg)

        all_results.append(probe_results)
        unload_model()

    all_results = np.concatenate(all_results, axis=0)
    all_msgs = np.array(all_msgs, dtype=object)
    return all_results, all_msgs

# ---------------- Test configurations ----------------
pilot = {
    'start_msg': "=== Pilot (closed, pooled over first 2 prompts × 1 seed) ===",
    'prompts': PROMPTS[:2],
    'rngs': [RNGS[0]],
    'eps_h': EPS_H,
    'mode': 'closed',
    'is_pilot': True,
    'needs_verdict': False
}
closedpool = {
    'start_msg': "\n=== Closed probe (pooled across prompts × seeds) ===",
    'prompts': PROMPTS,
    'rngs': RNGS,
    'eps_h': EPS_H,
    'mode': 'closed',
    'is_pilot': False,
    'needs_verdict': True
}
openpool = {
    'start_msg': "\n=== Open probe (pooled across prompts × seeds) ===",
    'prompts': PROMPTS,
    'rngs': RNGS,
    'eps_h': EPS_H,
    'mode': 'open',
    'is_pilot': False,
    'needs_verdict': True
}
placebo = {
    'start_msg': "\n=== Placebo ε=0 (closed, pooled) ===",
    'prompts': PROMPTS,
    'rngs': RNGS,
    'eps_h': 0.0,
    'mode': 'closed',
    'is_pilot': False,
    'needs_verdict': False
}
randomization_pvalue_params = {
    'start_msg': "\n=== Label randomization p-value (closed, pooled) ===",
    'X_closed': None,
    'B': 2000,
    'np_rng': BASE_RNG.spawn(1)[0]
}

probe_params = [closedpool, openpool, placebo]

test_params = {
    'pilot_test': pilot,
    'probe_test': probe_params,
    'randomization_pvalue_test': [randomization_pvalue_params]
}

models = [MODEL_NAME]

# ---------------- Run ----------------
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("NumPy:", np.__version__)
print("Using device:", DEVICE)
print("Using base seed:", BASE_SEED)

test_result, msgs = execute_tests(models, test_params)
print("finished")