# -*- coding: utf-8 -*-
"""ablom (Sep 20, 2025, 4:50:48 PM)

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/embedded/projects/prj-prd-data-learning-ddb6/locations/europe-west4/repositories/ce0ee1f7-19db-4562-b14f-52907a2e3e70
"""

# -*- coding: utf-8 -*-


# === Colab single cell: Fast cap-free neutrality audit on a small GPT ===
# Includes:
#   - Closed and open probes (CRN) with KV cache and 3-arm batching
#   - Trajectory-as-agent pooled tests (anytime e-value + fixed-horizon p-value)
#   - Reliability checks: pilot power, placebo ε=0, label randomization p
#   - Layer-as-agent MFG test (finite-difference residual gain per block; small, non-KV for correctness)
# Notes:
#   - No Jacobians. No cap constants. Uses only logits, tokens, CRN. 13111311
#   - On CPU: keep defaults. On GPU: raise M_SIB, STEPS_MAX, and consider ADV_DIR=True.
#   - Optional TOPK vocab projection speeds up CPU by ~10–25× with minimal effect on decisions.

!pip -q install transformers>=4.44.0
!pip -q install torch>=2.0.0

# @title
import math
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp
mp.set_start_method("spawn", force=True)

import os
import subprocess

# 1. Set environment variables so Hugging Face respects the new cache
os.environ["HF_HOME"] = "/content/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/content/hf_cache"

# 2. Make sure the directory exists
os.makedirs("/content/hf_cache", exist_ok=True)

# 3. If anything is already cached in /root/.cache/huggingface, move it
subprocess.run("mkdir -p /content/hf_cache && mv /root/.cache/huggingface/* /content/hf_cache/ 2>/dev/null || true", shell=True)

# 4. Replace the old cache folder with a symlink
subprocess.run("rm -rf /root/.cache/huggingface && ln -s /content/hf_cache /root/.cache/huggingface", shell=True)

# ---------------- Config ----------------
MODEL_NAME   = "distilgpt2"
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"

# Probe params
T_SOFT       = 1.7
EPS_H        = 5e-3           # ε hidden nudge
M_SIB        = 16             # siblings per step for trajectory tests (raise on GPU)
STEPS_MAX    = 16             # decoding steps per run (raise on GPU)
PROMPT_LEN   = 96
ADV_DIR      = False          # True = adversarial direction at inject index (costs one backward per prompt)

# Optional CPU speedup: vocab top-K projection
TOPK         = 2000 if DEVICE == "cpu" else None  # None or integer like 2000; set None on GPU

# Params for rng
BASE_SEED = 12345
BASE_RNG = np.random.default_rng(BASE_SEED) if BASE_SEED is not None else np.random.default_rng()
NUM_SEEDS = 2 # increase for more power
SEEDS = BASE_RNG.integers(0, 2**32, size=NUM_SEEDS)
RNGS = [(torch.Generator(device=DEVICE).manual_seed(int(s)),np.random.default_rng(int(s)),s) for s in SEEDS]

# Experiment params
ALPHA        = 0.05
PROMPTS = [
  "Discuss the main factors that influenced the Industrial Revolution in Britain and their long-term consequences for global trade and technology.",
  "Summarize the core ideas of Maxwell's equations and how they unify electricity, magnetism, and light, with a brief modern application.",
  "Explain the chain rule of calculus and give a nontrivial example involving nested functions from physics.",
  "Describe the key mechanisms of photosynthesis and why the Z-scheme is important for energy transfer.",
  "Provide a short overview of reinforcement learning, including policy gradients and value-based methods, and compare their trade-offs."
]

# Nonadjustable MFG layer-as-agent params (small for speed; uses full forward, not KV)
MFG_M_SIB = 8
MFG_STEPS = 8

# @title
# ---------------- Utilities ----------------
def softmax_T32(logits, T):
    z = logits / T
    z = z - z.amax(dim=-1, keepdim=True)
    return torch.softmax(z, dim=-1)

def js_div(p, q, eps=1e-7):
    # p,q: [..., V]
    p = p.clamp(eps, 1.0)
    q = q.clamp(eps, 1.0)
    m = 0.5*(p+q)
    def kl(a,b): return torch.sum(a*(torch.log(a)-torch.log(b)), dim=-1)
    return 0.5*kl(p,m) + 0.5*kl(q,m)

def sample_with_uniform(P, u):
    # P: [V] simplex; u: [m] in [0,1)
    c = torch.cumsum(P, dim=-1)
    idx = torch.searchsorted(c, u.unsqueeze(-1), right=False).squeeze(-1)
    return idx.clamp_(0, P.shape[-1]-1).long()

def e_process_and_pval_numpy(X):
    # X: numpy array
    grid = np.geomspace(1e-3, 1.0, 50)
    S=0.0; V=0.0; Emax=1.0
    for t,x in enumerate(X,1):
        inc = x if t==1 else x - X[t-2]
        V += inc*inc; S += x
        logs = grid*S - 0.5*(grid**2)*max(V,1e-12)
        E = math.exp(np.max(logs))
        Emax = max(Emax, E)
    Tsum = float(np.sum(X))
    sd   = math.sqrt(np.sum(np.square(X))+1e-12)
    Z    = Tsum/sd if sd>0 else 0.0
    p    = 2*(1-0.5*(1+math.erf(abs(Z)/math.sqrt(2))))
    return dict(T=len(X), mean=float(np.mean(X)), sd=float(np.std(X, ddof=1)),
                Emax=float(Emax), Z=float(Z), p=float(p))

def randomization_pvalue(X, B=2000, np_rng = None):
    if np_rng is None:
        np_rng = np.random.default_rng()
    Tobs = float(np.sum(X))
    R = np_rng.choice([-1.0,1.0], size=(B, len(X)))
    Tr = R @ X
    p = float(np.mean(np.abs(Tr) >= abs(Tobs)))
    return dict(p_emp=p)

def power_plan_from_pilot(Xpilot, d_target=5e-3, alpha=0.05, power=0.8):
    sd = float(np.sqrt(np.mean(np.square(Xpilot)) + 1e-12))
    z_alpha = 1.959963984540054
    z_power = 0.8416212335729143
    N = ((z_alpha + z_power) * sd / max(d_target, 1e-12))**2
    return dict(sd=sd, N_required=int(math.ceil(N)))

def topk_project(P, k=None):
    if k is None: return P
    if P.dim()==1:
        vals, idx = torch.topk(P, k)
        Pk = torch.zeros_like(P)
        Pk[idx] = vals
        Pk = Pk / (Pk.sum() + 1e-12)
        return Pk
    else:
        vals, idx = torch.topk(P, k, dim=-1)
        Pk = torch.zeros_like(P)
        Pk.scatter_(dim=-1, index=idx, src=vals)
        Z = Pk.sum(dim=-1, keepdim=True) + 1e-12
        return Pk / Z

import gc
import ctypes

def prompt_length(prompt, tokenizer, add_special_tokens=True):
    """
    Returns the number of tokens in a prompt.

    add_special_tokens: whether to count tokens like <bos>/<eos> etc.
    """
    tokens = tokenizer(prompt, add_special_tokens=add_special_tokens)
    return len(tokens["input_ids"])

def estimate_max_batch(model, device, tokz, prompts, seed=42, max_trials=10):
    """
    Estimate the max batch size that fits in memory for a given model.

    model_fn : function returning a fresh model instance
    device   : 'cuda' or 'cpu'
    prompt_len : max sequence length
    """
    prompt_len = max([prompt_length(prompt, tokz) for prompt in prompts])
    batch_size = 1
    success = True
    model_fn = lambda: model

    while success:
        try:
            # Make dummy input: shape [batch, seq_len]
            dummy_input = torch.randint(
                low=0, high=1000, size=(batch_size, prompt_len, STEPS_MAX), device=device
            )
            model = model_fn().to(device)
            # Forward pass (no grad if just estimating)
            with torch.no_grad():
                out = model(dummy_input)
            batch_size *= 2  # try doubling
        except RuntimeError as e:
            if "out of memory" in str(e):
                success = False
            else:
                raise
        finally:
            torch.cuda.empty_cache() if device != 'cpu' else None

    # Binary search for exact max
    low, high = batch_size // 2, batch_size
    while low < high - 1:
        mid = (low + high) // 2
        try:
            dummy_input = torch.randint(
                low=0, high=1000, size=(mid, prompt_len, STEPS_MAX), device=device
            )
            model = model_fn().to(device)
            with torch.no_grad():
                out = model(dummy_input)
            low = mid
        except RuntimeError as e:
            if "out of memory" in str(e):
                high = mid
            else:
                raise
        finally:
            torch.cuda.empty_cache() if device != 'cpu' else None

    return prompt_len,low

def unload_model(refs = []):
    """
    Cleanly unload a Hugging Face / PyTorch model to free memory.
    Works on both CPU and GPU setups.
    """
    for ref in refs:
      try:
          for name, obj in list(globals().items()):
            if obj is model:
                del globals()[name]
      except NameError:
          pass  # model was already gone

    gc.collect()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()  # helps free memory across processes

    libc = ctypes.CDLL("libc.so.6")
    libc.malloc_trim(0)

def get_model_data(model_name):
  existing = ['tokz', 'model', 'blocks', 'ln_f', 'wte', 'wpe', 'max_prompt_length', 'max_batch']
  to_del = []
  for vname in existing:
    if vname in globals().keys():
      to_del.append(vname)
  unload_model(to_del)
  try:
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(DEVICE).eval()
    tokz  = AutoTokenizer.from_pretrained(model_name)
    if tokz.pad_token is None:
      tokz.pad_token = tokz.eos_token
    blocks = model.transformer.h
    ln_f   = model.transformer.ln_f
    wte, wpe = model.transformer.wte, model.transformer.wpe
    max_prompt_length, max_batch = estimate_max_batch(model, DEVICE, tokz, PROMPTS)
    return tokz, model, blocks, ln_f, wte, wpe, max_prompt_length, max_batch
  except:
    unload_model()
    print(f"Model could not be loaded: {model_name}")

# @title
# ---------------- Model ----------------
tokz  = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokz.pad_token is None: tokz.pad_token = tokz.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32).to(DEVICE).eval()
blocks = model.transformer.h
ln_f   = model.transformer.ln_f
wte, wpe = model.transformer.wte, model.transformer.wpe
max_prompt_length, max_batch = estimate_max_batch(model, DEVICE, tokz, PROMPTS)
unload_model([tokz, model, blocks, ln_f, wte, wpe, max_prompt_length, max_batch])

# @title
# ---------------- KV-cache helpers (fast trajectory tests) ----------------
@torch.no_grad()
def build_past_for_arm(model, tokens, inj_idx, v_dir, eps_h):
    pos = torch.arange(tokens.shape[1], device=tokens.device).unsqueeze(0)
    emb = model.transformer.wte(tokens) + model.transformer.wpe(pos)
    if eps_h != 0.0:
        emb = emb.clone()
        emb[0, inj_idx, :] = emb[0, inj_idx, :] + eps_h * v_dir
    out = model(inputs_embeds=emb, use_cache=True)
    return out.past_key_values, out.logits[:, -1, :]

@torch.no_grad()
def one_step_from_past(model, past_kv, next_ids):
    if type(past_kv) is tuple:
        past_kv = DynamicCache.from_legacy_cache(past_kv)
    out = model(input_ids=next_ids, use_cache=True, past_key_values=past_kv)
    return out.past_key_values, out.logits[:, -1, :]

def repeat_past(past, m):
    # expand batch dimension (dim=0) from 1 → m for all layers
    # each layer entry is (k,v) with shape [B, H, S, D]
    return tuple((k.repeat(m, 1, 1, 1), v.repeat(m, 1, 1, 1)) for (k, v) in past)

# @title
# ---------------- Direction v (adversarial or random) ----------------
def get_v_dir(ids, inj_idx, torch_rng):
    if not ADV_DIR:
        v = torch.randn(wte.weight.shape[1], device=DEVICE, generator=torch_rng)
        return v/(v.norm()+1e-12)
    # one backward pass to get adversarial direction at inj
    T0 = ids.shape[1]
    pos = torch.arange(T0, device=DEVICE).unsqueeze(0)
    with torch.enable_grad():
        emb = (wte(ids)+wpe(pos)).detach().clone().requires_grad_(True)
        x = emb
        for blk in blocks: x = blk(x)[0]
        h = ln_f(x); logits = model.lm_head(h)
        P_all = softmax_T32(logits[0], T_SOFT)
        Vocab = P_all.shape[-1]
        eos = tokz.eos_token_id or (Vocab-1)
        q = torch.zeros(Vocab, device=DEVICE); q[eos] = 1.0
        KL = torch.sum(P_all[inj_idx]*(torch.log(P_all[inj_idx]+1e-16)-torch.log(q+1e-16)))
        KL.backward()
        v = emb.grad[0, inj_idx, :].detach()
    if v.norm() < 1e-12: v = torch.randn_like(v)
    return v/(v.norm()+1e-12)

# @title
# ---------------- Trajectory runner with KV (mode='closed' or 'open') ----------------
def run_prompt_Xseq_kv(prompt, steps=STEPS_MAX, eps_h=EPS_H, mode="closed", rng=None):
    if rng is None:
        torch_rng = torch.Generator(device=DEVICE)
        np_rng = np.random.default_rng()
        rng = (torch_rng, np_rng)
    else:
        torch_rng, np_rng, seed = rng
    assert mode in {"closed","open"}
    ids = tokz(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"][:,:PROMPT_LEN].to(DEVICE)
    T0  = ids.shape[1]; inj = T0//2
    v_dir = get_v_dir(ids, inj, torch_rng=torch_rng)

    # initial past for baseline and both nudges
    past0, L0 = build_past_for_arm(model, ids, inj, v_dir, eps_h=0.0)
    pastP, Lp = build_past_for_arm(model, ids, inj, v_dir, eps_h=+eps_h)
    pastM, Lm = build_past_for_arm(model, ids, inj, v_dir, eps_h=-eps_h)

    P0 = topk_project(softmax_T32(L0, T_SOFT).squeeze(0), TOPK)
    Pp = topk_project(softmax_T32(Lp, T_SOFT).squeeze(0), TOPK)
    Pm = topk_project(softmax_T32(Lm, T_SOFT).squeeze(0), TOPK)

    D_curr_p = float(js_div(Pp.unsqueeze(0), P0.unsqueeze(0)))
    D_curr_m = float(js_div(Pm.unsqueeze(0), P0.unsqueeze(0)))

    X_list = []
    for t in range(steps):
        u = torch.rand(M_SIB, device=DEVICE, generator=torch_rng)
        if mode == "closed":
            tok0 = sample_with_uniform(P0, u)
            tokp = sample_with_uniform(Pp, u)
            tokm = sample_with_uniform(Pm, u)

            past0_rep = repeat_past(past0, M_SIB)
            pastP_rep = repeat_past(pastP, M_SIB)
            pastM_rep = repeat_past(pastM, M_SIB)
            _, L02 = one_step_from_past(model, past0_rep, tok0.view(-1,1))
            _, Lp2 = one_step_from_past(model, pastP_rep, tokp.view(-1,1))
            _, Lm2 = one_step_from_past(model, pastM_rep, tokm.view(-1,1))

            Pb2 = topk_project(softmax_T32(L02, T_SOFT), TOPK)
            Pp2 = topk_project(softmax_T32(Lp2, T_SOFT), TOPK)
            Pm2 = topk_project(softmax_T32(Lm2, T_SOFT), TOPK)
        else:
            tok0 = sample_with_uniform(P0, u)
            past0_rep = repeat_past(past0, M_SIB)
            _, L02 = one_step_from_past(model, past0_rep, tok0.view(-1,1))
            Pb2 = topk_project(softmax_T32(L02, T_SOFT), TOPK)
            Pp2 = Pb2
            Pm2 = Pb2

        D_next_p = float(js_div(Pp2, Pb2).mean())
        D_next_m = float(js_div(Pm2, Pb2).mean())
        X_t = 0.5*((D_next_p - D_curr_p) - (D_next_m - D_curr_m))
        X_list.append(X_t)

        # advance single-token contexts (first sibling) to refresh P0,Pp,Pm
        tok0_1 = tok0[0:1].view(1,1)
        past0, L0 = one_step_from_past(model, past0, tok0_1)
        if mode == "closed":
            pastP, Lp = one_step_from_past(model, pastP, tokp[0:1].view(1,1))
            pastM, Lm = one_step_from_past(model, pastM, tokm[0:1].view(1,1))
        else:
            pastP, Lp = one_step_from_past(model, pastP, tok0_1)
            pastM, Lm = one_step_from_past(model, pastM, tok0_1)

        P0 = topk_project(softmax_T32(L0, T_SOFT).squeeze(0), TOPK)
        Pp = topk_project(softmax_T32(Lp, T_SOFT).squeeze(0), TOPK)
        Pm = topk_project(softmax_T32(Lm, T_SOFT).squeeze(0), TOPK)
        D_curr_p = float(js_div(Pp.unsqueeze(0), P0.unsqueeze(0)))
        D_curr_m = float(js_div(Pm.unsqueeze(0), P0.unsqueeze(0)))

    return np.array(X_list, dtype=float)

# Old, sequential version
def pooled_X(prompts, rngs, mode="closed", eps_h=EPS_H):
    allX=[]
    for rng in rngs:
        for p in prompts:
            # print(p, STEPS_MAX, eps_h, mode, rng)
            X = run_prompt_Xseq_kv(p, steps=STEPS_MAX, eps_h=eps_h, mode=mode, rng=rng)
            allX.append(X)
    return np.concatenate(allX, axis=0) if allX else np.array([])


# def x_worker(args):
#     p, seed, mode, eps_h = args
#     torch_rng = torch.Generator(device=DEVICE).manual_seed(seed)
#     np_rng = np.random.default_rng(seed)
#     rng = (torch_rng,np_rng)
#     return run_prompt_Xseq_kv(
#         p,
#         steps=STEPS_MAX,
#         eps_h=eps_h,
#         mode=mode,
#         rng=rng,
#     )

# def pooled_X(prompts, rngs, mode="closed", eps_h=EPS_H):
#   from concurrent.futures import ProcessPoolExecutor
#   import numpy as np

#   tasks = [(p, seed[2], mode, eps_h) for seed in rngs for p in prompts]
#   print(tasks)
#   with ProcessPoolExecutor() as ex:
#       results = list(ex.map(x_worker, tasks))

#   return np.concatenate(results, axis=0) if results else np.array([])

# @title


# ---------------- Layer-as-agent MFG (finite-difference residual gain; small, non-KV) ----------------
def forward_with_gain_full(batch_tokens, layer_idx=None, delta_g=0.0):
    # Full forward with optional residual scaling at layer_idx
    x = wte(batch_tokens) + wpe(torch.arange(batch_tokens.shape[1], device=DEVICE).unsqueeze(0))
    for j in range(len(blocks)):
        x_in = x
        x = blocks[j](x)[0]
        if layer_idx is not None and j == layer_idx:
            delta = x - x_in
            x = x_in + (1.0 + delta_g) * delta
    h = ln_f(x)
    logits = model.lm_head(h)
    return logits

def run_prompt_Xseq_layergain(prompt, steps=MFG_STEPS, eps_h=EPS_H, layer_idx=None, delta_g=0.0, rng = None):
    if rng is None:
        torch_rng = torch.Generator(device=DEVICE)
        np_rng = np.random.default_rng()
        rng = (torch_rng, np_rng)
    else:
        torch_rng, np_rng,seed = rng
    # Non-KV small runner for MFG actions
    ids = tokz(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"][:,:PROMPT_LEN].to(DEVICE)
    T0  = ids.shape[1]; inj = T0//2

    v_dir = get_v_dir(ids, inj, torch_rng=torch_rng) if ADV_DIR else (torch.randn(wte.weight.shape[1], device=DEVICE, generator=torch_rng)/(torch.randn(1,device=DEVICE, generator=torch_rng).abs()+1e-12))
    v_dir = v_dir/(v_dir.norm()+1e-12)

    base = ids.clone()

    # current distributions
    def arm_dist(sign):
        pos = torch.arange(base.shape[1], device=DEVICE).unsqueeze(0)
        emb = wte(base)+wpe(pos)
        if sign!=0 and eps_h!=0.0:
            emb = emb.clone()
            emb[0, inj, :] = emb[0, inj, :] + sign*eps_h*v_dir
        # forward with gain
        x = emb
        for j in range(len(blocks)):
            x_in = x
            x = blocks[j](x)[0]
            if layer_idx is not None and j == layer_idx:
                delta = x - x_in
                x = x_in + (1.0 + delta_g) * delta
        h = ln_f(x); lg = model.lm_head(h)
        return softmax_T32(lg[0,-1,:], T_SOFT)

    P0 = topk_project(arm_dist(0), TOPK)
    Pp = topk_project(arm_dist(+1), TOPK)
    Pm = topk_project(arm_dist(-1), TOPK)
    D_curr_p = float(js_div(Pp.unsqueeze(0), P0.unsqueeze(0)).detach())
    D_curr_m = float(js_div(Pm.unsqueeze(0), P0.unsqueeze(0)).detach())

    X_list=[]
    for t in range(steps):
        u = torch.rand(MFG_M_SIB, device=DEVICE, generator=torch_rng)
        tok0 = sample_with_uniform(P0, u)
        tokp = sample_with_uniform(Pp, u)
        tokm = sample_with_uniform(Pm, u)

        # sibling batches
        mb = base.repeat(MFG_M_SIB,1)
        mb0 = torch.cat([mb, tok0.unsqueeze(1)],1)
        mbp = torch.cat([mb, tokp.unsqueeze(1)],1)
        mbm = torch.cat([mb, tokm.unsqueeze(1)],1)

        # next-step dists
        Pb2 = topk_project(softmax_T32(forward_with_gain_full(mb0, layer_idx, delta_g)[:,-1,:], T_SOFT), TOPK)
        Pp2 = topk_project(softmax_T32(forward_with_gain_full(mbp, layer_idx, delta_g)[:,-1,:], T_SOFT), TOPK)
        Pm2 = topk_project(softmax_T32(forward_with_gain_full(mbm, layer_idx, delta_g)[:,-1,:], T_SOFT), TOPK)

        D_next_p = float(js_div(Pp2, Pb2).mean().detach())
        D_next_m = float(js_div(Pm2, Pb2).mean().detach())
        X_t = 0.5*((D_next_p - D_curr_p) - (D_next_m - D_curr_m))
        X_list.append(X_t)

        # advance baseline context by first baseline sibling
        base = torch.cat([base, tok0[0:1].unsqueeze(0)], 1)

        # refresh current dists
        P0 = topk_project(arm_dist(0), TOPK)
        Pp = topk_project(arm_dist(+1), TOPK)
        Pm = topk_project(arm_dist(-1), TOPK)
        D_curr_p = float(js_div(Pp.unsqueeze(0), P0.unsqueeze(0)).detach())
        D_curr_m = float(js_div(Pm.unsqueeze(0), P0.unsqueeze(0)).detach())

    return np.array(X_list, dtype=float)

def mfg_actions(prompt, blocks_to_test, delta_g, rng = None):
    if rng is None:
        torch_rng = torch.Generator(device=DEVICE)
        np_rng = np.random.default_rng()
        rng = (torch_rng, np_rng)
    A=[]
    for j in blocks_to_test:
        Xm = run_prompt_Xseq_layergain(prompt, layer_idx=j, delta_g=-delta_g, rng=rng)
        Xp = run_prompt_Xseq_layergain(prompt, layer_idx=j, delta_g=+delta_g, rng=rng)
        a_j = (float(np.mean(Xp)) - float(np.mean(Xm))) / (2*delta_g)
        A.append(a_j)
    return np.array(A, dtype=float)

def bootstrap_mu(A, B, np_rng = None):
    if np_rng is None:
        np_rng = np.random.default_rng()
    n = len(A)
    if n == 0:
        return dict(mu_hat=0.0, se=0.0, ci=(0.0,0.0))
    boot = [float(np.mean(np_rng.choice(A, size=n, replace=True))) for _ in range(B)]
    mu_hat = float(np.mean(A))
    se = float(np.std(boot, ddof=1))
    lo, hi = float(np.percentile(boot, 2.5)), float(np.percentile(boot, 97.5))
    return dict(mu_hat=mu_hat, se=se, ci=(lo,hi))

# @title
# -------- Tests

def probe_test(start_msg, prompts, rngs, eps_h, mode, is_pilot, needs_verdict):
    print(start_msg)
    X = pooled_X(prompts, rngs, mode=mode, eps_h=eps_h)
    R = e_process_and_pval_numpy(X)
    if is_pilot:
        plan = power_plan_from_pilot(X, d_target=5e-3, alpha=ALPHA, power=0.8)
        msg = f"T_pilot={R['T']}  sd≈{plan['sd']:.3e}  N_required(d=5e-3, α=0.05, 0.8 pow)≈{plan['N_required']}"
    else:
        plan = None
        msg = f"T={R['T']}  meanX={R['mean']:.3e}  sd={R['sd']:.3e}  Emax={R['Emax']:.3f}  Z={R['Z']:.3f}  p={R['p']:.3f}"
    if needs_verdict:
        verdict = ("REJECT" if R['p']<=ALPHA else "NEUTRAL")
        msg = f"{msg}  verdict={verdict}"
    print(msg)
    return msg, X, R, plan

def randomization_pvalue_test(start_msg, X_closed, B, np_rng):
    print(start_msg)
    rand_chk = randomization_pvalue(X_closed, B=B, np_rng = np_rng)
    msg = f"empirical p≈{rand_chk['p_emp']:.3f}  (≈0.5 under neutrality)"
    print(msg)
    return msg, rand_chk

def layer_as_agent_test(start_msg, prompt, blocks_to_test, delta_g, boot_b, rng, bootstrap_rng, needs_verdict):
    print(start_msg)
    A = mfg_actions(prompt, blocks_to_test, delta_g=delta_g, rng=rng)
    MFG = bootstrap_mu(A, B=boot_b, np_rng=bootstrap_rng)
    ci_lo, ci_hi = MFG["ci"]
    verdict_mfg = "NEUTRAL" if (ci_lo<=0.0<=ci_hi) else ("EXPANSIVE" if MFG["mu_hat"]>0 else "CONTRACTIVE")
    msg = f"Blocks tested={len(A)}  μ̂={MFG['mu_hat']:.3e}  SE≈{MFG['se']:.3e}  95% CI=({ci_lo:.3e}, {ci_hi:.3e})"
    if needs_verdict:
        msg = f"{msg}  verdict={verdict_mfg}"
    print(msg)
    return msg, A, MFG, verdict_mfg

def execute_tests(models, test_params):
    for modelname in models:
        print(f"\n=== Testing model: {modelname} ===\n")
        global tokz, model, blocks, ln_f, wte, wpe, max_prompt_length, max_batch
        tokz, model, blocks, ln_f, wte, wpe, max_prompt_length, max_batch = get_model_data(modelname) # Ugly way (for now) to set global vars of model

        probes = test_params['probe_test']
        randomization_pvalue_params = test_params['randomization_pvalue_test']
        layer_as_agent_params = test_params['layer_as_agent_test']
        for p in probes:
            res = probe_test(**p)
            if p is closedpool:
                trajectory_msg = res[0]
                for p in randomization_pvalue_params:
                    p['X_closed'] = res[1]

        print("\n=== MFG (trajectory-as-agent) closed pooled ===")
        print(trajectory_msg)

        for p in randomization_pvalue_params:
            randomization_pvalue_test(**p)

        for p in layer_as_agent_params:
            layer_as_agent_test(**p)
        # Unload to prevent OOM
        unload_model([tokz, model, blocks, ln_f, wte, wpe, max_prompt_length, max_batch])

# @title

# ---------------- Test configurations ----------------
pilot = {
    'start_msg': "=== Pilot (closed, pooled over first 2 prompts × 1 seed) ===",
    'prompts': PROMPTS[:2],
    'rngs': [RNGS[0]],
    'eps_h': EPS_H,
    'mode': 'closed',
    'is_pilot': True,
    'needs_verdict': False
}
closedpool = {
    'start_msg': "\n=== Closed probe (pooled across prompts × seeds) ===",
    'prompts': PROMPTS,
    'rngs': RNGS,
    'eps_h': EPS_H,
    'mode': 'closed',
    'is_pilot': False,
    'needs_verdict': True,
}
openpool = {
    'start_msg': "\n=== Open probe (pooled across prompts × seeds) ===",
    'prompts': PROMPTS,
    'rngs': RNGS,
    'eps_h': EPS_H,
    'mode': 'open',
    'is_pilot': False,
    'needs_verdict': True,
}
placebo = {
    'start_msg': "\n=== Placebo ε=0 (closed, pooled) ===",
    'prompts': PROMPTS,
    'rngs': RNGS,
    'eps_h': 0.0,
    'mode': 'closed',
    'is_pilot': False,
    'needs_verdict': False
    }
randomization_pvalue_params = {
    'start_msg': "\n=== Label randomization p-value (closed, pooled) ===",
    'X_closed': None,  # to be filled after closedpool probe
    'B': 2000,
    'np_rng': BASE_RNG.spawn(1)[0]
}
layer_as_agent_params = {
    'start_msg': "\n=== MFG (layer-as-agent) actions and μ̂ CI (prompt 1) ===",
    'prompt': PROMPTS[0],
    'blocks_to_test': [0,2,4,6], # adjust to model depth
    'delta_g': 0.10,
    'boot_b': 1000, # bootstrap resamples for μ̂ CI
    'rng': RNGS[0],
    'bootstrap_rng': BASE_RNG.spawn(1)[0],
    'needs_verdict': True,
  }

probe_params = [closedpool, openpool, placebo]

test_params = {
    'pilot_test': pilot,
    'probe_test': probe_params,
    'randomization_pvalue_test': [randomization_pvalue_params],
    'layer_as_agent_test': [layer_as_agent_params]
}

# models = ["distilgpt2","sshleifer/tiny-gpt2",'gpt2-medium','gpt2-large'] # add more model names here if desired
models = ['gpt2-medium','gpt2-large']

# Pre-load models to prevent printer interruption
for model_name in models:
  try:
    tokz, model, blocks, ln_f, wte, wpe = get_model_data(model_name)
    print(f"Loaded model: {model_name}")
    unload_model([tokz, model, blocks, ln_f, wte, wpe])
    print(f"Unloaded model: {model_name}")
  except:
    pass

def sample_with_uniform_batch(P, u):
    """
    Batched version of sample_with_uniform.
    P: [B, V] probability distributions
    u: [B, M_SIB] uniform numbers in [0,1)
    Returns:
      idx: [B, M_SIB] sampled token indices
    """
    B, V = P.shape
    M = u.shape[1]

    # cumulative sum along vocab
    c = torch.cumsum(P, dim=-1)  # [B, V]

    # flatten batch × M_SIB to 1D for searchsorted
    c_flat = c.repeat_interleave(M, dim=0)  # [B*M, V]
    u_flat = u.reshape(-1, 1)               # [B*M, 1]

    # searchsorted along last dimension
    idx_flat = torch.searchsorted(c_flat, u_flat, right=False).squeeze(-1)  # [B*M]

    # reshape back to [B, M_SIB]
    idx = idx_flat.view(B, M)

    return idx.clamp(0, V-1).long()

@torch.no_grad()
def build_past_for_arm_batch(model, tokens, inj_idx, v_dir, eps_h):
    """
    Batched version of build_past_for_arm with explicit sibling dimension.
    tokens: [B, T]
    inj_idx: [B] tensor of injection positions
    v_dir: [B, H] adversarial directions
    eps_h: float
    Returns:
        past_key_values: tuple of layers, each (k, v) with shape [B, 1, ...]
        logits: [B, V]
    """
    B, T = tokens.shape
    device = tokens.device

    pos = torch.arange(T, device=device).unsqueeze(0).expand(B, T)  # [B, T]
    emb = model.transformer.wte(tokens) + model.transformer.wpe(pos)  # [B, T, H]

    if eps_h != 0.0:
        batch_indices = torch.arange(B, device=device)
        emb[batch_indices, inj_idx, :] += eps_h * v_dir

    # Forward pass
    out = model(inputs_embeds=emb, use_cache=True)

    # Add explicit "sibling" dimension of 1 for past_key_values
    past_with_sib = tuple(
        (k.unsqueeze(1), v.unsqueeze(1))  # [B, 1, ...] each
        for k, v in out.past_key_values
    )

    return past_with_sib, out.logits[:, -1, :]

@torch.no_grad()
def one_step_from_past_batch(model, past_batch, next_ids):
    """
    Batched version of one_step_from_past without legacy cache.

    Args:
      model: language model
      past_batch: tuple of (k, v) per layer, each [B, M_SIB, H, S, D]
      next_ids: [B, M_SIB] tensor of token IDs for the next step

    Returns:
      new_past: tuple of (k, v) per layer, same shape [B, M_SIB, H, S, D]
      logits: [B, M_SIB, V] tensor
    """
    B, M_SIB = next_ids.shape
    # flatten batch and sibling dims
    past_flat = tuple((k.reshape(B*M_SIB, *k.shape[2:]), v.reshape(B*M_SIB, *v.shape[2:]))
                      for k, v in past_batch)
    next_ids_flat = next_ids.reshape(B*M_SIB, 1)  # [B*M_SIB, 1]

    out = model(input_ids=next_ids_flat, use_cache=True, past_key_values=past_flat)

    # reshape past back to [B, M_SIB, H, S, D]
    new_past = tuple(
        (k.reshape(B, M_SIB, *k.shape[1:]), v.reshape(B, M_SIB, *v.shape[1:]))
        for k, v in out.past_key_values
    )

    # reshape logits to [B, M_SIB, V]
    logits = out.logits[:, -1, :].reshape(B, M_SIB, -1)
    return new_past, logits

def repeat_past_batch(past, M_SIB):
    """
    Repeat past for each sibling in a batched setup, keeping batch and sibling dimensions separate.

    Args:
      past: tuple of (k, v) tensors for each layer, each with shape [B, H, S, D]
      M_SIB: number of siblings per batch row

    Returns:
      tuple of (k, v) for each layer, each with shape [B, M_SIB, H, S, D]
    """
    repeated = []
    for k, v in past:
        # insert a new dimension for siblings
        k_rep = k.repeat(1, M_SIB, 1, 1, 1)  # [B, M_SIB, H, S, D]
        v_rep = v.repeat(1, M_SIB, 1, 1, 1)  # [B, M_SIB, H, S, D]
        repeated.append((k_rep, v_rep))
    return tuple(repeated)

def batch_get_v_dir(ids, inj_idx, torch_rngs):
    """
    ids: [B, T0]
    inj_idx: int (shared) or LongTensor [B] (per-example index)
    torch_rngs: list of torch.Generator, length B
    returns: [B, hidden_dim]
    """
    B, T0 = ids.shape
    H = wte.weight.shape[1]

    # Case 1: Random vector only
    if not ADV_DIR:
        vs = []
        for rng in torch_rngs:
            v = torch.randn(H, device=DEVICE, generator=rng)
            v = v / (v.norm() + 1e-12)
            vs.append(v)
        return torch.stack(vs, dim=0)  # [B, H]

    # Case 2: Adversarial direction
    pos = torch.arange(T0, device=DEVICE).unsqueeze(0).expand(B, -1)
    with torch.enable_grad():
        emb = (wte(ids) + wpe(pos)).detach().clone().requires_grad_(True)  # [B, T0, H]
        x = emb
        for blk in blocks:
            x = blk(x)[0]
        h = ln_f(x)
        logits = model.lm_head(h)  # [B, T0, V]
        P_all = softmax_T32(logits, T_SOFT)  # [B, T0, V]

        Vocab = P_all.shape[-1]
        eos = tokz.eos_token_id or (Vocab - 1)

        # Build q distribution: [B, V]
        if isinstance(inj_idx, int):
            inj_idx = torch.full((B,), inj_idx, dtype=torch.long, device=DEVICE)

        q = torch.zeros(B, Vocab, device=DEVICE)
        q[:, eos] = 1.0

        # KL divergence per example
        gather_p = P_all[torch.arange(B, device=DEVICE), inj_idx, :]  # [B, V]
        KL = torch.sum(
            gather_p * (torch.log(gather_p + 1e-16) - torch.log(q + 1e-16)),
            dim=-1
        )  # [B]

        KL.sum().backward()  # backward over sum to cover batch
        v = emb.grad[torch.arange(B, device=DEVICE), inj_idx, :].detach()  # [B, H]

    # Handle degenerate case
    norms = v.norm(dim=1, keepdim=True)
    mask = norms < 1e-12
    if mask.any():
        for i in torch.nonzero(mask, as_tuple=False).flatten():
            v[i] = torch.randn(H, device=DEVICE, generator=torch_rngs[i])

    v = v / (v.norm(dim=1, keepdim=True) + 1e-12)
    return v  # [B, H]

def run_prompt_Xseq_kv_batch(ids, batch_rngs, v_dir, inj, batch_prompts, steps=STEPS_MAX, eps_h=EPS_H, mode="closed"):
    """
    Batched version of run_prompt_Xseq_kv.
    Arguments:
      prompts: list[str]
      v_dir: tensor [batch, hidden_dim]
      inj: tensor [batch] (injection index for each prompt)
    """
    assert mode in {"closed", "open"}
    torch_rngs = [rng[0] for rng in batch_rngs]
    batch_seeds = [rng[2] for rng in batch_rngs]

    # === Build baseline and nudged pasts ===
    past0, L0 = build_past_for_arm_batch(model, ids, inj, v_dir, eps_h=0.0)
    pastP, Lp = build_past_for_arm_batch(model, ids, inj, v_dir, eps_h=+eps_h)
    pastM, Lm = build_past_for_arm_batch(model, ids, inj, v_dir, eps_h=-eps_h)
    # # Initial distributions
    P0 = topk_project(softmax_T32(L0, T_SOFT), TOPK)
    Pp = topk_project(softmax_T32(Lp, T_SOFT), TOPK)
    Pm = topk_project(softmax_T32(Lm, T_SOFT), TOPK)

    # # Initial divergences
    D_curr_p = js_div(Pp, P0)    # [batch]
    D_curr_m = js_div(Pm, P0)    # [batch]

    X_list = []

    for t in range(steps):
        # Uniform noise per batch, per sibling
        u = torch.stack([torch.rand(M_SIB, device=DEVICE, generator=g) for g in torch_rngs], dim=0)

        if mode == "closed":
            tok0 = sample_with_uniform_batch(P0, u)  # [batch, M_SIB]
            tokp = sample_with_uniform_batch(Pp, u)
            tokm = sample_with_uniform_batch(Pm, u)

            past0_rep = repeat_past_batch(past0, M_SIB)
            pastP_rep = repeat_past_batch(pastP, M_SIB)
            pastM_rep = repeat_past_batch(pastM, M_SIB)

            _, L02 = one_step_from_past_batch(model, past0_rep, tok0)
            _, Lp2 = one_step_from_past_batch(model, pastP_rep, tokp)
            _, Lm2 = one_step_from_past_batch(model, pastM_rep, tokm)

            Pb2 = topk_project(softmax_T32(L02, T_SOFT), TOPK)  # [B, M_SIB, V]
            Pp2 = topk_project(softmax_T32(Lp2, T_SOFT), TOPK)
            Pm2 = topk_project(softmax_T32(Lm2, T_SOFT), TOPK)

        else:  # "open"
            tok0 = sample_with_uniform_batch(P0, u)       # [batch, M_SIB]
            past0_rep = repeat_past_batch(past0, M_SIB)
            past0_rep, L02 = one_step_from_past_batch(model, past0_rep, tok0)
            Pb2 = topk_project(softmax_T32(L02, T_SOFT), TOPK)
            Pp2 = Pb2
            Pm2 = Pb2

        D_next_p = js_div(Pp2, Pb2).mean(dim=1)
        D_next_m = js_div(Pm2, Pb2).mean(dim=1)
        X_t = 0.5*((D_next_p - D_curr_p) - (D_next_m - D_curr_m))
        X_list.append(X_t)

        # advance single-token contexts (first sibling) to refresh P0,Pp,Pm
        tok0_first = tok0[:, 0:1]
        past0, L0 = one_step_from_past_batch(model, past0, tok0_first)
        if mode == "closed":
            pastP, Lp = one_step_from_past_batch(model, pastP, tokp[:, 0:1])
            pastM, Lm = one_step_from_past_batch(model, pastM, tokm[:, 0:1])
        else:
            pastP, Lp = one_step_from_past_batch(model, pastP, tok0_first)
            pastM, Lm = one_step_from_past_batch(model, pastM, tok0_first)

        P0 = topk_project(softmax_T32(L0, T_SOFT), TOPK).squeeze(1)
        Pp = topk_project(softmax_T32(Lp, T_SOFT), TOPK).squeeze(1)
        Pm = topk_project(softmax_T32(Lm, T_SOFT), TOPK).squeeze(1)

        D_curr_p = js_div(Pp, P0)
        D_curr_m = js_div(Pm, P0)

    result = torch.stack(X_list, dim=1).cpu().numpy()  # [B, steps]
    B, steps = result.shape

    # Broadcast step numbers
    steps_arr = np.arange(steps, dtype=float)[None, :]  # [1, steps]
    steps_arr = np.repeat(steps_arr, B, axis=0)         # [B, steps]

    # Broadcast seeds
    seeds_arr = np.array(batch_seeds)[:, None]          # [B, 1]
    seeds_arr = np.repeat(seeds_arr, steps, axis=1)     # [B, steps]

    # Broadcast prompts
    prompts_arr = np.array(batch_prompts)[:, None]      # [B, 1]
    prompts_arr = np.repeat(prompts_arr, steps, axis=1) # [B, steps]

    # Flatten everything to [B*steps]
    flat_prompts = prompts_arr.reshape(-1)
    flat_seeds   = seeds_arr.reshape(-1)
    flat_steps   = steps_arr.reshape(-1)
    flat_result  = result.reshape(-1)

    # Build final 2D array [B*steps, 4]
    final = np.column_stack([flat_prompts, flat_seeds, flat_steps, flat_result])

    return final

def batch_pooled_X(prompts, rngs, mode="closed", eps_h=EPS_H, max_batch=8):
    """
    prompts: list of strings
    rngs: list of torch.Generator
    returns: concatenated results for all (rng, prompt) combinations
    """
    allX = []
    pairs = [(p, r) for p in prompts for r in rngs]

    for i in range(0, len(pairs), max_batch):
        batch_pairs = pairs[i:i+max_batch]

        # split into two aligned lists
        batch_prompts = [p for (p, _) in batch_pairs]
        batch_rngs    = [r for (_, r) in batch_pairs]
        torch_rngs = [rng[0] for rng in batch_rngs]

        # tokenize all prompts in batch
        ids = [tokz(p, return_tensors="pt", padding = "max_length", add_special_tokens=False)["input_ids"][:,:max_prompt_length] for p in batch_prompts]
        ids = torch.cat(ids, dim=0).to(DEVICE)  # shape [B, T0]
        T0 = ids.shape[1]
        inj = T0 // 2

        # # call batched get_v_dir
        v_dir = batch_get_v_dir(ids, inj_idx=inj, torch_rngs=torch_rngs)

        # call batched run_prompt_Xseq_kv
        X_batch = run_prompt_Xseq_kv_batch(ids, batch_rngs, v_dir, inj, batch_prompts, steps=STEPS_MAX, eps_h=eps_h, mode=mode)
        allX.append(X_batch)

    return np.concatenate(allX, axis=0) if allX else np.array([])

def generate_test_rngs(rng,n_required):
    seeds = []
    n_required = n_required / len(PROMPTS)
    while len(seeds) < n_required:
        seed = rng.integers(0, 2**32-1)
        if seed not in seeds:
            seeds.append(seed)
    rngs = [(torch.Generator(device=DEVICE).manual_seed(int(s)),np.random.default_rng(int(s)),s) for s in seeds]
    return rngs

def probe_test_batch(start_msg, prompts, rngs, eps_h, mode, is_pilot, needs_verdict):
    print(start_msg)
    X = batch_pooled_X(prompts, rngs, mode=mode, eps_h=eps_h)
    X_R = X[:, 3].astype(float)
    X = np.concatenate([np.full((X.shape[0], 1), start_msg, dtype=object), X], axis=1)
    R = e_process_and_pval_numpy(X_R)
    if is_pilot:
        plan = power_plan_from_pilot(X_R, d_target=5e-3, alpha=ALPHA, power=0.8)
        msg = f"T_pilot={R['T']}  sd≈{plan['sd']:.3e}  N_required(d=5e-3, α=0.05, 0.8 pow)≈{plan['N_required']}"
    else:
        plan = None
        msg = f"T={R['T']}  meanX={R['mean']:.3e}  sd={R['sd']:.3e}  Emax={R['Emax']:.3f}  Z={R['Z']:.3f}  p={R['p']:.3f}"
    if needs_verdict:
        verdict = ("REJECT" if R['p']<=ALPHA else "NEUTRAL")
        msg = f"{msg}  verdict={verdict}"
    print(msg)
    return msg, X, R, plan, X_R

def execute_tests_batch(models, test_params):
  all_results = []
  all_msgs = []
  for modelname in models:
      modelmsg = f"\n=== Testing model: {modelname} ===\n"
      print(modelmsg)
      all_msgs.append(modelmsg)
      global tokz, model, blocks, ln_f, wte, wpe, max_prompt_length, max_batch
      tokz, model, blocks, ln_f, wte, wpe, max_prompt_length, max_batch = get_model_data(modelname) # Ugly way (for now) to set global vars of model

      pilot = test_params['pilot_test']
      probes = test_params['probe_test']
      randomization_pvalue_params = test_params['randomization_pvalue_test']
      layer_as_agent_params = test_params['layer_as_agent_test']

      results = []
      all_msgs.append(pilot['start_msg'])
      pilot_result = probe_test_batch(**pilot)
      all_msgs.append(pilot_result[0])
      results.append(pilot_result[1])

      rngs = generate_test_rngs(BASE_RNG, pilot_result[3]['N_required'])
      test_msg = f'Number of test seeds: {len(rngs)}'
      print(test_msg)
      all_msgs.append(test_msg)
      for p in probes:
          p['rngs'] = rngs
          all_msgs.append(p['start_msg'])
          res = probe_test_batch(**p)
          all_msgs.append(res[0])
          results.append(res[1])
          if p is closedpool:
              trajectory_msg = res[0]
              for p in randomization_pvalue_params:
                  p['X_closed'] = res[4]

      probe_results = np.concatenate(results, axis=0)
      probe_results = np.core.records.fromarrays(
          [np.full(probe_results.shape[0], modelname),
          probe_results[:,0],
          probe_results[:,1],
          probe_results[:,2],
          probe_results[:,3],
          probe_results[:,4]],
          names='model, test, prompt, seed, step, result'
      )
      pool_msg = "\n=== MFG (trajectory-as-agent) closed pooled ==="
      print(pool_msg)
      all_msgs.append(pool_msg)
      print(trajectory_msg)
      all_msgs.append(trajectory_msg)

      for p in randomization_pvalue_params:
          all_msgs.append(p['start_msg'])
          msg, chk = randomization_pvalue_test(**p)
          all_msgs.append(msg)

      for p in layer_as_agent_params:
          all_msgs.append(p['start_msg'])
          msg, A, MFG, verdict_mfg = layer_as_agent_test(**p)
          all_msgs.append(msg)

      all_results.append(probe_results)
      # Unload to prevent OOM
      unload_model([tokz, model, blocks, ln_f, wte, wpe, max_prompt_length, max_batch])
  all_results = np.concatenate(all_results, axis=0)
  all_msgs = np.array(all_msgs)
  return all_results, all_msgs

print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
print("NumPy:", np.__version__)
print(f"Using device: {DEVICE}")
print(f"Using base seed: {BASE_SEED}")
test_result, msgs = execute_tests_batch(models, test_params)
print('finished')

import pandas as pd

df = pd.DataFrame(test_result, columns=["model", "test", "prompt", "seed", "step", "result"])
df.to_csv("results.csv", index=False)
df = pd.DataFrame(msgs, columns=["msg"])
df.to_csv("msgs.csv", index=False)

import pandas as pd
# ABLATION STUDY
for t in [.5, 1, 2, 5]:
  print(f"ABLATION STUDY: {T_SOFT}, {M_SIB}")
  T_SOFT = t
  test_result, msgs = execute_tests_batch(models, test_params)
  df = pd.DataFrame(test_result, columns=["model", "test", "prompt", "seed", "step", "result"])
  df.to_csv(f"results-TEMP{str(T_SOFT)}-M_SIB-{str(M_SIB)}.csv", index=False)
  df = pd.DataFrame(msgs, columns=["msg"])
  df.to_csv(f"msgs-TEMP{str(T_SOFT)}-M_SIB-{str(M_SIB)}.csv", index=False)
T_SOFT = 1.7
for sib in [4, 8, 32, 64]:
  M_SIB = sib
  print(f"ABLATION STUDY: {T_SOFT}, {M_SIB}")
  test_result, msgs = execute_tests_batch(models, test_params)
  df = pd.DataFrame(test_result, columns=["model", "test", "prompt", "seed", "step", "result"])
  df.to_csv(f"results-TEMP{str(T_SOFT)}-M_SIB-{str(M_SIB)}.csv", index=False)
  df = pd.DataFrame(msgs, columns=["msg"])
  df.to_csv(f"msgs-TEMP{str(T_SOFT)}-M_SIB-{str(M_SIB)}.csv", index=False)
  print('finished')

