# ============================================================
# Idealized Self-Attention on Permutation Graph (Capacity vs D_K and h)
# Includes GPU determinism 
# ============================================================

# --- Deterministic CUDA GEMM (must be set BEFORE importing torch) ---
import os
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")  # use ":16:8" if memory is very tight

# --- Standard imports (torch comes AFTER the env var above) ---
import math, json, random, time, zipfile
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

# --- CUDA precision & backend knobs (safe on CPU, effective on CUDA) ---
if torch.cuda.is_available():
    # Deterministic cuDNN algorithms; disable autotuner
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Disable TF32 (for tighter reproducibility in float32)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

torch.set_grad_enabled(True)

# =========================
# Config
# =========================
SEED = 42
m = 256                       # nodes
d_model = 32                  # embedding dim

h = 8                         # (kept for single-run mode)

DK_TRIALS = [32, 64, 128, 256]     # used by single-run mode only
alpha = 10.0                       # sharpness
lr = 1e-3                          # AdamW lr
weight_decay = 0.0
max_steps = 20_000
check_every = 500
patience = 5
target_f1 = 0.995
target_pos_rate = 0.5              # ρ in (0,1)
VAL_CONTEXTS = 500
TEST_CONTEXTS = 2000

TRAIN_LENGTH_CHOICES = [16]
TEST_LENGTH_CHOICES  = [16]

SAVE_TO_DRIVE = True
BASE_NAME = "perm_graph_results"

# Optional: print reproducibility report at start
PRINT_REPRO_REPORT = True


# =========================
# Sweep config
# =========================
RUN_SWEEP = True
H_SWEEP = [1,2,4,8,16,32,64]
M_SWEEP = [64,128,256,512]
DMODEL_SWEEP = [16,32,64]
REPEATS = 10
DK_RATIOS = [0.5,0.75,1,1.5, 2.0]
MAX_DK = 2048
INCLUDE_BASE_DKS = False
SWEEP_TAG = time.strftime("%Y%m%d_%H%M%S")
SWEEP_ROOT = f"perm_graph_sweep_{SWEEP_TAG}"


# =========================
# Utilities
# =========================
def print_repro_report():
    print("=== Reproducibility Report ===")
    print("Device:", "cuda" if torch.cuda.is_available() else "cpu")
    try:
        print("Deterministic algos:", torch.are_deterministic_algorithms_enabled())
    except AttributeError:
        # Older PyTorch
        print("Deterministic algos: <unknown for this torch version>")
    print("CUBLAS_WORKSPACE_CONFIG:", os.getenv("CUBLAS_WORKSPACE_CONFIG"))
    if torch.cuda.is_available():
        print("TF32 matmul allowed:", torch.backends.cuda.matmul.allow_tf32)
        print("TF32 cuDNN allowed:", torch.backends.cudnn.allow_tf32)
        print("cuDNN deterministic:", torch.backends.cudnn.deterministic)
        print("cuDNN benchmark:", torch.backends.cudnn.benchmark)
    print("==============================\n")

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Keep cuDNN behavior stable
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Enforce deterministic ops (raise on non-determinism)
    try:
        torch.use_deterministic_algorithms(True, warn_only=False)
    except TypeError:
        # Older PyTorch without warn_only kwarg
        torch.use_deterministic_algorithms(True)
    except Exception:
        # If not available for some reason, at least warn
        print("Warning: torch.use_deterministic_algorithms could not be enabled.")

def device_str():
    return "cuda" if torch.cuda.is_available() else "cpu"

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def rng_from(master: int, salt: int) -> np.random.Generator:
    ss = np.random.SeedSequence([master, salt])
    return np.random.default_rng(ss.generate_state(4))

def get_outdir(save_to_drive: bool, rho: float) -> str:
    ts = time.strftime("%Y%m%d_%H%M%S")
    base = "/content"
    if save_to_drive:
        try:
            from google.colab import drive  # type: ignore
            drive.mount("/content/drive", force_remount=False)
            base = "/content/drive/MyDrive"
            print("Saving to Google Drive under MyDrive.")
        except Exception as e:
            print("Drive mount failed; falling back to /content. Reason:", repr(e))
            base = "/content"
    name = f"{BASE_NAME}_rho{int(100*rho)}_{ts}"
    return os.path.join(base, name)

# =========================
# Data generation
# =========================
@dataclass
class DataBundle:
    X: torch.Tensor
    perm: torch.Tensor
    contexts_val: List[torch.Tensor]
    contexts_test: List[torch.Tensor]

def build_graph_and_embeddings(seed: int) -> Tuple[torch.Tensor, torch.Tensor]:
    rng = rng_from(seed, 100)
    perm_np = np.arange(m)
    rng.shuffle(perm_np)
    perm = torch.tensor(perm_np, dtype=torch.long)
    torch_rng = torch.Generator().manual_seed(seed)
    X = torch.randn((m, d_model), generator=torch_rng) / math.sqrt(d_model)
    X = X / (X.norm(dim=1, keepdim=True) + 1e-8)
    return X, perm

def sample_context_with_pos_rate(rng: np.random.Generator,
                                 perm: torch.Tensor,
                                 ell: int,
                                 rho: float) -> torch.Tensor:
    S = rng.choice(m, size=ell, replace=False).astype(np.int64)
    S_set = set(int(x) for x in S)
    b = int(rng.binomial(ell, rho))
    if b == 0:
        return torch.tensor(list(S_set), dtype=torch.long)
    U = rng.choice(list(S_set), size=min(b, len(S_set)), replace=False).astype(np.int64)
    perm_np = perm.detach().cpu().numpy()
    for i in U:
        target = int(perm_np[i])
        if target not in S_set:
            pool = list(S_set - {int(i)})
            if not pool:
                continue
            j = int(rng.choice(pool))
            S_set.remove(j)
            S_set.add(target)
    return torch.tensor(list(S_set), dtype=torch.long)

def sample_context(rng: np.random.Generator,
                   perm: torch.Tensor,
                   rho: float,
                   length_choices: List[int]) -> torch.Tensor:
    ell = int(rng.choice(length_choices))
    return sample_context_with_pos_rate(rng, perm, ell, rho)

def make_context_list(n: int,
                      seed: int,
                      salt: int,
                      perm: torch.Tensor,
                      rho: float,
                      length_choices: List[int]) -> List[torch.Tensor]:
    rng = rng_from(seed, salt)
    return [sample_context(rng, perm, rho, length_choices) for _ in range(n)]

def make_data(seed: int, rho: float) -> DataBundle:
    X, perm = build_graph_and_embeddings(seed)
    contexts_val  = make_context_list(VAL_CONTEXTS,  seed, 200, perm, rho, TEST_LENGTH_CHOICES)
    contexts_test = make_context_list(TEST_CONTEXTS, seed, 300, perm, rho, TEST_LENGTH_CHOICES)
    return DataBundle(X=X, perm=perm, contexts_val=contexts_val, contexts_test=contexts_test)

def build_labels_for_context(ctx: torch.Tensor, perm: torch.Tensor, device) -> torch.Tensor:
    targets = perm[ctx]                            # [ell]
    eq = (targets[:, None] == ctx[None, :])        # [ell, ell] boolean
    return eq.to(device, dtype=torch.float32)

# =========================
# Model (idealized SA)
# =========================
class IdealizedSelfAttention(nn.Module):
    def __init__(self, d_model: int, D_K: int, h: int, alpha: float = 10.0):
        super().__init__()
        assert D_K % h == 0, "D_K must be divisible by h"
        self.d_model = d_model
        self.D_K = D_K
        self.h = h
        self.d_k = D_K // h
        self.alpha = alpha
        std = 1.0 / math.sqrt(d_model)
        self.W_Q = nn.Parameter(torch.randn(d_model, D_K) * std)
        self.W_K = nn.Parameter(torch.randn(d_model, D_K) * std)
        self.tau  = nn.Parameter(torch.tensor(0.0))

    @torch.no_grad()
    def save_checkpoint(self, path: str, meta: Dict):
        ckpt = {
            "W_Q": self.W_Q.detach().cpu(),
            "W_K": self.W_K.detach().cpu(),
            "tau": float(self.tau.detach().cpu()),
            "meta": meta
        }
        torch.save(ckpt, path)

    def scores_Smax(self, X_ctx: torch.Tensor) -> torch.Tensor:
        ell = X_ctx.shape[0]
        Q = X_ctx @ self.W_Q
        K = X_ctx @ self.W_K
        Qh = Q.view(ell, self.h, self.d_k).transpose(0, 1)        # [h, ell, d_k]
        Kh = K.view(ell, self.h, self.d_k).transpose(0, 1)        # [h, ell, d_k]
        S_heads = torch.matmul(Qh, Kh.transpose(1, 2))            # [h, ell, ell]
        S_max = S_heads.max(dim=0).values                         # [ell, ell]
        return S_max

# =========================
# Loss & metrics
# =========================
def weighted_bce_logits(z: torch.Tensor, y: torch.Tensor, pos_weight_scalar: float) -> torch.Tensor:
    pos_term = F.softplus(-z) * y * pos_weight_scalar
    neg_term = F.softplus( z) * (1.0 - y)
    return (pos_term + neg_term).mean()

@torch.no_grad()
def evaluate_model(model: IdealizedSelfAttention,
                   X: torch.Tensor,
                   perm: torch.Tensor,
                   contexts: List[torch.Tensor],
                   device,
                   compute_hist: bool = False) -> Tuple[float, float, Optional[np.ndarray], Optional[np.ndarray]]:
    model.eval()
    tau = float(model.tau.detach().cpu().item())
    TP = FP = FN = 0
    pos_scores_all = []
    hard_neg_scores_all = []

    for ctx in contexts:
        ctx = ctx.to(device)
        X_ctx = X[ctx]
        y = build_labels_for_context(ctx, perm, device)
        S_max = model.scores_Smax(X_ctx)
        preds = (S_max > tau).float()

        TP += (preds * y).sum().item()
        FP += (preds * (1.0 - y)).sum().item()
        FN += ((1.0 - preds) * y).sum().item()

        pos_scores = S_max[y == 1.0]
        if pos_scores.numel() > 0:
            pos_scores_all.append(pos_scores.detach().cpu())

        neg_matrix = S_max.clone()
        neg_matrix[y == 1.0] = -float('inf')
        for r in range(neg_matrix.shape[0]):
            row = neg_matrix[r]
            valid = torch.isfinite(row)
            if valid.any():
                k = min(5, int(valid.sum().item()))
                topk_vals = torch.topk(row[valid], k=k).values
                hard_neg_scores_all.append(topk_vals.detach().cpu())

    denom = (2 * TP + FP + FN)
    micro_f1 = (2 * TP / denom) if denom > 0 else 0.0

    if len(pos_scores_all) == 0 or len(hard_neg_scores_all) == 0:
        margin = float('nan')
    else:
        pos_concat = torch.cat(pos_scores_all)
        hard_neg_concat = torch.cat(hard_neg_scores_all)
        margin = float(pos_concat.mean().item() - hard_neg_concat.mean().item())

    if compute_hist:
        pos_np = torch.cat(pos_scores_all).numpy() if len(pos_scores_all) > 0 else np.array([])
        neg_np = torch.cat(hard_neg_scores_all).numpy() if len(hard_neg_scores_all) > 0 else np.array([])
        return micro_f1, margin, pos_np, neg_np
    else:
        return micro_f1, margin, None, None

# =========================
# Optimizer (determinism-friendly)
# =========================
def make_adamw(params, lr: float, weight_decay: float):
    kw = dict(lr=lr, weight_decay=weight_decay)
    # Prefer disabling fused/foreach paths to avoid nondeterministic atomics on some builds
    try:
        return torch.optim.AdamW(params, **kw, fused=False, foreach=False)
    except TypeError:
        try:
            return torch.optim.AdamW(params, **kw, foreach=False)
        except TypeError:
            return torch.optim.AdamW(params, **kw)

# =========================
# Training
# =========================
@dataclass
class TrainResult:
    DK: int
    h: int
    d_k: int
    passed: bool
    test_micro_f1: float
    margin: float
    tau: float
    steps_trained: int
    first_pass_step: Optional[int]
    stop_step: int
    ckpt_path: str

# === pass h_val explicitly and use deterministic-friendly AdamW ===
def train_one_configuration(DK: int,
                            h_val: int,
                            bundle: DataBundle,
                            device,
                            outdir: str,
                            rho: float) -> TrainResult:
    d_k = DK // h_val
    model = IdealizedSelfAttention(d_model=d_model, D_K=DK, h=h_val, alpha=alpha).to(device)
    opt = make_adamw(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_rng = rng_from(SEED, 500 + DK + 13*h_val)

    consecutive_passes = 0
    first_pass_step = None
    stop_step = max_steps

    for step in tqdm(range(1, max_steps + 1), desc=f"Training h={h_val}, D_K={DK}", leave=False):
        ctx = sample_context(train_rng, bundle.perm, rho, TRAIN_LENGTH_CHOICES).to(device)
        X_ctx = bundle.X[ctx]
        y = build_labels_for_context(ctx, bundle.perm, device)
        ell = X_ctx.shape[0]

        S_max = model.scores_Smax(X_ctx)
        z = alpha * (S_max - model.tau)

        pos_w = float(ell - 1)
        loss = weighted_bce_logits(z, y, pos_w)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        if step % check_every == 0:
            val_f1, _, _, _ = evaluate_model(model, bundle.X, bundle.perm,
                                             bundle.contexts_val, device, compute_hist=False)
            passed_now = (val_f1 >= target_f1)
            consecutive_passes = consecutive_passes + 1 if passed_now else 0
            if passed_now and first_pass_step is None:
                first_pass_step = step
            if consecutive_passes >= patience:
                stop_step = step
                break

    test_f1, margin, pos_hist, neg_hist = evaluate_model(
        model, bundle.X, bundle.perm, bundle.contexts_test, device, compute_hist=True
    )
    passed = (test_f1 >= target_f1)

    tag = f"DK{DK}_h{h_val}_dk{d_k}"

    try:
        if pos_hist is not None and neg_hist is not None and pos_hist.size > 0 and neg_hist.size > 0:
            plt.figure(figsize=(6,4))
            plt.hist(pos_hist, bins=60, alpha=0.6, label="S+ (positives)", density=True)
            plt.hist(neg_hist, bins=60, alpha=0.6, label="Top-5 S- (hard negatives)", density=True)
            plt.xlabel("Score")
            plt.ylabel("Density")
            plt.title(f"Score separation — {tag} (test)")
            plt.legend()
            plt.tight_layout()
            hist_path = os.path.join(outdir, f"score_hist_{tag}.png")
            ensure_dir(outdir)
            plt.savefig(hist_path, dpi=150)
            plt.close()
    except Exception as e:
        print("Histogram plotting skipped due to:", repr(e))

    tau_val = float(model.tau.detach().cpu().item())
    return TrainResult(
        DK=DK, h=h_val, d_k=d_k, passed=passed, test_micro_f1=float(test_f1),
        margin=float(margin), tau=tau_val, steps_trained=stop_step,
        first_pass_step=first_pass_step, stop_step=stop_step, ckpt_path="NA"
    )

# =========================
# Original single-experiment 
# =========================
def run_experiment():
    """
    Single experiment using current globals: SEED, m, d_model, DK_TRIALS, h, etc.
    """
    set_seed(SEED)
    device = device_str()
    print(f"Device: {device}")
    if PRINT_REPRO_REPORT:
        print_repro_report()

    outdir = get_outdir(SAVE_TO_DRIVE, target_pos_rate)
    ensure_dir(outdir)
    print("Output directory:", outdir)

    bundle = make_data(SEED, target_pos_rate)
    bundle = DataBundle(X=bundle.X.to(device),
                        perm=bundle.perm.to(device),
                        contexts_val=bundle.contexts_val,
                        contexts_test=bundle.contexts_test)

    results: List[TrainResult] = []
    smallest_passing: Optional[int] = None

    for DK in DK_TRIALS:
        print("="*80)
        print(f"Training configuration: D_K={DK} (h={h} ⇒ d_k={DK//h}), ρ={target_pos_rate}")
        tr = train_one_configuration(DK, h, bundle, device, outdir, target_pos_rate)
        results.append(tr)
        print(f"→ Test micro-F1: {tr.test_micro_f1:.6f} | margin: {tr.margin:.6f} | τ: {tr.tau:.6f}")
        print(f"   Steps trained: {tr.steps_trained} | first_pass_step: {tr.first_pass_step} | checkpoint: {tr.ckpt_path}")
        if tr.passed and smallest_passing is None:
            smallest_passing = DK

    rows = []
    for r in results:
        rows.append({
            "D_K": r.DK,
            "h": r.h,
            "d_k": r.d_k,
            "rho": target_pos_rate,
            "passed": r.passed,
            "microF1_test": r.test_micro_f1,
            "margin": r.margin,
            "tau": r.tau,
            "steps_trained": r.steps_trained,
            "first_pass_step": r.first_pass_step if r.first_pass_step is not None else "NA",
            "stop_step": r.stop_step,
            "checkpoint": r.ckpt_path
        })
    df = pd.DataFrame(rows).sort_values(["h","D_K"]).reset_index(drop=True)
    csv_path = os.path.join(outdir, "results.csv")
    json_path = os.path.join(outdir, "results.json")
    df.to_csv(csv_path, index=False)
    with open(json_path, "w") as f:
        json.dump(rows, f, indent=2)

    print("\nSaved results:")
    print(" CSV:", csv_path)
    print(" JSON:", json_path)

    print("\nResults table:")
    display(df)

    if smallest_passing is None:
        print("\nOutcome: No configuration reached micro‑F1 ≥ 0.995 on the test set.")
    else:
        print(f"\nEmpirical hat D_K*: smallest passing D_K = {smallest_passing}")

    if not SAVE_TO_DRIVE:
        zip_path = os.path.join("/content", "perm_graph_results.zip")
        print("\nCreating zip for download at:", zip_path)
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
            for root, _, files in os.walk(outdir):
                for fn in files:
                    abspath = os.path.join(root, fn)
                    relpath = os.path.relpath(abspath, os.path.dirname(outdir))
                    zf.write(abspath, arcname=relpath)
        try:
            from google.colab import files  # type: ignore
            files.download(zip_path)
        except Exception as e:
            print("Automatic download not available; you can manually fetch:", zip_path, "| Reason:", repr(e))

    return df, outdir, smallest_passing

# =========================
# Sweep helpers & driver 
# =========================
def _dk_candidates_for(m_val: int,
                       d_val: int,
                       h_val: int,
                       ratios: List[float],
                       base_dks: List[int],
                       max_dk: Optional[int]) -> List[int]:
    thr_l2 = (m_val * max(math.log2(max(m_val, 2)), 1e-9)) / max(d_val, 1e-9)
    cand = set()
    for r in ratios:
        dk = int(math.ceil(r * thr_l2))
        dk = max(dk, 1)
        dk = int(math.ceil(dk / h_val) * h_val)   # enforce divisibility by h
        cand.add(dk)
    if INCLUDE_BASE_DKS and base_dks:
        for dk in base_dks:
            if dk % h_val == 0:
                cand.add(dk)
    filtered = sorted(x for x in cand if x >= h_val and (max_dk is None or x <= max_dk))
    if not filtered:
        filtered = [h_val]
    return filtered

def _mount_and_get_sweep_root():
    # bootstrap mount & yield sweep root dir
    _orig = BASE_NAME
    base = f"{SWEEP_ROOT}/_bootstrap_only"
    globals()["BASE_NAME"] = base
    tmp = get_outdir(SAVE_TO_DRIVE, target_pos_rate)
    drive_base = "/content/drive/MyDrive" if tmp.startswith("/content/drive/MyDrive") else "/content"
    sweep_root_dir = os.path.join(drive_base, SWEEP_ROOT)
    ensure_dir(sweep_root_dir)
    globals()["BASE_NAME"] = _orig
    print(f"\nSweep root: {sweep_root_dir}")
    return sweep_root_dir, drive_base

def _ratio_log2(m_val: int, d_val: int, DK_val: int):
    thr_l2 = (m_val * math.log2(max(m_val, 2))) / max(d_val, 1e-9)
    return DK_val / max(thr_l2, 1e-9), thr_l2

def _plot_block_by_h(agg_block: pd.DataFrame, m_val: int, d_val: int, out_dir: str):
    """Multi-line plot (one line per h) with error bars; x-axis = DK / (m log2 m / d)."""
    ensure_dir(out_dir)
    plt.figure(figsize=(7.2, 4.6))
    for h_val, sub in agg_block.groupby("h"):
        sub = sub.sort_values("ratio_log2")
        x = sub["ratio_log2"].to_numpy()
        y = sub["microF1_mean"].to_numpy()
        yerr = sub["microF1_std"].fillna(0.0).to_numpy()
        plt.errorbar(x, y, yerr=yerr, fmt="-o", capsize=3, label=f"h={h_val}", alpha=0.95)
    plt.axvline(1.0, linestyle="--", linewidth=1)
    plt.xlabel("DK / (m log2 m / d_model)")
    plt.ylabel("Mean micro‑F1 (test)")
    plt.title(f"F1 vs Capacity Ratio — m={m_val}, d_model={d_val}")
    plt.legend(title="Heads")
    plt.tight_layout()
    fig_path = os.path.join(out_dir, f"f1_vs_ratio_m{m_val}_d{d_val}_byH.png")
    plt.savefig(fig_path, dpi=150)
    plt.close()
    print("Saved block plot:", fig_path)

def run_sweep():
    """
    Joint sweep over M_SWEEP × DMODEL_SWEEP × H_SWEEP with REPEATS.
    For each (m, d_model), run all heads & DKs, then immediately save
    the per-block plot to Drive. Finally, save global summaries and
    a heatmap of min DK over h for each (m, d_model).
    """
    global m, d_model, SEED

    set_seed(SEED)
    device = device_str()
    print(f"Device: {device}")
    if PRINT_REPRO_REPORT:
        print_repro_report()

    sweep_root_dir, drive_base = _mount_and_get_sweep_root()

    all_rows: List[pd.DataFrame] = []
    block_min_rows = []  # for min-DK heatmap (over h)

    DK_TRIALS_ORIG = list(DK_TRIALS)
    SEED_BASE = SEED

    for m_val in M_SWEEP:
        for d_val in DMODEL_SWEEP:
            print(f"\n=== Block: m={m_val}, d_model={d_val} ===")

            # Per-block directories
            block_dir = os.path.join(sweep_root_dir, f"m{m_val}_d{d_val}")
            ensure_dir(block_dir)

            block_rows = []  # accumulate all runs for this (m,d)

            # Do all repeats (so we can compute means/std err bars)
            for rep in range(1, REPEATS + 1):
                # Fix globals for data generation only
                m = m_val
                d_model = d_val
                SEED = int(SEED_BASE + rep*997 + m_val*13 + d_val*23)

                # Build data once per replicate so all heads/DKs share it (fair comparison)
                bundle = make_data(SEED, target_pos_rate)
                bundle = DataBundle(X=bundle.X.to(device),
                                    perm=bundle.perm.to(device),
                                    contexts_val=bundle.contexts_val,
                                    contexts_test=bundle.contexts_test)

                for h_val in H_SWEEP:
                    dk_list = _dk_candidates_for(
                        m_val=m_val, d_val=d_val, h_val=h_val,
                        ratios=DK_RATIOS, base_dks=DK_TRIALS_ORIG, max_dk=MAX_DK
                    )
                    print(f"  • run{rep:02d} | h={h_val} | DK candidates={dk_list}")

                    run_outdir = os.path.join(block_dir, f"run{rep:02d}", f"h{h_val}")
                    ensure_dir(run_outdir)

                    for DK in dk_list:
                        tr = train_one_configuration(DK, h_val, bundle, device, run_outdir, target_pos_rate)
                        block_rows.append({
                            "m": m_val,
                            "d_model": d_val,
                            "seed": SEED,
                            "run": rep,
                            "h": h_val,
                            "D_K": tr.DK,
                            "d_k": tr.d_k,
                            "rho": target_pos_rate,
                            "passed": tr.passed,
                            "microF1_test": tr.test_micro_f1,
                            "margin": tr.margin,
                            "tau": tr.tau,
                            "steps_trained": tr.steps_trained,
                            "first_pass_step": tr.first_pass_step if tr.first_pass_step is not None else "NA",
                            "stop_step": tr.stop_step,
                            "outdir": run_outdir
                        })

            # Aggregate per-block and save plots immediately
            df_block = pd.DataFrame(block_rows)
            if len(df_block) == 0:
                print("No data collected for this block.")
                continue

            # normalize first_pass_step to numeric
            def _to_num(x):
                if x == "NA": return np.nan
                try:
                    return float(x)
                except Exception:
                    return np.nan
            df_block["first_pass_step_num"] = df_block["first_pass_step"].apply(_to_num)

            # Add ratio_log2 and thr for plotting
            thr_l2 = (m_val * math.log2(max(m_val, 2))) / max(d_val, 1e-9)
            df_block["thr_log2"] = thr_l2
            df_block["ratio_log2"] = df_block["D_K"] / max(thr_l2, 1e-9)

            # Per (h, D_K) aggregates
            group_cols = ["m", "d_model", "h", "D_K"]
            agg_block = df_block.groupby(group_cols).agg(
                runs=("passed", "count"),
                pass_rate=("passed", "mean"),
                microF1_mean=("microF1_test", "mean"),
                microF1_std=("microF1_test", "std"),
                margin_mean=("margin", "mean"),
                margin_std=("margin", "std"),
                steps_mean=("steps_trained", "mean"),
                steps_median=("steps_trained", "median"),
                first_pass_step_median=("first_pass_step_num", "median"),
                thr_log2=("thr_log2", "first"),
            ).reset_index()
            agg_block["ratio_log2"] = agg_block["D_K"] / agg_block["thr_log2"].replace(0, np.nan)

            # Save per-block CSVs
            block_all_csv = os.path.join(block_dir, "block_all_runs.csv")
            block_agg_csv = os.path.join(block_dir, "block_grouped_by_h_DK.csv")
            df_block.to_csv(block_all_csv, index=False)
            agg_block.to_csv(block_agg_csv, index=False)
            print("Saved block CSVs:", block_all_csv, "and", block_agg_csv)

            # Plot multi-line (one line per h) and save now
            _plot_block_by_h(agg_block, m_val, d_val, block_dir)

            # Compute minimum DK over all heads (exists some h with pass_rate == 1.0 / 0.8 / 0.5)
            mins = {"m": m_val, "d_model": d_val, "thr_log2": thr_l2}
            for cut, label in [(1.0, "100"), (0.80, "80"), (0.50, "50")]:
                cand = agg_block.loc[agg_block["pass_rate"] >= cut, ["D_K", "h"]].sort_values("D_K")
                mins[f"DK_star_anyh_pass{label}"] = (cand["D_K"].iloc[0] if len(cand) else np.nan)
                mins[f"h_for_DK_star_pass{label}"] = (cand["h"].iloc[0] if len(cand) else np.nan)
            block_min_rows.append(mins)

            # Keep for global concat
            all_rows.append(df_block)

    # === Global summaries ===
    if len(all_rows) == 0:
        print("No rows collected; sweep did not run.")
        return

    df_all = pd.concat(all_rows, ignore_index=True)

    # Global grouped with h included
    group_cols_global = ["m", "d_model", "h", "D_K"]
    agg_global = df_all.groupby(group_cols_global).agg(
        runs=("passed", "count"),
        pass_rate=("passed", "mean"),
        microF1_mean=("microF1_test", "mean"),
        microF1_std=("microF1_test", "std"),
        margin_mean=("margin", "mean"),
        margin_std=("margin", "std"),
        steps_mean=("steps_trained", "mean"),
        steps_median=("steps_trained", "median"),
        thr_log2=("thr_log2", "first")
    ).reset_index()
    agg_global["ratio_log2"] = agg_global["D_K"] / agg_global["thr_log2"].replace(0, np.nan)

    # Minimum DK per (m,d) over all heads (as requested)
    df_min_anyh = pd.DataFrame(block_min_rows)
    df_min_anyh["DK_star_anyh_ratio_log2_pass100"] = df_min_anyh["DK_star_anyh_pass100"] / df_min_anyh["thr_log2"]
    df_min_anyh["DK_star_anyh_ratio_log2_pass80"]  = df_min_anyh["DK_star_anyh_pass80"]  / df_min_anyh["thr_log2"]
    df_min_anyh["DK_star_anyh_ratio_log2_pass50"]  = df_min_anyh["DK_star_anyh_pass50"]  / df_min_anyh["thr_log2"]

    # Save global artifacts at sweep root
    all_csv = os.path.join(sweep_root_dir, "sweep_all_runs.csv")
    agg_csv = os.path.join(sweep_root_dir, "sweep_grouped_by_m_d_h_DK.csv")
    min_anyh_csv = os.path.join(sweep_root_dir, "sweep_min_DK_over_heads_per_m_d.csv")
    df_all.to_csv(all_csv, index=False)
    agg_global.to_csv(agg_csv, index=False)
    df_min_anyh.to_csv(min_anyh_csv, index=False)
    print("\nSweep saved:")
    print(" • All runs:", all_csv)
    print(" • Grouped per (m, d_model, h, D_K):", agg_csv)
    print(" • Min DK over heads per (m, d_model):", min_anyh_csv)

    # Optional combined overlay (all (m,d,h) curves) — can get busy
    try:
        plt.figure(figsize=(8.0,5.2))
        for (mm, dd, hh), sub in agg_global.groupby(["m", "d_model", "h"]):
            sub_sorted = sub.sort_values("ratio_log2")
            x = sub_sorted["ratio_log2"].to_numpy()
            y = sub_sorted["microF1_mean"].to_numpy()
            yerr = sub_sorted["microF1_std"].fillna(0.0).to_numpy()
            plt.errorbar(x, y, yerr=yerr, fmt="-o", capsize=3, alpha=0.80, label=f"m={mm}, d={dd}, h={hh}")
        plt.axvline(1.0, linestyle="--", linewidth=1)
        plt.xlabel("DK / (m log2 m / d_model)")
        plt.ylabel("Mean micro‑F1 (test)")
        plt.title("F1 vs Capacity Ratio — ALL (m, d_model, h)")
        plt.legend(ncol=2, fontsize=8)
        plt.tight_layout()
        fig_path_all = os.path.join(sweep_root_dir, "f1_vs_ratio_ALL_byH.png")
        plt.savefig(fig_path_all, dpi=150)
        plt.close()
        print(" • Combined plot:", fig_path_all)
    except Exception as e:
        print("Combined plotting skipped due to:", repr(e))

    # === Plot heatmap of min DK over heads per (m, d_model) ===
    try:
        # Use pass100 criterion for robustness; fall back handled in CSV
        pivot = df_min_anyh.pivot(index="m", columns="d_model", values="DK_star_anyh_pass100")
        plt.figure(figsize=(1.2 + 0.8*pivot.shape[1], 1.2 + 0.6*pivot.shape[0]))
        im = plt.imshow(pivot, aspect="auto", interpolation="nearest")
        plt.colorbar(im, label="Min DK (exists some h, pass rate = 100%)")
        plt.xticks(ticks=np.arange(pivot.shape[1]), labels=list(pivot.columns))
        plt.yticks(ticks=np.arange(pivot.shape[0]), labels=list(pivot.index))
        plt.xlabel("d_model")
        plt.ylabel("m")
        plt.title("Minimum DK over heads per (m, d_model)")
        # annotate cells
        for i in range(pivot.shape[0]):
            for j in range(pivot.shape[1]):
                val = pivot.iloc[i, j]
                txt = "—" if (pd.isna(val)) else str(int(val))
                plt.text(j, i, txt, ha="center", va="center", fontsize=9)
        plt.tight_layout()
        heatmap_path = os.path.join(sweep_root_dir, "min_DK_over_heads_heatmap.png")
        plt.savefig(heatmap_path, dpi=150)
        plt.close()
        print(" • Min DK heatmap:", heatmap_path)
    except Exception as e:
        print("Heatmap plotting skipped due to:", repr(e))

# Execute
if __name__ == "__main__":
    if RUN_SWEEP:
        run_sweep()
    else:
        run_experiment()







#@title Aggregate permutation-graph sweep CSVs into a "min DK over heads" heatmap + line plot (log y with powers of 2)
#@markdown **Instructions:** Set `TARGET_DIR` to the folder in Google Drive that contains your *_grouped_by_m_d_h_DK*.csv summary files.
#@markdown <br/>By default we treat "minimum DK" as the smallest DK for which `microF1_mean >= 0.9899`. If you intended the opposite, change `COMPARE_OP` to `"<"`.
#@markdown <br/>**New:** Line plot uses a log y-axis with **major ticks at powers of 2 only**, labeled as \(2^k\). No fractional ticks.
#@markdown <br/>**New:** Error-bar toggle on per-(m,d_model) plots — set `ERRORBAR_MODE` to `"std"` (±1 SD) or `"ci95"` (95% CI for the mean).
#@markdown <br/>**This version enforces a single selection policy everywhere:**
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• For each (m, d_model), pick the **min total D_K** that meets the F1 criterion;
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• Among rows at that min D_K, pick the **head with max F1** (ties → **more heads**);
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• Exact duplicate configs keep the **best row** (max F1, then higher pass_rate, then higher runs).
#@markdown <br/><br/> Provide exclusion lists at the top:
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• `EXCLUDE_M`: list of `m` values to exclude (e.g., `[16384, 32768]`);
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• `EXCLUDE_PAIRS`: list of `(m, d_model)` pairs to exclude (e.g., `[(16384, 2048), (32768, 4096)]`).
#@markdown <br/>All excluded values are removed **before** any computation, tables, or plots.

# === Parameters ===
TARGET_DIR = "/content/drive/MyDrive/Self Attention/Ten trial sweep/Outputs"  #@param {type:"string"}
#TARGET_DIR = "/content/drive/MyDrive/Self Attention/Outputs2"  #@param {type:"string"}
SEARCH_SUBDIRS = False                                      #@param {type:"boolean"}
THRESHOLD = 0.99                                            #@param {type:"number"}
COMPARE_OP = ">="                                           #@param [">=", "<"] {allow-input: true}
#RANDOM_SEED = None                                         #@param {type:"raw"}  # (unused: dedup is deterministic by best F1)

# --- Exclusions ---
#EXCLUDE_M = [1024,2048,4096]            #@param {type:"raw"}  # e.g., [128, 256]
EXCLUDE_M = []            #@param {type:"raw"}  # e.g., [128, 256]
#EXCLUDE_PAIRS = [(128,16),(256,16),(512,16)]        #@param {type:"raw"}  # e.g., [(128, 16), (256, 32)]
EXCLUDE_PAIRS = []        #@param {type:"raw"}  # e.g., [(128, 16), (256, 32)]

# Error-bar mode (affects only the per-(m,d_model) F1 vs DK figures)
ERRORBAR_MODE = "ci95"  #@param ["std", "ci95"] {allow-input: true}
# "std"  -> draw ±1 sample standard deviation across runs
# "ci95" -> draw two-sided 95% CI half-width for the mean
#          uses Student's t (df = runs-1) when 'runs' is available; otherwise falls back to z≈1.96 if SE is present
# --- Runs filter ---
EXCLUDE_SINGLE_RUN = False  #@param {type:"boolean"}  # If True, drop rows with runs == 1 before any computation



# Diagonal plots configuration (m/d_model ratios to consider)
DIAGONAL_RATIOS = [8]  # e.g., 8 means m == 8 * d_model

# Selection tolerance for floating F1 ties (treat values within this as equal)
F1_TOL = 1e-12

# === Imports ===
import os, glob, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import matplotlib.ticker as mticker

# === Mount Drive (safe if already mounted) ===
try:
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive", force_remount=False)
except Exception as e:
    print("Drive mount skipped or failed (running locally perhaps):", repr(e))

# === Validate directory ===
if not os.path.isdir(TARGET_DIR):
    # try relative to MyDrive
    alt = os.path.join("/content/drive/MyDrive", TARGET_DIR.strip("/"))
    if os.path.isdir(alt):
        TARGET_DIR = alt
    else:
        raise FileNotFoundError(f"Directory not found: {TARGET_DIR}")

# Ensure directory exists (and is writable)
os.makedirs(TARGET_DIR, exist_ok=True)
print(f"\nReading summary CSVs from: {TARGET_DIR}")

# === Find candidate files ===
pattern = "**/*grouped_by_*.csv" if SEARCH_SUBDIRS else "*grouped_by_*.csv"
paths = glob.glob(os.path.join(TARGET_DIR, pattern), recursive=SEARCH_SUBDIRS)
paths = sorted(set(paths))

if not paths:
    raise FileNotFoundError("No files matched *grouped_by_m_d_h_DK*.csv in the target directory.")

REQUIRED_COLUMNS = {"m", "d_model", "h", "D_K", "microF1_mean"}  # minimum set

# === Quick header validation (order-independent) ===
for p in paths:
    try:
        header = pd.read_csv(p, nrows=0)
        cols = set(header.columns)
        if REQUIRED_COLUMNS.issubset(cols):
            good_files.append(p)
        else:
            bad_files.append((
                p,
                f"missing required columns; found={sorted(header.columns)}, need≥{sorted(REQUIRED_COLUMNS)}"
            ))
    except Exception as e:
        bad_files.append((p, f"read_csv failed: {repr(e)}"))



if not good_files:
    raise RuntimeError("No valid CSVs found. All candidate files failed the format check.")

print(f"Found {len(good_files)} valid file(s), {len(bad_files)} skipped due to format.\n")
if bad_files:
    print("Skipped files (reason shown):")
    for b, why in bad_files:
        print(" -", os.path.basename(b), "->", why)
    print()

# === Load and normalize ===
dfs = []
for p in good_files:
    try:
        df = pd.read_csv(p)
        # enforce numeric types
        for c in ["m", "d_model", "h", "D_K", "runs"]:
            if c in df.columns:
                df[c] = pd.to_numeric(df[c], errors="coerce")
        for c in ["pass_rate", "microF1_mean"]:
            if c in df.columns:
                df[c] = pd.to_numeric(df[c], errors="coerce")
        # keep only rows with required info
        df = df.dropna(subset=["m", "d_model", "h", "D_K", "microF1_mean"])
        dfs.append(df)
    except Exception as e:
        print("Read failed for", p, "->", repr(e))

if not dfs:
    raise RuntimeError("After reading, no usable rows were found.")

df_all = pd.concat(dfs, ignore_index=True)

# Cast to ints where appropriate (safe via dropna above)
for c in ["m", "d_model", "h", "D_K", "runs"]:
    if c in df_all.columns:
        df_all[c] = df_all[c].astype(int)

# === Exclusion helpers (NEW) ===
def _normalize_exclude_m(val):
    """Return a set of ints for EXCLUDE_M."""
    s = set()
    if val is None:
        return s
    if isinstance(val, (int, np.integer)):
        return {int(val)}
    try:
        for v in val:
            if v is None:
                continue
            if isinstance(v, (list, tuple)) and len(v) >= 1:
                s.add(int(v[0]))
            else:
                s.add(int(v))
    except Exception:
        try:
            s.add(int(val))
        except Exception:
            pass
    return s

def _normalize_exclude_pairs(val):
    """Return a set of (int, int) pairs for EXCLUDE_PAIRS."""
    s = set()
    if val is None:
        return s
    try:
        for item in val:
            if item is None:
                continue
            if isinstance(item, (tuple, list)) and len(item) >= 2:
                m, d = item[0], item[1]
                if m is None or d is None:
                    continue
                s.add((int(m), int(d)))
            elif isinstance(item, str) and ("," in item):
                parts = item.split(",")
                if len(parts) >= 2:
                    s.add((int(parts[0].strip()), int(parts[1].strip())))
    except Exception:
        pass
    return s

def _apply_exclusions(df, excl_m_set, excl_pair_set):
    """Drop rows where m ∈ excl_m_set OR (m,d_model) ∈ excl_pair_set."""
    if df.empty or (not excl_m_set and not excl_pair_set):
        return df.copy(), 0
    mask_m = df["m"].isin(excl_m_set) if excl_m_set else pd.Series(False, index=df.index)
    if excl_pair_set:
        pair_series = pd.Series(list(zip(df["m"].values, df["d_model"].values)), index=df.index)
        mask_pairs = pair_series.isin(excl_pair_set)
    else:
        mask_pairs = pd.Series(False, index=df.index)
    mask = mask_m | mask_pairs
    removed = int(mask.sum())
    return df.loc[~mask].copy(), removed

def _drop_single_run_rows(df: pd.DataFrame, enabled: bool):
    """
    If enabled is True, remove rows with runs == 1.
    If the 'runs' column is missing, print a note and do nothing.
    Returns (filtered_df, removed_count).
    """
    if not enabled:
        return df.copy(), 0
    if "runs" not in df.columns:
        print("Single-run exclusion enabled but 'runs' column not found; skipping.")
        return df.copy(), 0

    mask = (df["runs"] == 1)
    removed = int(mask.sum())
    return df.loc[~mask].copy(), removed


# Normalize + apply exclusions **before** any selection/plots (NEW)
EXCLUDE_M_SET = _normalize_exclude_m(EXCLUDE_M)
EXCLUDE_PAIR_SET = _normalize_exclude_pairs(EXCLUDE_PAIRS)

if EXCLUDE_M_SET or EXCLUDE_PAIR_SET:
    print(f"Exclusion lists active — EXCLUDE_M: {sorted(EXCLUDE_M_SET)} ; EXCLUDE_PAIRS: {sorted(EXCLUDE_PAIR_SET)}")
else:
    print("Exclusion lists empty — no rows will be dropped.")

df_all, _removed_rows = _apply_exclusions(df_all, EXCLUDE_M_SET, EXCLUDE_PAIR_SET)
print(f"Rows remaining after exclusions: {len(df_all)}")

if df_all.empty:
    raise RuntimeError("All rows were excluded by EXCLUDE_M / EXCLUDE_PAIRS. Nothing to compute or plot.")

# --- Apply the single-run filter (NEW) ---
df_all, removed_single = _drop_single_run_rows(df_all, EXCLUDE_SINGLE_RUN)
if EXCLUDE_SINGLE_RUN:
    print(f"Single-run rows filter active — removed {removed_single} row(s) with runs == 1.")
else:
    print("Single-run rows filter inactive.")
print(f"Rows remaining after single-run filter: {len(df_all)}")

# === Build per-head F1 stats for the plotting script ===
# We aggregate over all rows with the same (m, d_model, h, D_K).
# If microF1_std and microF1_n (or runs) exist, we compute a pooled std; otherwise we fall back
# to sample std across rows of microF1_mean (NaN if only one row).

def _build_per_head(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    # Identify count column to use if present
    count_col = None
    if "microF1_n" in df.columns:
        count_col = "microF1_n"
    elif "runs" in df.columns:
        count_col = "runs"

    def _combine(group: pd.DataFrame) -> pd.Series:
        mu = pd.to_numeric(group["microF1_mean"], errors="coerce").astype(float)
        # choose n_i for each row (>=1)
        if count_col is not None:
            n_i = pd.to_numeric(group[count_col], errors="coerce").fillna(1).astype(int).clip(lower=1)
        else:
            n_i = pd.Series(1, index=group.index)

        N = int(n_i.sum()) if len(n_i) else 0
        mu_bar = float((mu * n_i).sum() / N) if N > 0 else float("nan")

        # pooled variance if we have per-row stds; else fallback to sample std across means
        if "microF1_std" in group.columns and group["microF1_std"].notna().any():
            s = pd.to_numeric(group["microF1_std"], errors="coerce").fillna(0.0).astype(float)
            # pooled (unbiased) variance across groups:
            within = ((n_i - 1).clip(lower=0) * (s ** 2)).sum()
            between = (n_i * (mu - mu_bar) ** 2).sum()
            denom = max(N - 1, 1)
            var = float((within + between) / denom)
            std = float(np.sqrt(max(var, 0.0)))
        else:
            std = float(mu.std(ddof=1)) if len(mu) > 1 else float("nan")

        return pd.Series({
            "F1_mean": mu_bar,
            "F1_std": std,
            "F1_n": int(N if N > 0 else len(group))
        })

    per_head = (
        df.groupby(["m", "d_model", "h", "D_K"], as_index=False)
          .apply(_combine)
          .reset_index(drop=True)
    )

    # enforce dtypes expected by the plotting script
    per_head["m"] = per_head["m"].astype(int)
    per_head["d_model"] = per_head["d_model"].astype(int)
    per_head["h"] = per_head["h"].astype(int)
    per_head["D_K"] = pd.to_numeric(per_head["D_K"], errors="coerce").astype(float)
    per_head["F1_mean"] = pd.to_numeric(per_head["F1_mean"], errors="coerce").astype(float)
    per_head["F1_std"] = pd.to_numeric(per_head["F1_std"], errors="coerce").astype(float)
    per_head["F1_n"] = pd.to_numeric(per_head["F1_n"], errors="coerce").astype(int)

    return per_head

per_head = _build_per_head(df_all)
# Optional: save for inspection
per_head_csv = os.path.join(TARGET_DIR, "per_head_stats.csv")
per_head.to_csv(per_head_csv, index=False)
print(f"Built per_head stats with {len(per_head)} rows → {per_head_csv}")


if df_all.empty:
    raise RuntimeError("All rows were excluded by EXCLUDE_M / EXCLUDE_PAIRS / single-run filter. Nothing to compute or plot.")



# === Deduplicate exact configs by BEST row: (m, d_model, h, D_K) ===
# Keep the single row with:
#   1) highest microF1_mean
#   2) then highest pass_rate (if present)
#   3) then largest runs (if present)
def _pick_best(group: pd.DataFrame) -> pd.DataFrame:
    sort_cols = ["microF1_mean"]
    ascending = [False]
    if "pass_rate" in group.columns:
        sort_cols.append("pass_rate")
        ascending.append(False)
    if "runs" in group.columns:
        sort_cols.append("runs")
        ascending.append(False)
    # Deterministic within equalities due to pandas stable sort
    g = group.sort_values(sort_cols, ascending=ascending)
    return g.iloc[[0]]

keys = ["m", "d_model", "h", "D_K"]
unique_df = df_all.groupby(keys, group_keys=False).apply(_pick_best).reset_index(drop=True)

print(f"Loaded rows (post-exclusion): {len(df_all)} ; Unique configs retained (best per exact config): {len(unique_df)}")

# === Save aggregated (best one per (m, d_model, h, D_K)) ===
aggregated_csv = os.path.join(TARGET_DIR, "aggregated_grouped_by_m_d_h_DK.csv")
(
    unique_df
    .sort_values(["m", "d_model", "h", "D_K"])
    .to_csv(aggregated_csv, index=False)
)
# (No stdout on purpose)

# === Select rows that meet (or fail) the threshold, per COMPARE_OP ===
if COMPARE_OP.strip() == ">=":
    ok = unique_df[unique_df["microF1_mean"] >= THRESHOLD].copy()
    criterion_desc = f"microF1_mean ≥ {THRESHOLD}"
elif COMPARE_OP.strip() == "<":
    ok = unique_df[unique_df["microF1_mean"] < THRESHOLD].copy()
    criterion_desc = f"microF1_mean < {THRESHOLD}"
else:
    raise ValueError("COMPARE_OP must be \">=\" or \"<\"")

if ok.empty:
    raise RuntimeError(f"No configurations satisfy the criterion ({criterion_desc}) after applying exclusions. Nothing to plot.")

# === For each (m, d_model), choose:
#     (1) the smallest DK meeting the criterion,
#     (2) among those rows at that DK, the head(s) with maximum F1,
#     (3) if still tied, prefer MORE heads (larger h).
#     This selection is used everywhere downstream.
minDK_for_pair = ok.groupby(["m", "d_model"])["D_K"].transform("min")

# Rows at the minimum D_K for each (m, d_model)
at_minDK = ok.loc[ok["D_K"] == minDK_for_pair].copy()

# Among those, filter to those achieving the max F1 (within tolerance)
maxF1_for_pair = at_minDK.groupby(["m", "d_model"])["microF1_mean"].transform("max")
cands = at_minDK.loc[at_minDK["microF1_mean"] >= (maxF1_for_pair - F1_TOL)].copy()

# Final tie-break on h: prefer MORE heads
min_rows = (
    cands.sort_values(["m", "d_model", "h"], ascending=[True, True, False])
         .drop_duplicates(["m", "d_model"], keep="first")
         .reset_index(drop=True)
)

# === Determine if chosen DK equals the smallest measured DK for that chosen head (yellow flag) ===
min_measured_per_head = unique_df.groupby(["m", "d_model", "h"], as_index=False)["D_K"].min()
min_measured_per_head = min_measured_per_head.rename(columns={"D_K": "min_measured_DK_for_head"})

min_rows = min_rows.merge(min_measured_per_head, on=["m", "d_model", "h"], how="left")
min_rows["yellow"] = (min_rows["D_K"] == min_rows["min_measured_DK_for_head"])

# === Pivot to a grid: rows=m, cols=d_model, values=DK, and a parallel boolean mask for yellow ===
vals = min_rows.pivot(index="m", columns="d_model", values="D_K")
flags = min_rows.pivot(index="m", columns="d_model", values="yellow")

# ensure consistent ordering
vals = vals.sort_index().sort_index(axis=1)
flags = flags.reindex_like(vals)

m_values = list(vals.index)
d_values = list(vals.columns)


# Also save the table used for the heatmap (with the chosen head and flag)
out_csv = os.path.join(TARGET_DIR, "min_DK_over_heads_table.csv")
min_rows_to_save = min_rows[["m", "d_model", "h", "D_K", "microF1_mean", "yellow"]].sort_values(["m","d_model"])
min_rows_to_save.to_csv(out_csv, index=False)
print("Saved table to:", out_csv, "| Exists?", os.path.exists(out_csv))


# ====================================================================================
# Combined heatmap: center = mean DK; upper-right = DK upper bound (CI worst-case);
# lower-left = DK lower bound (CI best-case). Replaces the two separate CI heatmaps.
# Relies on: unique_df, TARGET_DIR, COMPARE_OP, THRESHOLD, F1_TOL,
# and that the first (mean) heatmap has already produced `vals`, `m_values`, `d_values`.
# ====================================================================================

# --- Discover error-bar columns (same logic as before) ---
_STD_COLS = ["microF1_std", "microF1_stddev"]
_SE_COLS  = ["microF1_stderr", "microF1_se", "microF1_sem"]
_CI_COLS  = ["microF1_ci", "microF1_ci95", "microF1_CI"]

STD_COL = next((c for c in _STD_COLS if c in unique_df.columns), None)
SE_COL  = next((c for c in _SE_COLS  if c in unique_df.columns), None)
CI_COL  = next((c for c in _CI_COLS  if c in unique_df.columns), None)

# 95% two-sided t criticals for df = 1..30
_T_CRIT_95 = {
    1:12.706, 2:4.303, 3:3.182, 4:2.776, 5:2.571, 6:2.447, 7:2.365, 8:2.306, 9:2.262,
    10:2.228, 11:2.201, 12:2.179, 13:2.160, 14:2.145, 15:2.131, 16:2.120, 17:2.110,
    18:2.101, 19:2.093, 20:2.086, 21:2.080, 22:2.074, 23:2.069, 24:2.064, 25:2.060,
    26:2.056, 27:2.052, 28:2.048, 29:2.045, 30:2.042
}
def _tcrit_95_df(df_val: int) -> float:
    df_val = int(max(1, df_val))
    if df_val <= 30:
        return _T_CRIT_95[df_val]
    elif df_val <= 60:
        return 2.000
    else:
        return 1.960

def _sanitize_err(a):
    a = np.asarray(a, dtype=float)
    a[~np.isfinite(a)] = np.nan
    a[a < 0] = np.nan
    return a

def _ci95_halfwidth_series(df: pd.DataFrame) -> pd.Series:
    """Return per-row 95% CI half-width for the mean F1."""
    n   = df["runs"].to_numpy(dtype=float) if ("runs" in df.columns) else None
    std = df[STD_COL].to_numpy(dtype=float) if STD_COL else None
    se  = df[SE_COL ].to_numpy(dtype=float) if SE_COL  else None
    ci  = df[CI_COL ].to_numpy(dtype=float) if CI_COL  else None

    if ci is not None:
        yerr_ci95 = ci.copy()  # already half-width
    elif se is not None:
        if n is not None:
            df_arr = np.maximum(1, n.astype(int) - 1)
            tcrit = np.array([_tcrit_95_df(int(d)) for d in df_arr])
            yerr_ci95 = tcrit * se
        else:
            yerr_ci95 = 1.96 * se
    elif (std is not None) and (n is not None):
        df_arr = np.maximum(1, n.astype(int) - 1)
        tcrit = np.array([_tcrit_95_df(int(d)) for d in df_arr])
        with np.errstate(divide="ignore", invalid="ignore"):
            yerr_ci95 = np.where(n > 0, tcrit * std / np.sqrt(n), np.nan)
    else:
        return pd.Series(np.nan, index=df.index)

    return pd.Series(_sanitize_err(yerr_ci95), index=df.index)

# Compute CI half-widths and bounds
unique_df_ci = unique_df.copy()
unique_df_ci["microF1_ci95_hw"] = _ci95_halfwidth_series(unique_df_ci)

if unique_df_ci["microF1_ci95_hw"].isna().all():
    print("Combined CI heatmap: cannot compute 95% CI half-widths (no std/sem/ci available).")
    print("Will plot only the mean in the center; CI corners will show '—'.")
    unique_df_ci["microF1_lb95"] = np.nan
    unique_df_ci["microF1_ub95"] = np.nan
else:
    unique_df_ci["microF1_lb95"] = np.clip(unique_df_ci["microF1_mean"] - unique_df_ci["microF1_ci95_hw"], 0.0, 1.0)
    unique_df_ci["microF1_ub95"] = np.clip(unique_df_ci["microF1_mean"] + unique_df_ci["microF1_ci95_hw"], 0.0, 1.0)
    present_rows = int(unique_df_ci["microF1_ci95_hw"].notna().sum())
    print(f"Combined CI heatmap: computed 95% CI half-widths for {present_rows}/{len(unique_df_ci)} rows.")

def _minDK_grid_for_metric(df_src: pd.DataFrame, metric_col: str) -> pd.DataFrame:
    """Return a pivot grid (m x d_model) of the minimum DK meeting the threshold using `metric_col`."""
    # Filter by threshold using this metric
    cmp = COMPARE_OP.strip()
    if cmp == ">=":
        ok = df_src[df_src[metric_col] >= THRESHOLD].copy()
    elif cmp == "<":
        ok = df_src[df_src[metric_col] < THRESHOLD].copy()
    else:
        raise ValueError("COMPARE_OP must be \">=\" or \"<\"")

    if ok.empty:
        # Return an all-NaN grid on the same (m,d) domain as the mean `vals`
        empty = pd.DataFrame(index=vals.index, columns=vals.columns, dtype=float)
        return empty

    # (1) smallest DK meeting criterion per (m, d_model)
    minDK_for_pair = ok.groupby(["m", "d_model"])["D_K"].transform("min")

    # (2) subset at that minimum DK
    at_minDK = ok.loc[ok["D_K"] == minDK_for_pair].copy()

    # (3) within those, keep max metric (within tolerance)
    maxMetric_for_pair = at_minDK.groupby(["m", "d_model"])[metric_col].transform("max")
    cands = at_minDK.loc[at_minDK[metric_col] >= (maxMetric_for_pair - F1_TOL)].copy()

    # (4) tie-break on h: prefer MORE heads
    chosen = (
        cands.sort_values(["m", "d_model", "h"], ascending=[True, True, False])
             .drop_duplicates(["m", "d_model"], keep="first")
             .reset_index(drop=True)
    )

    # (5) pivot to grid and align to the mean grid's ordering
    grid = chosen.pivot(index="m", columns="d_model", values="D_K")
    grid = grid.sort_index().sort_index(axis=1)
    grid = grid.reindex(index=vals.index, columns=vals.columns)
    return grid

# Build DK grids:
# - mean DK grid is the earlier `vals` from your first heatmap
# - upper DK bound (pessimistic) uses the LOWER F1 bound
dk_upper = _minDK_grid_for_metric(unique_df_ci, "microF1_lb95")
# - lower DK bound (optimistic) uses the UPPER F1 bound
dk_lower = _minDK_grid_for_metric(unique_df_ci, "microF1_ub95")

# --- Plot the combined grid with three numbers per cell ---
fig_w = 1.2 + 0.9 * max(1, len(d_values))
fig_h = 1.2 + 0.9 * max(1, len(m_values))
fig3, ax3 = plt.subplots(figsize=(fig_w, fig_h))

# subtle grey background to delineate cells
bg = np.zeros((len(m_values), len(d_values)))
_ = ax3.imshow(bg, aspect="auto", vmin=0, vmax=1, cmap="Greys", alpha=0.08)

# grid / ticks
ax3.set_xticks(np.arange(len(d_values)))
ax3.set_yticks(np.arange(len(m_values)))
ax3.set_xticklabels([str(int(x)) for x in d_values], rotation=0)
ax3.set_yticklabels([str(int(y)) for y in m_values])
ax3.set_xticks(np.arange(-.5, len(d_values), 1), minor=True)
ax3.set_yticks(np.arange(-.5, len(m_values), 1), minor=True)
ax3.grid(which="minor", color="black", linewidth=0.5, alpha=0.25)
ax3.tick_params(top=False, bottom=True, left=True, right=False)

# text sizes: center big, corners small
FS_CENTER = 13
FS_CORNER = 10
OFF = 0.40  # how far into the corner to place the small numbers

for i, m_val in enumerate(m_values):
    for j, d_val in enumerate(d_values):
        # center (mean)
        v_mean  = vals.iloc[i, j]
        # upper-right corner: DK upper bound (pessimistic) -> uses F1 lower bound
        v_upper = dk_upper.iloc[i, j]
        # lower-left corner: DK lower bound (optimistic) -> uses F1 upper bound
        v_lower = dk_lower.iloc[i, j]

        # formatters
        c_txt = "—" if pd.isna(v_mean)  else str(int(v_mean))
        u_txt = "—" if pd.isna(v_upper) else str(int(v_upper))
        l_txt = "—" if pd.isna(v_lower) else str(int(v_lower))

        # draw center (big)
        ax3.text(j, i, c_txt, ha="center", va="center",
                 fontsize=FS_CENTER, fontweight="bold", color="black")

        # draw upper-right (small)
        ax3.text(j + OFF, i - OFF, u_txt, ha="right", va="top",
                 fontsize=FS_CORNER, color="black")

        # draw lower-left (small)
        ax3.text(j - OFF, i + OFF, l_txt, ha="left", va="bottom",
                 fontsize=FS_CORNER, color="black")

title = ("Minimum DK achieved")
ax3.set_title(title)
ax3.set_xlabel(r"$d_{\mathrm{model}}$")
ax3.set_ylabel("m")
plt.tight_layout()

# Save figure
combined_path = os.path.join(TARGET_DIR, "min_DK_over_heads_heatmap_mean_with_CI_bounds.png")
fig3.savefig(combined_path, dpi=150)
print("\nSaved heatmap to:", combined_path, "| Exists?", os.path.exists(combined_path))
plt.show()
plt.close(fig3)

# Save a tidy table with the three DK numbers for each (m, d_model)
mean_long  = vals.stack(dropna=False).rename("DK_mean").reset_index()
upper_long = dk_upper.stack(dropna=False).rename("DK_upper").reset_index()
lower_long = dk_lower.stack(dropna=False).rename("DK_lower").reset_index()

triplet_df = (mean_long
              .merge(upper_long, on=["m","d_model"], how="left")
              .merge(lower_long, on=["m","d_model"], how="left")
              .sort_values(["m","d_model"]))

out_triplet_csv = os.path.join(TARGET_DIR, "min_DK_over_heads_table_mean_with_CI_bounds.csv")
triplet_df.to_csv(out_triplet_csv, index=False)
print("Saved combined table to:", out_triplet_csv, "| Exists?", os.path.exists(out_triplet_csv))


# --------------------------------------------------------------------------------
# === TABLE — Per-head size d_k for the SELECTED head (max F1 at min D_K)
# --------------------------------------------------------------------------------

selected_with_dk = min_rows.copy()
selected_with_dk["d_k"] = selected_with_dk["D_K"].astype(float) / selected_with_dk["h"].astype(float)

per_head_rows_to_save = (
    selected_with_dk[["m", "d_model", "h", "D_K", "d_k", "microF1_mean"]]
    .sort_values(["m", "d_model"])
    .copy()
)

# If all d_k are integers, cast to int for cleaner display; otherwise keep as float
dk_vals = per_head_rows_to_save["d_k"].to_numpy(dtype=float)
if np.all(np.isfinite(dk_vals)) and np.allclose(dk_vals, np.round(dk_vals)):
    per_head_rows_to_save["d_k"] = np.round(per_head_rows_to_save["d_k"]).astype(int)
else:
    per_head_rows_to_save["d_k"] = per_head_rows_to_save["d_k"].astype(float)

per_head_csv = os.path.join(TARGET_DIR, "min_per_head_dk_at_min_DK_table.csv")
per_head_rows_to_save.to_csv(per_head_csv, index=False)
print("Saved per-head table (selected head) to:", per_head_csv, "| Exists?", os.path.exists(per_head_csv))

print("\nPer-head size d_k for the selected head at the min total D_K (full table):")
print(per_head_rows_to_save.to_string(index=False))



# --------------------------------------------------------------------------------
# LINE PLOT with error bars — y = number of heads (h) that produced the minimum DK,
# per (m, d_model). Error bars computed exactly as in the scatter script:
#   ±10% DK window around the winner + one-sided Welch t-test (alpha=0.05)
# X-axis: log2 with powers-of-two tick labels; Y-axis: log with powers-of-two ticks
# --------------------------------------------------------------------------------

import os, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

# ---------------------------------------
# Pre-flight checks and light conveniences
# ---------------------------------------
if "TARGET_DIR" not in globals():
    TARGET_DIR = "."

if "min_rows" not in globals() or not isinstance(min_rows, pd.DataFrame):
    raise RuntimeError(
        "This script expects a DataFrame `min_rows` with columns ['m','d_model','h'] "
        "containing the chosen head per (m, d_model)."
    )

# ---------------------------------------
# Helper to recover per-head stats exactly like in the second script
# Required columns in the returned df:
#   ['m','d_model','h','D_K','F1_mean','F1_std','F1_n']
# ---------------------------------------
def _coerce_per_head_df():
    tidy_candidates = (
        "per_head", "stats_long", "triplet_df_full", "all_stats",
        "summary_long", "f1_stats_by_head"
    )
    for name in tidy_candidates:
        if name in globals():
            obj = globals()[name]
            if isinstance(obj, pd.DataFrame):
                df = obj.copy()
                rename = {}
                for c in df.columns:
                    lc = str(c).lower()
                    if lc in ("d_k", "dk"):                      rename[c] = "D_K"
                    if lc in ("f1_mean", "f1_mu"):               rename[c] = "F1_mean"
                    if lc in ("f1_std","f1_sigma","std_f1"):     rename[c] = "F1_std"
                    if lc in ("f1_n","n_f1","count_f1","n"):     rename[c] = "F1_n"
                df = df.rename(columns=rename)
                need = {"m","d_model","h","D_K","F1_mean","F1_std","F1_n"}
                if need.issubset(set(df.columns)):
                    out = df[list(need)].copy()
                    out["m"] = out["m"].astype(int)
                    out["d_model"] = out["d_model"].astype(int)
                    out["h"] = out["h"].astype(int)
                    out["D_K"] = out["D_K"].astype(float)
                    out["F1_mean"] = out["F1_mean"].astype(float)
                    out["F1_std"] = out["F1_std"].astype(float)
                    out["F1_n"] = out["F1_n"].astype(int)
                    return out

    grid_names_mean = ("F1_mean","f1_mean","F1_mu","f1_mu","mean_f1")
    grid_names_std  = ("F1_std","f1_std","F1_sigma","f1_sigma","std_f1")
    grid_names_n    = ("F1_n","f1_n","count_f1","n_f1","N_f1")
    grid_names_dk   = ("D_K","DK","DK_by_head","D_K_by_head","d_k","dk_grid")

    grids = {}
    for key in grid_names_mean:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_mean"] = globals()[key]; break
    for key in grid_names_std:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_std"] = globals()[key]; break
    for key in grid_names_n:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_n"] = globals()[key]; break
    for key in grid_names_dk:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["D_K"] = globals()[key]; break

    if set(grids) == {"F1_mean","F1_std","F1_n","D_K"}:
        def _stack(df, out_name):
            tmp = df.copy()
            tmp.columns.name = "h"
            long = tmp.stack(dropna=False).rename(out_name).reset_index()
            long["h"] = long["h"].astype(int)
            return long

        df = _stack(grids["F1_mean"], "F1_mean") \
            .merge(_stack(grids["F1_std"], "F1_std"), on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["F1_n"],   "F1_n"),   on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["D_K"],    "D_K"),    on=["m","d_model","h"], how="left")

        df["m"] = df["m"].astype(int)
        df["d_model"] = df["d_model"].astype(int)
        df["h"] = df["h"].astype(int)
        df["D_K"] = df["D_K"].astype(float)
        df["F1_mean"] = df["F1_mean"].astype(float)
        df["F1_std"] = df["F1_std"].astype(float)
        df["F1_n"] = df["F1_n"].astype(int)
        return df[["m","d_model","h","D_K","F1_mean","F1_std","F1_n"]].copy()

    return None

per_head = _coerce_per_head_df()
if per_head is None:
    raise RuntimeError(
        "Could not find per-head F1 stats. Provide a DataFrame `per_head` with columns "
        "['m','d_model','h','D_K','F1_mean','F1_std','F1_n'], or adjust the auto-detection."
    )

# ---------------------------------------
# Welch t-test and DK window (same constants as scatter script)
# ---------------------------------------
def p_less_welch(mu_c, sd_c, n_c, mu_w, sd_w, n_w):
    v1 = (sd_c**2) / max(int(n_c), 1)
    v2 = (sd_w**2) / max(int(n_w), 1)
    se2 = v1 + v2
    if not np.isfinite(se2) or se2 <= 1e-16:
        return 1.0 if (mu_c >= mu_w) else 0.0

    t = (mu_c - mu_w) / math.sqrt(se2)
    df_den = 0.0
    if n_c > 1 and v1 > 0:
        df_den += (v1**2) / (n_c - 1)
    if n_w > 1 and v2 > 0:
        df_den += (v2**2) / (n_w - 1)
    df = (se2**2) / df_den if df_den > 0 else 1e9

    try:
        from scipy.stats import t as student_t
        p = float(student_t.cdf(t, df))
    except Exception:
        p = 0.5 * (1.0 + math.erf(t / math.sqrt(2.0)))
    return p

alpha  = 0.05  # 95%
dk_tol = 0.10  # ±10%

# ---------------------------------------
# Compute kept-head bounds per (m, d_model) for winners in min_rows
# ---------------------------------------
winners = (
    min_rows[["m","d_model","h"]]
    .dropna()
    .astype({"m": int, "d_model": int, "h": int})
    .copy()
)

bounds = []
for _, r in winners.iterrows():
    m_val, d_val, h_star = int(r["m"]), int(r["d_model"]), int(r["h"])
    sub = per_head[(per_head["m"]==m_val) & (per_head["d_model"]==d_val)].copy()
    if sub.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    win = sub[sub["h"]==h_star]
    if win.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    mu_w = float(win["F1_mean"].iloc[0])
    sd_w = float(win["F1_std"].iloc[0])
    n_w  = int(win["F1_n"].iloc[0])
    DK_w = float(win["D_K"].iloc[0])

    lo, hi = (1.0 - dk_tol) * DK_w, (1.0 + dk_tol) * DK_w
    near = sub[(sub["D_K"] >= lo) & (sub["D_K"] <= hi)].copy()

    if near.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    keep = []
    for _, c in near.iterrows():
        p_less = p_less_welch(
            float(c["F1_mean"]), float(c["F1_std"]), int(c["F1_n"]),
            mu_w, sd_w, n_w
        )
        keep.append(p_less >= alpha)

    kept = near.loc[keep]
    if kept.empty:
        h_lo = h_hi = h_star
    else:
        h_lo = int(np.nanmin(kept["h"].to_numpy()))
        h_hi = int(np.nanmax(kept["h"].to_numpy()))
        h_lo = min(h_lo, h_star)
        h_hi = max(h_hi, h_star)

    bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_lo, "h_upper": h_hi})

bounds_df = pd.DataFrame(bounds)

# Merge bounds with winners and build asymmetric yerr
sc_df_ci = winners.merge(bounds_df, on=["m","d_model"], how="left")
sc_df_ci["yerr_low"]  = (sc_df_ci["h"] - sc_df_ci["h_lower"]).clip(lower=0).astype(float)
sc_df_ci["yerr_high"] = (sc_df_ci["h_upper"] - sc_df_ci["h"]).clip(lower=0).astype(float)

# Pivot y and yerr to the same (m x d_model) layout as the line plot
h_grid = min_rows.pivot(index="m", columns="d_model", values="h").sort_index().sort_index(axis=1)
# Align x-axis with heatmap's d_model order for consistent comparison:
h_grid = h_grid.reindex(columns=d_values)

yerr_low_grid  = sc_df_ci.pivot(index="m", columns="d_model", values="yerr_low").sort_index().reindex(columns=d_values)
yerr_high_grid = sc_df_ci.pivot(index="m", columns="d_model", values="yerr_high").sort_index().reindex(columns=d_values)

# ---------------------------------------
# Figure and plotting (line + markers + error bars)
# ---------------------------------------
fig2_w = 1.2 + 0.9 * max(1, len(d_values))
fig2_h = 4.8
fig2, ax2 = plt.subplots(figsize=(fig2_w, fig2_h))

xs = np.array([int(x) for x in d_values], dtype=float)

for m_val in h_grid.index:
    ys = h_grid.loc[m_val].astype(float).values  # NaNs break the line where no data
    # Draw the line + markers first
    (line_handle,) = ax2.plot(xs, ys, marker="o", linewidth=1.8, label=f"m={int(m_val)}")

    # Error bars: only at finite points, match line color
    if (m_val in yerr_low_grid.index) and (m_val in yerr_high_grid.index):
        ylo = yerr_low_grid.loc[m_val].astype(float).values
        yhi = yerr_high_grid.loc[m_val].astype(float).values

        mask = np.isfinite(ys) & np.isfinite(ylo) & np.isfinite(yhi)
        if np.any(mask):
            x_pts  = xs[mask]
            y_pts  = ys[mask]
            yerr_l = np.maximum(0.0, ylo[mask])
            yerr_h = np.maximum(0.0, yhi[mask])

            # On a log-y axis, ensure the lower end stays > 0 (conservative epsilon guard)
            eps = 1e-12
            yerr_l = np.minimum(yerr_l, np.maximum(0.0, y_pts - eps))

            ax2.errorbar(
                x_pts, y_pts,
                yerr=np.vstack([yerr_l, yerr_h]),
                fmt="none",
                ecolor=line_handle.get_color(),
                elinewidth=0.9,
                capsize=3,
                alpha=0.85,
#                zorder=1.5,
                zorder=line_handle.get_zorder() + 1
            )

ax2.set_xlabel(r"$d_{\mathrm{model}}$")
ax2.set_ylabel("number of heads")
ax2.set_title(r"Heads producing minimum $D_K$")
ax2.grid(alpha=0.3, linewidth=0.6)

# Log scales for both axes (base 2 for x to match your original)
ax2.set_yscale("log")
ax2.set_xscale("log", base=2)

# --- helper: powers-of-two formatter ---
def _pow2_label(y, _):
    if y <= 0:
        return ""
    k = np.log2(y)
    if np.isclose(k, round(k)):
        return fr"$2^{{{int(round(k))}}}$"
    return ""

# Compute power-of-two major ticks that span the Y data range (no clipping)
all_h = pd.Series(h_grid.values.ravel(), dtype="float64").dropna()
all_h = all_h[all_h > 0]
if not all_h.empty:
    k_min_y = int(np.floor(np.log2(all_h.min())))
    k_max_y = int(np.ceil(np.log2(all_h.max())))
    pow2_ticks_y = [2**k for k in range(k_min_y, k_max_y + 1)]
else:
    pow2_ticks_y = [1, 2, 4, 8, 16, 32]

ax2.yaxis.set_major_locator(mticker.FixedLocator(pow2_ticks_y))
ax2.yaxis.set_minor_locator(mticker.NullLocator())
ax2.yaxis.set_major_formatter(mticker.FuncFormatter(_pow2_label))

# Compute power-of-two major ticks that span the X data range (no clipping)
xs_pos = xs[xs > 0]
if xs_pos.size > 0:
    k_min_x = int(np.floor(np.log2(xs_pos.min())))
    k_max_x = int(np.ceil(np.log2(xs_pos.max())))
    pow2_ticks_x = [2**k for k in range(k_min_x, k_max_x + 1)]
else:
    pow2_ticks_x = [1, 2, 4, 8, 16, 32]

ax2.xaxis.set_major_locator(mticker.FixedLocator(pow2_ticks_x))
ax2.xaxis.set_minor_locator(mticker.NullLocator())
ax2.xaxis.set_major_formatter(mticker.FuncFormatter(_pow2_label))

# Legend placement adapts to the number of m's
if len(h_grid.index) <= 10:
    ax2.legend(title="m", fontsize=9)
else:
    ax2.legend(title="m", fontsize=8, bbox_to_anchor=(1.02, 1.0), loc="upper left", borderaxespad=0.)

plt.tight_layout()

# === Save line plot explicitly to Drive ===
lineplot_path = os.path.join(TARGET_DIR, "h_that_produced_min_DK_lineplot.png")
fig2.savefig(lineplot_path, dpi=150, bbox_inches="tight")
print("Saved line plot (with error bars) to:", lineplot_path, "| Exists?", os.path.exists(lineplot_path))
plt.show()
plt.close(fig2)



# --------------------------------------------------------------------------------
# Scatter: heads vs (m / d_model) with global linear fit + confidence bars
# Confidence bars: for each (m, d_model), consider heads whose DK is within ±10% of the
# chosen optimal DK. Keep heads that cannot be eliminated by a one-sided Welch t-test
# vs. the winner at alpha=0.05 (i.e., p >= 0.05 for H1: mu_c < mu_w). Plot the min/max
# kept head as vertical error bars around the winner's head.
# --------------------------------------------------------------------------------

import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ---------------------------------------
# Pre-flight checks and light conveniences
# ---------------------------------------
if "TARGET_DIR" not in globals():
    TARGET_DIR = "."

if "min_rows" not in globals() or not isinstance(min_rows, pd.DataFrame):
    raise RuntimeError(
        "This script expects a DataFrame `min_rows` with columns ['m','d_model','h'] "
        "containing the chosen head per (m, d_model)."
    )

# ---------------------------------------
# Recover per-m colors from an earlier line plot (ax2), if available
# ---------------------------------------
color_by_m = {}
try:
    for line in ax2.get_lines():
        lbl = line.get_label()
        if isinstance(lbl, str) and lbl.startswith("m="):
            try:
                mm = int(lbl.split("=")[1])
                color_by_m[mm] = line.get_color()
            except Exception:
                pass
except NameError:
    # If ax2 isn't defined, we'll fall back later
    color_by_m = {}

# ---------------------------------------
# Build the scatter base table from the chosen rows
# ---------------------------------------
sc_df = (
    min_rows[["m", "d_model", "h"]]
    .dropna()
    .copy()
)
sc_df["m"] = sc_df["m"].astype(int)
sc_df["d_model"] = sc_df["d_model"].astype(int)
sc_df["h"] = sc_df["h"].astype(int)
sc_df["ratio_m_over_d"] = sc_df["m"].astype(float) / sc_df["d_model"].astype(float)

# ---------------------------------------------------
# Locate / construct a tidy per-head table with stats
# Required columns: ['m','d_model','h','D_K','F1_mean','F1_std','F1_n']
# ---------------------------------------------------
def _coerce_per_head_df():
    # 1) If a tidy table already exists under common names, normalize its columns.
    tidy_candidates = (
        "per_head", "stats_long", "triplet_df_full", "all_stats", "summary_long", "f1_stats_by_head"
    )
    for name in tidy_candidates:
        if name in globals():
            obj = globals()[name]
            if isinstance(obj, pd.DataFrame):
                df = obj.copy()
                rename = {}
                for c in df.columns:
                    lc = str(c).lower()
                    if lc in ("d_k", "dk"):        rename[c] = "D_K"
                    if lc in ("f1_mean", "f1_mu"): rename[c] = "F1_mean"
                    if lc in ("f1_std", "f1_sigma","std_f1"): rename[c] = "F1_std"
                    if lc in ("f1_n", "n_f1", "count_f1", "n"): rename[c] = "F1_n"
                df = df.rename(columns=rename)
                need = {"m","d_model","h","D_K","F1_mean","F1_std","F1_n"}
                if need.issubset(set(df.columns)):
                    out = df[list(need)].copy()
                    # enforce types
                    out["m"] = out["m"].astype(int)
                    out["d_model"] = out["d_model"].astype(int)
                    out["h"] = out["h"].astype(int)
                    out["D_K"] = out["D_K"].astype(float)
                    out["F1_mean"] = out["F1_mean"].astype(float)
                    out["F1_std"] = out["F1_std"].astype(float)
                    out["F1_n"] = out["F1_n"].astype(int)
                    return out

    # 2) Try to rebuild from pivot grids commonly used in earlier steps:
    #    dataframes with index = (m, d_model), columns = heads
    grid_names_mean = ("F1_mean","f1_mean","F1_mu","f1_mu","mean_f1")
    grid_names_std  = ("F1_std","f1_std","F1_sigma","f1_sigma","std_f1")
    grid_names_n    = ("F1_n","f1_n","count_f1","n_f1","N_f1")
    grid_names_dk   = ("D_K","DK","DK_by_head","D_K_by_head","d_k","dk_grid")

    grids = {}
    for key in grid_names_mean:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_mean"] = globals()[key]
            break
    for key in grid_names_std:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_std"] = globals()[key]
            break
    for key in grid_names_n:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_n"] = globals()[key]
            break
    for key in grid_names_dk:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["D_K"] = globals()[key]
            break

    if set(grids) == {"F1_mean","F1_std","F1_n","D_K"}:
        def _stack(df, out_name):
            tmp = df.copy()
            tmp.columns.name = "h"  # ensure stacked column is named 'h'
            long = tmp.stack(dropna=False).rename(out_name).reset_index()
            # expected columns: ['m','d_model','h', out_name]
            long["h"] = long["h"].astype(int)
            return long

        df = _stack(grids["F1_mean"], "F1_mean") \
            .merge(_stack(grids["F1_std"], "F1_std"), on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["F1_n"],   "F1_n"),   on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["D_K"],    "D_K"),    on=["m","d_model","h"], how="left")

        df["m"] = df["m"].astype(int)
        df["d_model"] = df["d_model"].astype(int)
        df["h"] = df["h"].astype(int)
        df["D_K"] = df["D_K"].astype(float)
        df["F1_mean"] = df["F1_mean"].astype(float)
        df["F1_std"] = df["F1_std"].astype(float)
        df["F1_n"] = df["F1_n"].astype(int)
        return df[["m","d_model","h","D_K","F1_mean","F1_std","F1_n"]].copy()

    return None

per_head = _coerce_per_head_df()
if per_head is None:
    raise RuntimeError(
        "Could not find per-head F1 stats. Provide a DataFrame `per_head` with columns "
        "['m','d_model','h','D_K','F1_mean','F1_std','F1_n'], or adjust the auto-detection."
    )

# ---------------------------------------
# One-sided Welch t-test (candidate vs winner): H1: mu_c < mu_w
# Keep a candidate if p >= alpha (cannot eliminate at 95%)
# ---------------------------------------
def p_less_welch(mu_c, sd_c, n_c, mu_w, sd_w, n_w):
    v1 = (sd_c**2) / max(int(n_c), 1)
    v2 = (sd_w**2) / max(int(n_w), 1)
    se2 = v1 + v2
    if not np.isfinite(se2) or se2 <= 1e-16:
        # degenerate variance: conservative fallback based on means
        return 1.0 if (mu_c >= mu_w) else 0.0

    t = (mu_c - mu_w) / math.sqrt(se2)
    # Welch–Satterthwaite df
    df_den = 0.0
    if n_c > 1 and v1 > 0:
        df_den += (v1**2) / (n_c - 1)
    if n_w > 1 and v2 > 0:
        df_den += (v2**2) / (n_w - 1)
    df = (se2**2) / df_den if df_den > 0 else 1e9  # large df => normal approx

    # Student-t CDF if SciPy is available, else normal approx
    try:
        from scipy.stats import t as student_t
        p = float(student_t.cdf(t, df))  # one-sided: P(T <= t)
    except Exception:
        p = 0.5 * (1.0 + math.erf(t / math.sqrt(2.0)))
    return p

alpha  = 0.05  # 95% rule
dk_tol = 0.10  # ±10% window on DK around the winner's DK

bounds = []
winners = sc_df[["m","d_model","h"]].dropna().astype(int)

for _, r in winners.iterrows():
    m_val, d_val, h_star = int(r["m"]), int(r["d_model"]), int(r["h"])
    sub = per_head[(per_head["m"]==m_val) & (per_head["d_model"]==d_val)].copy()
    if sub.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    win = sub[sub["h"]==h_star]
    if win.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    mu_w = float(win["F1_mean"].iloc[0])
    sd_w = float(win["F1_std"].iloc[0])
    n_w  = int(win["F1_n"].iloc[0])
    DK_w = float(win["D_K"].iloc[0])

    # candidates with DK within ±10% of the winner's DK
    lo, hi = (1.0 - dk_tol) * DK_w, (1.0 + dk_tol) * DK_w
    near = sub[(sub["D_K"] >= lo) & (sub["D_K"] <= hi)].copy()
    if near.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    keep_mask = []
    for _, c in near.iterrows():
        p_less = p_less_welch(
            float(c["F1_mean"]), float(c["F1_std"]), int(c["F1_n"]),
            mu_w, sd_w, n_w
        )
        # Keep if NOT significantly worse than winner at alpha
        keep_mask.append(p_less >= alpha)

    kept = near[keep_mask]

    if kept.empty:
        h_lo = h_hi = h_star
    else:
        h_lo = int(np.nanmin(kept["h"].to_numpy()))
        h_hi = int(np.nanmax(kept["h"].to_numpy()))
        # ensure the winner lies in range
        h_lo = min(h_lo, h_star)
        h_hi = max(h_hi, h_star)

    bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_lo, "h_upper": h_hi})

bounds_df = pd.DataFrame(bounds)

# ---------------------------------------
# Merge bounds into the scatter table and compute vertical errors in "heads"
# ---------------------------------------
sc_df_ci = sc_df.merge(bounds_df, on=["m","d_model"], how="left")
sc_df_ci["yerr_low"]  = (sc_df_ci["h"] - sc_df_ci["h_lower"]).clip(lower=0).astype(float)
sc_df_ci["yerr_high"] = (sc_df_ci["h_upper"] - sc_df_ci["h"]).clip(lower=0).astype(float)

# ---------------------------------------
# Fallback color palette if we couldn't read colors from ax2
# ---------------------------------------
if not color_by_m:
    cycle_colors = plt.rcParams["axes.prop_cycle"].by_key().get("color", [])
    if not cycle_colors:
        cycle_colors = [f"C{i}" for i in range(10)]
    m_unique = sorted(int(v) for v in sc_df_ci["m"].unique())
    color_by_m = {m_val: cycle_colors[i % len(cycle_colors)] for i, m_val in enumerate(m_unique)}

# ---------------------------------------
# Figure and plotting
# ---------------------------------------
fig_sc, ax_sc = plt.subplots(figsize=(5.0, 5.0))

# Error bars first so markers sit on top
for m_val, grp in sc_df_ci.groupby("m", sort=True):
    x_pts = grp["ratio_m_over_d"].astype(float).values
    y_pts = grp["h"].astype(float).values
    yerr  = np.vstack([
        grp["yerr_low"].fillna(0).to_numpy(dtype=float),
        grp["yerr_high"].fillna(0).to_numpy(dtype=float)
    ])
    e_kwargs = dict(fmt="none", elinewidth=0.9, capsize=3, alpha=0.85, zorder=1.5)
    if int(m_val) in color_by_m:
        e_kwargs["ecolor"] = color_by_m[int(m_val)]
    ax_sc.errorbar(x_pts, y_pts, yerr=yerr, **e_kwargs)

# Scatter points, color-coded by m
for m_val, grp in sc_df_ci.groupby("m", sort=True):
    pts_x = grp["ratio_m_over_d"].astype(float).values
    pts_y = grp["h"].astype(float).values
    kwargs = dict(s=46, alpha=0.9, label=f"m={int(m_val)}", zorder=2)
    if int(m_val) in color_by_m:
        kwargs["color"] = color_by_m[int(m_val)]
    ax_sc.scatter(pts_x, pts_y, **kwargs)

# Global linear fit: h ≈ a * (m/d_model) + b
x_all = sc_df_ci["ratio_m_over_d"].to_numpy(dtype=float)
y_all = sc_df_ci["h"].to_numpy(dtype=float)
mask = np.isfinite(x_all) & np.isfinite(y_all)
x_all, y_all = x_all[mask], y_all[mask]

fit_done = False
if x_all.size >= 2 and (x_all.max() - x_all.min()) > 1e-12:
    a, b = np.polyfit(x_all, y_all, 1)
    xs_fit = np.linspace(float(x_all.min()), float(x_all.max()), 200)
    ys_fit = a * xs_fit + b

    # R^2
    y_pred = a * x_all + b
    ss_res = float(np.sum((y_all - y_pred) ** 2))
    ss_tot = float(np.sum((y_all - np.mean(y_all)) ** 2))
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan")

    ax_sc.plot(
        xs_fit, ys_fit,
        linestyle="--", linewidth=1.0, color="black",
        label=f"y = {a:.3g} x + {b:.3g}"  #  (R²={r2:.3f})
    )
    fit_done = True
else:
    print("Skipped linear fit: insufficient x-range or too few points.")
    a = b = r2 = float("nan")  # ensure defined for printing below

# Axes, labels, styling (both axes linear)
ax_sc.set_xlabel(r"$m / d_{\mathrm{model}}$")
ax_sc.set_ylabel("Heads producing min DK")
ax_sc.set_title(r"Heads vs $m / d_{\mathrm{model}}$")
ax_sc.grid(alpha=0.3, linewidth=0.6)
ax_sc.set_xscale("linear")
ax_sc.set_yscale("linear")

# Nice bounds with a little padding
if x_all.size:
    xr = float(x_all.max() - x_all.min())
    if xr <= 0:
        ax_sc.set_xlim(float(x_all.min()) - 0.5, float(x_all.max()) + 0.5)
    else:
        ax_sc.set_xlim(float(x_all.min()) - 0.05 * xr, float(x_all.max()) + 0.05 * xr)
if y_all.size:
    yr_min, yr_max = float(np.min(y_all)), float(np.max(y_all))
    if yr_max > yr_min:
        pad = 0.05 * (yr_max - yr_min)
        ax_sc.set_ylim(max(0, yr_min - pad), yr_max + pad)

# Legend placement similar to the line plot
if sc_df_ci["m"].nunique() <= 10:
    ax_sc.legend(title="m", fontsize=9)
else:
    ax_sc.legend(title="m", fontsize=8, bbox_to_anchor=(1.02, 1.0),
                 loc="upper left", borderaxespad=0.)

plt.tight_layout()

# Save + show
scatter_fit_path = os.path.join(TARGET_DIR, "heads_vs_m_over_d_scatter_linear_fit.png")
fig_sc.savefig(scatter_fit_path, dpi=150, bbox_inches="tight")
print("Saved scatter+fit+CI to:", scatter_fit_path, "| Exists?", os.path.exists(scatter_fit_path))
if fit_done:
    print(f"y = {a:.3g} x + {b:.3g} (R²={r2:.3f})")
plt.show()
plt.close(fig_sc)




# --------------------------------------------------------------------------------
# PLOT — Scatter: min DK vs (m * log m / d_model) + global linear fit + CI error bars
# Uses the same per‑m colors as the two graphs immediately before (via color_by_m)
# Confidence intervals are the same DK bounds used for the table (DK_lower, DK_upper)
# --------------------------------------------------------------------------------

# Reuse color_by_m if it exists; otherwise, try to rebuild it from ax2 (line plot).
if 'color_by_m' not in locals() or not isinstance(color_by_m, dict) or not color_by_m:
    color_by_m = {}
    try:
        for line in ax2.get_lines():
            lbl = line.get_label()
            if isinstance(lbl, str) and lbl.startswith("m="):
                mm = int(lbl.split("=")[1])
                color_by_m[mm] = line.get_color()
    except Exception:
        pass

# Prepare the (m, d_model, D_K) table and x = m log m / d_model
sc2_df = (
    min_rows[["m", "d_model", "D_K"]]
    .dropna()
    .copy()
)
sc2_df["x_metric"] = (
    sc2_df["m"].astype(float) * np.log2(sc2_df["m"].astype(float))
    / sc2_df["d_model"].astype(float)
)

# Deterministic fallback mapping if we couldn't recover colors from ax2
if not color_by_m:
    cycle_colors = plt.rcParams["axes.prop_cycle"].by_key().get("color", [])
    if not cycle_colors:
        cycle_colors = [f"C{i}" for i in range(10)]
    m_unique = sorted(int(v) for v in sc2_df["m"].unique())
    color_by_m = {m_val: cycle_colors[i % len(cycle_colors)] for i, m_val in enumerate(m_unique)}

# --- NEW: bring in DK CI bounds computed for the table (same logic, no recomputation) ---
# Prefer using the tidy table you just wrote; fallback to the pivot grids if needed.
if ("triplet_df" in locals()) and {"m","d_model","DK_upper","DK_lower"}.issubset(set(triplet_df.columns)):
    tmp_trip = triplet_df[["m","d_model","DK_upper","DK_lower"]].copy()
else:
    # Fallback: reconstruct from the already-built grids
    mean_long_f  = vals.stack(dropna=False).rename("DK_mean").reset_index()
    upper_long_f = dk_upper.stack(dropna=False).rename("DK_upper").reset_index()
    lower_long_f = dk_lower.stack(dropna=False).rename("DK_lower").reset_index()
    tmp_trip = (mean_long_f
                .merge(upper_long_f, on=["m","d_model"], how="left")
                .merge(lower_long_f, on=["m","d_model"], how="left"))[["m","d_model","DK_upper","DK_lower"]]

# Ensure join keys line up in type/format
sc2_df_ci = sc2_df.copy()
sc2_df_ci["m"] = sc2_df_ci["m"].astype(int)
sc2_df_ci["d_model"] = sc2_df_ci["d_model"].astype(int)

tmp_trip = tmp_trip.copy()
tmp_trip["m"] = tmp_trip["m"].astype(int)
tmp_trip["d_model"] = tmp_trip["d_model"].astype(int)

# Merge DK bounds into the scatter rows
sc2_df_ci = sc2_df_ci.merge(tmp_trip, on=["m","d_model"], how="left")

# Compute asymmetric vertical errors:
# bottom = DK_mean - DK_lower (optimistic); top = DK_upper - DK_mean (pessimistic)
sc2_df_ci["yerr_low"]  = (sc2_df_ci["D_K"] - sc2_df_ci["DK_lower"]).astype(float)
sc2_df_ci["yerr_high"] = (sc2_df_ci["DK_upper"] - sc2_df_ci["D_K"]).astype(float)

# sanitize: nonnegative, finite only
for col in ["yerr_low", "yerr_high"]:
    arr = sc2_df_ci[col].to_numpy(dtype=float)
    arr[~np.isfinite(arr)] = np.nan
    arr[arr < 0] = np.nan
    sc2_df_ci[col] = arr

# --- Plot ---
fig_sc2, ax_sc2 = plt.subplots(figsize=(5.6, 5.0))

# Draw error bars first (slightly lower zorder) so markers sit on top
for m_val, grp in sc2_df_ci.groupby("m", sort=True):
    x_pts = grp["x_metric"].astype(float).values
    y_pts = grp["D_K"].astype(float).values
    yerr  = np.vstack([grp["yerr_low"].values, grp["yerr_high"].values])  # [2 x N]

    e_kwargs = dict(fmt="none", elinewidth=0.9, capsize=3, alpha=0.85, zorder=1.5)
    if int(m_val) in color_by_m:
        e_kwargs["ecolor"] = color_by_m[int(m_val)]
    ax_sc2.errorbar(x_pts, y_pts, yerr=yerr, **e_kwargs)

# Scatter the points grouped by m, using the shared color scheme (unchanged)
for m_val, grp in sc2_df_ci.groupby("m", sort=True):
    x_pts = grp["x_metric"].astype(float).values
    y_pts = grp["D_K"].astype(float).values
    kwargs = dict(s=46, alpha=0.9, label=f"m={int(m_val)}", zorder=2)
    if int(m_val) in color_by_m:
        kwargs["color"] = color_by_m[int(m_val)]
    ax_sc2.scatter(x_pts, y_pts, **kwargs)

# Global linear fit constrained through the origin: D_K ≈ a * (m log m / d_model) (unchanged)
x = sc2_df_ci["x_metric"].to_numpy(dtype=float)
y = sc2_df_ci["D_K"].to_numpy(dtype=float)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]

fit_done = False
if x.size >= 1:
    sxx = float(np.dot(x, x))
    if sxx > 0:
        a = float(np.dot(x, y) / sxx)  # OLS slope with intercept fixed at 0

        x_start = min(0.0, float(x.min()))
        x_end = float(x.max())
        xs_fit = np.linspace(x_start, x_end, 200)
        ys_fit = a * xs_fit

        # R^2 for a through-origin model: 1 - SSE / sum(y^2)
        y_pred = a * x
        ss_res = float(np.sum((y - y_pred) ** 2))
        ss_tot0 = float(np.sum(y ** 2))
        r2_0 = 1.0 - ss_res / ss_tot0 if ss_tot0 > 0 else float("nan")

        ax_sc2.plot(
            xs_fit, ys_fit,
            linestyle="--", linewidth=1.0, color="black", zorder=1,
            label=f"y = {a:.3g} x"  #  (R²₀={r2_0:.3f})
        )
        fit_done = True

# Labels, title, and styling (unchanged)
ax_sc2.set_xlabel(r"$m \log m \,/\, d_{\mathrm{model}}$")
ax_sc2.set_ylabel(r"Minimum total key dimension $D_K$")
ax_sc2.set_title(r"Min $D_K$ vs $m \log m / d_{\mathrm{model}}$")
ax_sc2.grid(alpha=0.3, linewidth=0.6)

# Legend placement matches earlier behavior (unchanged)
if sc2_df_ci["m"].nunique() <= 10:
    ax_sc2.legend(fontsize=9, title="m")
else:
    ax_sc2.legend(fontsize=8, title="m", bbox_to_anchor=(1.02, 1.0),
                  loc="upper left", borderaxespad=0.)

plt.tight_layout()

# Save + show 
scatter2_ci_path = os.path.join(TARGET_DIR, "min_DK_vs_m_log_m_over_d_scatter_linear_fit.png")
fig_sc2.savefig(scatter2_ci_path, dpi=150, bbox_inches="tight")
if fit_done:
    print(f"Saved scatter+fit+CI to: {scatter2_ci_path} | Exists? {os.path.exists(scatter2_ci_path)}")
    print(f"y = {a:.3g} x (R²₀ = {r2_0:.3f})")
else:
    print(f"Saved scatter+CI (no fit drawn) to: {scatter2_ci_path} | Exists? {os.path.exists(scatter2_ci_path)}")
    print("Skipped linear fit through origin: insufficient variation or zero range in x.")
plt.show()
plt.close(fig_sc2)



# =============================================================================
# DIAGONAL SWEEP PLOTS — with error intervals on the bars (heads)
# Colored so the line+its CI and bars+their CI are easy to distinguish.
#
# One chart per ratio in DIAGONAL_RATIOS:
#   • Left y-axis (line): minimum DK (with DK CI if available)
#   • Right y-axis (bars): selected heads h, with asymmetric CI error bars
# X-axis: the (m, d_model) pairs on the diagonal, increasing m from left to right
# =============================================================================

import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

# ----------------------------
# Pre-flight & light defaults
# ----------------------------
if "TARGET_DIR" not in globals():
    TARGET_DIR = "."

# ---- Colors (choose anything you like) ----
LINE_COLOR      = globals().get("LINE_COLOR", "#1f77b4")  # line + its error bars (tab:blue)
LINE_ERR_COLOR  = globals().get("LINE_ERR_COLOR", LINE_COLOR)

BAR_FACE_COLOR  = globals().get("BAR_FACE_COLOR", "#ff7f0e")  # bars (tab:orange)
BAR_EDGE_COLOR  = globals().get("BAR_EDGE_COLOR", "#ad5a0a")  # bar edges
BAR_ERR_COLOR   = globals().get("BAR_ERR_COLOR", BAR_EDGE_COLOR)  # bar CIs

if "min_rows" not in globals() or not isinstance(min_rows, pd.DataFrame):
    raise RuntimeError(
        "This script expects a DataFrame `min_rows` with columns ['m','d_model','D_K','h']."
    )

# If DIAGONAL_RATIOS isn't provided, infer from the data you actually have in min_rows
if "DIAGONAL_RATIOS" not in globals() or not DIAGONAL_RATIOS:
    tmp = min_rows.dropna(subset=["m","d_model"]).copy()
    tmp["ratio"] = tmp["m"].astype(float) / tmp["d_model"].astype(float)
    # Keep unique ratios in ascending order
    DIAGONAL_RATIOS = sorted(float(r) for r in np.unique(tmp["ratio"].to_numpy(dtype=float)))

# Keep the selected min DK and its h
minDK_and_h = (
    min_rows[["m", "d_model", "D_K", "h"]]
    .rename(columns={"D_K": "min_DK", "h": "h_selected"})
    .copy()
)
minDK_and_h["m"] = minDK_and_h["m"].astype(int)
minDK_and_h["d_model"] = minDK_and_h["d_model"].astype(int)

# ---------------------------------------
# Helpers
# ---------------------------------------
def _pairs_on_ratio(df_pairs: pd.DataFrame, ratio: float) -> pd.DataFrame:
    """Return subset where (m, d_model) lie on the diagonal m/d_model == ratio."""
    r = float(ratio)
    df = df_pairs.copy()
    if float(ratio).is_integer():
        rr = int(round(r))
        mask = (df["m"] == rr * df["d_model"])
    else:
        mask = np.isclose(df["m"] / df["d_model"], r, rtol=1e-9, atol=1e-12)
    out = df.loc[mask].copy()
    out["m"] = out["m"].astype(int)
    out["d_model"] = out["d_model"].astype(int)
    return out

def _pow2_label(y, _):
    if y <= 0:
        return ""
    k = np.log2(y)
    if np.isclose(k, round(k)):
        return fr"$2^{{{int(round(k))}}}$"
    return ""

# For consistent ticks across plots: powers-of-two covering ALL chosen h
_all_h_selected = pd.Series(min_rows["h"], dtype="float64").dropna()
_all_h_selected = _all_h_selected[_all_h_selected > 0]
if not _all_h_selected.empty:
    k_min_h = int(np.floor(np.log2(_all_h_selected.min())))
    k_max_h = int(np.ceil(np.log2(_all_h_selected.max())))
    pow2_ticks_h = [2**k for k in range(k_min_h, k_max_h + 1)]
else:
    pow2_ticks_h = [1, 2, 4, 8, 16, 32]

# ---------------------------------------
# DK CI source for the line (optional)
# ---------------------------------------
diag_ci_src = None
if ("triplet_df" in globals()) and isinstance(triplet_df, pd.DataFrame) \
   and {"m","d_model","DK_upper","DK_lower"}.issubset(triplet_df.columns):
    diag_ci_src = triplet_df[["m","d_model","DK_upper","DK_lower"]].copy()
elif ("tmp_trip" in globals()) and isinstance(tmp_trip, pd.DataFrame) \
     and {"m","d_model","DK_upper","DK_lower"}.issubset(tmp_trip.columns):
    diag_ci_src = tmp_trip[["m","d_model","DK_upper","DK_lower"]].copy()
else:
    # Fallback: reconstruct from pivot grids if available
    if all(name in globals() and isinstance(globals()[name], pd.DataFrame)
           for name in ("vals","dk_upper","dk_lower")):
        mean_long_f  = vals.stack(dropna=False).rename("DK_mean").reset_index()
        upper_long_f = dk_upper.stack(dropna=False).rename("DK_upper").reset_index()
        lower_long_f = dk_lower.stack(dropna=False).rename("DK_lower").reset_index()
        diag_ci_src = (
            mean_long_f
            .merge(upper_long_f, on=["m","d_model"], how="left")
            .merge(lower_long_f, on=["m","d_model"], how="left")
        )[["m","d_model","DK_upper","DK_lower"]].copy()

if diag_ci_src is not None:
    diag_ci_src["m"] = diag_ci_src["m"].astype(int)
    diag_ci_src["d_model"] = diag_ci_src["d_model"].astype(int)

# ---------------------------------------
# Build/locate tidy per-head F1 stats to compute head-range CIs (bars)
# Required columns: ['m','d_model','h','D_K','F1_mean','F1_std','F1_n']
# ---------------------------------------
def _coerce_per_head_df():
    # 1) If a tidy table already exists under common names, normalize its columns.
    tidy_candidates = (
        "per_head", "stats_long", "triplet_df_full", "all_stats", "summary_long", "f1_stats_by_head"
    )
    for name in tidy_candidates:
        if name in globals():
            obj = globals()[name]
            if isinstance(obj, pd.DataFrame):
                df = obj.copy()
                rename = {}
                for c in df.columns:
                    lc = str(c).lower()
                    if lc in ("d_k", "dk"):        rename[c] = "D_K"
                    if lc in ("f1_mean", "f1_mu"): rename[c] = "F1_mean"
                    if lc in ("f1_std", "f1_sigma","std_f1"): rename[c] = "F1_std"
                    if lc in ("f1_n", "n_f1", "count_f1", "n"): rename[c] = "F1_n"
                df = df.rename(columns=rename)
                need = {"m","d_model","h","D_K","F1_mean","F1_std","F1_n"}
                if need.issubset(set(df.columns)):
                    out = df[list(need)].copy()
                    # enforce types
                    out["m"] = out["m"].astype(int)
                    out["d_model"] = out["d_model"].astype(int)
                    out["h"] = out["h"].astype(int)
                    out["D_K"] = out["D_K"].astype(float)
                    out["F1_mean"] = out["F1_mean"].astype(float)
                    out["F1_std"] = out["F1_std"].astype(float)
                    out["F1_n"] = out["F1_n"].astype(int)
                    return out

    # 2) Else rebuild from per-head pivot grids if present
    grid_names_mean = ("F1_mean","f1_mean","F1_mu","f1_mu","mean_f1")
    grid_names_std  = ("F1_std","f1_std","F1_sigma","f1_sigma","std_f1")
    grid_names_n    = ("F1_n","f1_n","count_f1","n_f1","N_f1")
    grid_names_dk   = ("D_K","DK","DK_by_head","D_K_by_head","d_k","dk_grid")

    grids = {}
    for key in grid_names_mean:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_mean"] = globals()[key]; break
    for key in grid_names_std:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_std"] = globals()[key]; break
    for key in grid_names_n:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_n"] = globals()[key]; break
    for key in grid_names_dk:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["D_K"] = globals()[key]; break

    if set(grids) == {"F1_mean","F1_std","F1_n","D_K"}:
        def _stack(df, out_name):
            tmp = df.copy()
            tmp.columns.name = "h"  # ensure stacked column is 'h'
            long = tmp.stack(dropna=False).rename(out_name).reset_index()
            long["h"] = long["h"].astype(int)
            return long

        df = _stack(grids["F1_mean"], "F1_mean") \
            .merge(_stack(grids["F1_std"], "F1_std"), on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["F1_n"],   "F1_n"),   on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["D_K"],    "D_K"),    on=["m","d_model","h"], how="left")

        df["m"] = df["m"].astype(int)
        df["d_model"] = df["d_model"].astype(int)
        df["h"] = df["h"].astype(int)
        df["D_K"] = df["D_K"].astype(float)
        df["F1_mean"] = df["F1_mean"].astype(float)
        df["F1_std"] = df["F1_std"].astype(float)
        df["F1_n"] = df["F1_n"].astype(int)
        return df[["m","d_model","h","D_K","F1_mean","F1_std","F1_n"]].copy()

    return None

per_head = _coerce_per_head_df()
if per_head is None:
    raise RuntimeError(
        "Could not find per-head F1 stats. Provide `per_head` with columns "
        "['m','d_model','h','D_K','F1_mean','F1_std','F1_n'], or the pivot grids."
    )

# ---------------------------------------
# Welch one-sided t-test (candidate vs winner): H1: mu_c < mu_w
# Keep a candidate if p >= alpha (cannot eliminate at 95%)
# ---------------------------------------
def p_less_welch(mu_c, sd_c, n_c, mu_w, sd_w, n_w):
    v1 = (sd_c**2) / max(int(n_c), 1)
    v2 = (sd_w**2) / max(int(n_w), 1)
    se2 = v1 + v2
    if not np.isfinite(se2) or se2 <= 1e-16:
        # Degenerate variance: conservative fallback based on means
        return 1.0 if (mu_c >= mu_w) else 0.0
    t = (mu_c - mu_w) / math.sqrt(se2)
    # Welch–Satterthwaite df
    df_den = 0.0
    if n_c > 1 and v1 > 0:
        df_den += (v1**2) / (n_c - 1)
    if n_w > 1 and v2 > 0:
        df_den += (v2**2) / (n_w - 1)
    df = (se2**2) / df_den if df_den > 0 else 1e9  # large df => normal approx
    try:
        from scipy.stats import t as student_t
        p = float(student_t.cdf(t, df))  # one-sided: P(T <= t)
    except Exception:
        # Normal approximation if SciPy is unavailable
        p = 0.5 * (1.0 + math.erf(t / math.sqrt(2.0)))
    return p

alpha  = 0.05  # 95% rule
dk_tol = 0.10  # ±10% window on DK around the winner's DK

# ---------------------------------------
# Compute per-(m,d_model) head-range bounds & asymmetric y-errors
# ---------------------------------------
winners = (
    min_rows[["m","d_model","h"]]
    .dropna()
    .astype(int)
    .rename(columns={"h":"h_selected"})
)

bounds = []
for _, r in winners.iterrows():
    m_val, d_val, h_star = int(r["m"]), int(r["d_model"]), int(r["h_selected"])
    sub = per_head[(per_head["m"]==m_val) & (per_head["d_model"]==d_val)].copy()
    if sub.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    win = sub[sub["h"]==h_star]
    if win.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    mu_w = float(win["F1_mean"].iloc[0])
    sd_w = float(win["F1_std"].iloc[0])
    n_w  = int(win["F1_n"].iloc[0])
    DK_w = float(win["D_K"].iloc[0])

    # Candidates with DK within ±10% of the winner's DK
    lo, hi = (1.0 - dk_tol) * DK_w, (1.0 + dk_tol) * DK_w
    near = sub[(sub["D_K"] >= lo) & (sub["D_K"] <= hi)].copy()
    if near.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    keep_mask = []
    for _, c in near.iterrows():
        p_less = p_less_welch(
            float(c["F1_mean"]), float(c["F1_std"]), int(c["F1_n"]),
            mu_w, sd_w, n_w
        )
        # Keep if NOT significantly worse than winner at alpha
        keep_mask.append(p_less >= alpha)

    kept = near.loc[keep_mask]
    if kept.empty:
        h_lo = h_hi = h_star
    else:
        h_lo = int(np.nanmin(kept["h"].to_numpy()))
        h_hi = int(np.nanmax(kept["h"].to_numpy()))
        # Ensure the winner lies within the range
        h_lo = min(h_lo, h_star)
        h_hi = max(h_hi, h_star)

    bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_lo, "h_upper": h_hi})

bounds_df = pd.DataFrame(bounds)

# Tidy CI table with asymmetric y-errors relative to the winner head
h_ci_df = (
    winners.merge(bounds_df, on=["m","d_model"], how="left")
    .assign(
        yerr_low  = lambda df: (df["h_selected"] - df["h_lower"]).clip(lower=0).astype(float),
        yerr_high = lambda df: (df["h_upper"] - df["h_selected"]).clip(lower=0).astype(float),
    )
)

# =============================================================================
# Plot: one figure per diagonal ratio
# =============================================================================
for ratio in DIAGONAL_RATIOS:
    diag_df = _pairs_on_ratio(minDK_and_h, ratio)
    if diag_df.empty:
        print(f"(Diagonal ratio {ratio}) No (m, d_model) pairs found on this diagonal; skipping.")
        continue

    # Order by increasing m
    diag_df = diag_df.sort_values(["m", "d_model"]).reset_index(drop=True)

    # X-axis labels "(m,d)"
    x_labels = [f"({int(m)},{int(d)})" for m, d in zip(diag_df["m"], diag_df["d_model"])]
    x = np.arange(len(diag_df), dtype=float)

    # Figure size scales with #points
    fig_w = max(6.5, 1.0 + 0.9 * len(diag_df))
    fig_h = 5.2
    fig_diag, ax_left = plt.subplots(figsize=(fig_w, fig_h))

    # ---------------------------------------
    # Left axis — line: min DK (with DK CI if available)
    # ---------------------------------------
    ax_left.plot(
        x, diag_df["min_DK"].astype(float).values,
        marker="o", linewidth=1.8, label="Min $D_K$",
        color=LINE_COLOR, zorder=2.2
    )

    # Merge DK CI bounds for these points (if we have them)
    if diag_ci_src is not None:
        _diag_ci = diag_df.merge(diag_ci_src, on=["m","d_model"], how="left")
        y_vals = _diag_ci["min_DK"].astype(float).to_numpy()
        yerr_low  = (y_vals - _diag_ci["DK_lower"].astype(float).to_numpy())
        yerr_high = (_diag_ci["DK_upper"].astype(float).to_numpy() - y_vals)

        # sanitize: nonnegative, finite only
        yerr_low[~np.isfinite(yerr_low) | (yerr_low < 0)] = np.nan
        yerr_high[~np.isfinite(yerr_high) | (yerr_high < 0)] = np.nan

        ax_left.errorbar(
            x, y_vals, yerr=np.vstack([yerr_low, yerr_high]),
            fmt="none", elinewidth=1.0, capsize=3, alpha=0.95,
            ecolor=LINE_ERR_COLOR, zorder=2.1
        )

    # Axis labels/grid
    ax_left.set_ylabel("Minimum $D_K$")
    ax_left.set_xlabel(r"$(m, d_{\mathrm{model}})$")
    ax_left.grid(alpha=0.3, linewidth=0.6)
    ax_left.margins(y=0.2)

    # X ticks
    ax_left.set_xticks(x)
    ax_left.set_xticklabels(x_labels, rotation=45, ha="right")

    # ---------------------------------------
    # Right axis — bars: selected heads, with asymmetric CI error bars
    # ---------------------------------------
    ax_right = ax_left.twinx()
    bar_heights = diag_df["h_selected"].astype(float).values
    bars = ax_right.bar(
        x, bar_heights,
        alpha=0.35, width=0.8, label="Heads",
        color=BAR_FACE_COLOR, edgecolor=BAR_EDGE_COLOR, linewidth=1.0, zorder=1.0
    )
    ax_right.set_ylabel("Heads producing min $D_K$")
    ax_right.set_yscale("log")

    # NEW: add asymmetric CI error bars on the bars, from h_ci_df
    _diag_hci = diag_df.merge(
        h_ci_df[["m","d_model","yerr_low","yerr_high"]],
        on=["m","d_model"], how="left"
    )
    _yl = _diag_hci["yerr_low"].to_numpy(dtype=float)
    _yh = _diag_hci["yerr_high"].to_numpy(dtype=float)
    _yl[~np.isfinite(_yl)] = 0.0
    _yh[~np.isfinite(_yh)] = 0.0
    _yl[_yl < 0] = 0.0
    _yh[_yh < 0] = 0.0

    ax_right.errorbar(
        x,
        bar_heights,
        yerr=np.vstack([_yl, _yh]),
        fmt="none", elinewidth=1.0, capsize=3, alpha=0.95,
        ecolor=BAR_ERR_COLOR, zorder=1.5
    )

    # Pad the log scale by one power of 2 on each side of the data
    diag_h = diag_df["h_selected"].astype(float)
    diag_h = diag_h[diag_h > 0]
    if not diag_h.empty:
        kmin = int(np.floor(np.log2(diag_h.min())))
        kmax = int(np.ceil(np.log2(diag_h.max())))
        y_min = 2 ** (kmin - 1)
        y_max = 2 ** (kmax + 1)
        ax_right.set_ylim(y_min, y_max)

    # Major ticks only at global powers-of-two for consistency
    ax_right.yaxis.set_major_locator(mticker.FixedLocator(pow2_ticks_h))
    ax_right.yaxis.set_minor_locator(mticker.NullLocator())
    ax_right.yaxis.set_major_formatter(mticker.FuncFormatter(_pow2_label))

    # ---------------------------------------
    # Title + legend
    # ---------------------------------------
    ax_left.set_title(fr"Diagonal: $m/d_{{\mathrm{{model}}}} = {ratio}$")
    handles_left, labels_left = ax_left.get_legend_handles_labels()
    handles_right, labels_right = ax_right.get_legend_handles_labels()
    ax_left.legend(handles_left + handles_right, labels_left + labels_right, fontsize=9)

    plt.tight_layout()

    # Save + show
    ratio_str = str(ratio).replace(".", "p")
    out_name = f"diagonal_ratio_{ratio_str}_minDK_and_selected_h.png"
    out_path = os.path.join(TARGET_DIR, out_name)
    fig_diag.savefig(out_path, dpi=150, bbox_inches="tight")
    print(f"Saved diagonal plot for ratio {ratio}: {out_path} | Exists? {os.path.exists(out_path)}")

    plt.show()
    plt.close(fig_diag)



# --------------------------------------------------------------------------------
# ERROR-BAR SOURCES DISCOVERY + HELPERS (used by per-(m,d_model) figures)
# --------------------------------------------------------------------------------

# We try to be flexible with column names
_STD_COLS = ["microF1_std", "microF1_stddev"]
_SE_COLS  = ["microF1_stderr", "microF1_se", "microF1_sem"]
_CI_COLS  = ["microF1_ci", "microF1_ci95", "microF1_CI"]

STD_COL = next((c for c in _STD_COLS if c in unique_df.columns), None)
SE_COL  = next((c for c in _SE_COLS  if c in unique_df.columns), None)
CI_COL  = next((c for c in _CI_COLS  if c in unique_df.columns), None)

if not any([STD_COL, SE_COL, CI_COL]):
    print("Per-combo plots will NOT include error bars (no std/sem/ci column found).")
else:
    present = ", ".join([c for c in [STD_COL, SE_COL, CI_COL] if c])
    print(f"Per-combo plots: error-bar sources detected -> {present}")

# 95% two-sided critical values for Student's t at df = 1..30
_T_CRIT_95 = {
    1:12.706, 2:4.303, 3:3.182, 4:2.776, 5:2.571, 6:2.447, 7:2.365, 8:2.306, 9:2.262,
    10:2.228, 11:2.201, 12:2.179, 13:2.160, 14:2.145, 15:2.131, 16:2.120, 17:2.110,
    18:2.101, 19:2.093, 20:2.086, 21:2.080, 22:2.074, 23:2.069, 24:2.064, 25:2.060,
    26:2.056, 27:2.052, 28:2.048, 29:2.045, 30:2.042
}

def _tcrit_95_df(df: int) -> float:
    """Return two-sided 95% t critical for given degrees of freedom."""
    df = int(max(1, df))
    if df <= 30:
        return _T_CRIT_95[df]
    elif df <= 60:
        return 2.000  # reasonable midrange approximation
    else:
        return 1.960  # ~z for large df

def _sanitize_err(arr):
    if arr is None:
        return None
    a = np.asarray(arr, dtype=float)
    a[~np.isfinite(a)] = np.nan
    a[a < 0] = np.nan
    return a

def compute_errorbars_for_group(dh: pd.DataFrame, mode: str):
    """
    Returns (yerr, label_str) for the requested mode:
      - mode == "std"  -> ±1 SD across runs
      - mode == "ci95" -> 95% CI half-width for the mean
    Tries to derive from whichever of {STD_COL, SE_COL, CI_COL} is available.
    If 'runs' is missing and only STD is present, we cannot form a CI for the mean -> returns None for ci95.
    """
    n = dh["runs"].to_numpy(dtype=float) if "runs" in dh.columns else None

    std = dh[STD_COL].to_numpy(dtype=float) if (STD_COL and STD_COL in dh.columns) else None
    se  = dh[SE_COL ].to_numpy(dtype=float) if (SE_COL  and SE_COL  in dh.columns) else None
    ci  = dh[CI_COL ].to_numpy(dtype=float) if (CI_COL  and CI_COL  in dh.columns) else None

    yerr_std = None
    yerr_ci95 = None
    src_used = None

    # --- Build ±1 SD if possible ---
    if std is not None:
        yerr_std = std.copy()
        src_used = STD_COL
    elif se is not None and n is not None:
        yerr_std = se * np.sqrt(n)
        src_used = SE_COL
    elif ci is not None and n is not None:
        # SD ≈ CI * sqrt(n) / tcrit
        df_arr = np.maximum(1, n.astype(int) - 1)
        tcrit = np.array([_tcrit_95_df(int(df)) for df in df_arr])
        with np.errstate(divide="ignore", invalid="ignore"):
            yerr_std = np.where((tcrit > 0), ci * np.sqrt(n) / tcrit, np.nan)
        src_used = CI_COL

    # --- Build 95% CI half-width for the mean if possible ---
    if ci is not None:
        yerr_ci95 = ci.copy()
    elif se is not None:
        if n is not None:
            df_arr = np.maximum(1, n.astype(int) - 1)
            tcrit = np.array([_tcrit_95_df(int(df)) for df in df_arr])
            yerr_ci95 = tcrit * se
        else:
            # Without runs, fall back to z≈1.96 * SE (still a CI for the mean)
            yerr_ci95 = 1.96 * se
    elif std is not None and n is not None:
        df_arr = np.maximum(1, n.astype(int) - 1)
        tcrit = np.array([_tcrit_95_df(int(df)) for df in df_arr])
        with np.errstate(divide="ignore", invalid="ignore"):
            yerr_ci95 = np.where(n > 0, tcrit * std / np.sqrt(n), np.nan)
    else:
        # std is present but runs is unknown -> cannot compute a CI for the mean reliably
        yerr_ci95 = None

    yerr_std  = _sanitize_err(yerr_std)
    yerr_ci95 = _sanitize_err(yerr_ci95)

    if mode == "std":
        return (yerr_std, f"±1 SD (source={src_used})")
    else:
        return (yerr_ci95, f"95% CI (source={src_used})")

# --------------------------------------------------------------------------------
# PER-(m, d_model) FIGURES — F1 vs DK (one line per head), saved to a subfolder
# Error-bar toggle: ERRORBAR_MODE in {"std", "ci95"}
# --------------------------------------------------------------------------------
PER_COMBO_SUBDIR = "per_combo_F1_vs_DK_byH"  # change if you like
per_combo_dir = os.path.join(TARGET_DIR, PER_COMBO_SUBDIR)
os.makedirs(per_combo_dir, exist_ok=True)

if not any([STD_COL, SE_COL, CI_COL]):
    print("Per-combo plots will NOT include error bars (no std/sem/ci column found).")
else:
    print(f"Per-combo plots error bars mode: {ERRORBAR_MODE!r} "
          f"(std→±1 SD, ci95→95% CI half-width)")

pair_count = 0
for (m_val, d_val), sub in unique_df.groupby(["m", "d_model"], sort=True):
    heads = sorted(sub["h"].unique().tolist())

    fig_w, fig_h = 8.3, 5.2
    fig3, ax3 = plt.subplots(figsize=(fig_w, fig_h))

    ymin_seen, ymax_seen = +1e9, -1e9

    for h_val in heads:
        dh = sub.loc[sub["h"] == h_val].sort_values("D_K")
        if dh.empty:
            continue

        x = dh["D_K"].astype(float).values
        y = dh["microF1_mean"].astype(float).values
        ymin_seen = min(ymin_seen, float(np.nanmin(y)))
        ymax_seen = max(ymax_seen, float(np.nanmax(y)))

        # Compute error bars for this head, per the toggle
        yerr, err_label = compute_errorbars_for_group(dh, ERRORBAR_MODE)

        if (yerr is not None) and np.any(np.isfinite(yerr)):
            ax3.errorbar(
                x, y, yerr=yerr, marker="o", linestyle="-", linewidth=1.8,
                capsize=3, label=f"h={int(h_val)}"
            )
        else:
            ax3.plot(x, y, marker="o", linestyle="-", linewidth=1.8, label=f"h={int(h_val)}")

    ax3.set_xlabel("Total Key Dim $D_K$")
    ax3.set_ylabel("Mean micro-F1 (test)")
    suffix = "±1 SD" if ERRORBAR_MODE == "std" else "95% CI"
    ax3.set_title(fr"F1 vs DK — m={int(m_val)}, $d_{{\mathrm{{model}}}}$={int(d_val)}")
    ax3.grid(alpha=0.3, linewidth=0.6)

    # y-limits: give a little headroom, clip to [0, 1.02]
    if ymin_seen < 1e9:
        pad = 0.03
        ax3.set_ylim(max(0.0, ymin_seen - pad), min(1.02, max(0.85, ymax_seen + pad)))

    # legend like the reference figure
    if len(heads) <= 10:
        ax3.legend(title="Heads", fontsize=9)
    else:
        ax3.legend(title="Heads", fontsize=8, bbox_to_anchor=(1.02, 1.0),
                   loc="upper left", borderaxespad=0.)

    plt.tight_layout()

    # save to subfolder
    out_name = f"F1_vs_DK_m{int(m_val)}_d{int(d_val)}.png"
    out_path = os.path.join(per_combo_dir, out_name)
    fig3.savefig(out_path, dpi=150, bbox_inches="tight")
    print(f"Saved per-combo plot: {out_path} | Exists? {os.path.exists(out_path)}")

    plt.show()
    plt.close(fig3)
    pair_count += 1

print(f"\nCreated {pair_count} per-(m,d_model) figure(s) in: {per_combo_dir}")

# Summary of which files were included
print("\nIncluded files:")
for p in good_files:
    print(" •", os.path.basename(p))




    

# =========================
# Single-cell Colab script (GPU-optimized):
# Capacity threshold of QK channel in a single Transformer block
# with frozen GPT-2 token embeddings, plus multi-head advantage at fixed D_K.
#
# TASK:
# - Context has L items; each item is a PAIR of distinct tokens (a,b) sampled without replacement.
# - Context representation uses embedding SUM: E[a] + E[b].
# - Query is last token (after NULL separator), unpaired.
# - Target: let s = perm[q]. If s appears in context, return its PARTNER token in that pair; else NULL.
# - Pairs are chosen independently of the permutation; query is chosen conditional to achieve P_PRESENT.
#
# + Drive mirroring:
#   - Writes to local Colab directory as before
#   - Mirrors all key artifacts (runs.csv, thresholds.csv, plots) to Google Drive
#   - Restores runs.csv from Drive if local is missing (resume-safe across runtimes)
#
# =========================

# --- lightweight dependency installs (safe to re-run) ---
import importlib, sys, subprocess, os, math, time, gc, random, shutil

def _ensure(pkg, pip_name=None):
    pip_name = pip_name or pkg
    if importlib.util.find_spec(pkg) is None:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pip_name])

_ensure("numpy")
_ensure("pandas")
_ensure("matplotlib")
_ensure("tqdm")
_ensure("transformers")
_ensure("datasets")  # NEW

# --- Imports ---
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# =========================
# Google Drive mirroring
# =========================
IN_COLAB = False
try:
    from google.colab import drive  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

DRIVE_ENABLED = IN_COLAB
DRIVE_MOUNTPOINT = "/content/drive"
DRIVE_PROJECT_ROOT = os.path.join(DRIVE_MOUNTPOINT, "MyDrive", "capacity_threshold_outputs")

# How often to sync updated files to Drive during the sweep
SYNC_TO_DRIVE = True
SYNC_EVERY_N_RUNS = 1  # set to 5/10 if Drive I/O becomes a bottleneck

def _sync_file_to_drive(local_path: str, drive_path: str | None):
    if not (DRIVE_ENABLED and SYNC_TO_DRIVE and drive_path is not None):
        return
    try:
        os.makedirs(os.path.dirname(drive_path), exist_ok=True)
        shutil.copy2(local_path, drive_path)
    except Exception as e:
        print(f"[WARN] Drive sync failed for {os.path.basename(local_path)}: {e}")

def _maybe_restore_from_drive(drive_path: str | None, local_path: str):
    """If local file missing but Drive has it, copy Drive->local to allow resume."""
    if not (DRIVE_ENABLED and drive_path is not None):
        return
    try:
        if (not os.path.exists(local_path)) and os.path.exists(drive_path):
            os.makedirs(os.path.dirname(local_path), exist_ok=True)
            shutil.copy2(drive_path, local_path)
            print(f"Restored from Drive: {drive_path} -> {local_path}")
    except Exception as e:
        print(f"[WARN] Drive restore failed for {os.path.basename(local_path)}: {e}")

if DRIVE_ENABLED:
    drive.mount(DRIVE_MOUNTPOINT)
    print(f"Drive mounted at {DRIVE_MOUNTPOINT}")
else:
    print("Not running in Colab; Drive mirroring disabled.")

# =========================
# Device / perf knobs
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device.type == "cuda":
    # TF32 is a big win on Ampere+ GPUs for matmuls
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass
else:
    # CPU thread tuning (can help on Colab CPU)
    try:
        torch.set_num_threads(os.cpu_count() or 2)
        torch.set_num_interop_threads(1)
    except Exception:
        pass

print(f"Using device: {device}")

# =========================
# Experiment configuration
# =========================
FAST_DEV_RUN = True
# If True, drastically shrinks the sweep for a quick smoke test.

MODEL_NAME = "gpt2"

# =========================
# Token selection mode
# =========================
# Options:
#   "random"     : current behavior (stable random pool, subset property via MAX_M slicing)
#   "top_freq"   : take top-m tokens by corpus frequency
#   "freq_sample": sample tokens without replacement so selection is proportional to corpus frequency
TOKEN_SELECTION_MODE = "top_freq"

# Frequency corpus settings (used when TOKEN_SELECTION_MODE != "random")
# Default corpus: WikiText-103 train.
FREQ_SOURCE_TAG = "wikitext103"
FREQ_DATASET_NAME = "wikitext"
FREQ_DATASET_CONFIG = "wikitext-103-raw-v1"
FREQ_DATASET_SPLIT = "train"

# Counting controls
FREQ_MAX_TOKENS = 10_000_000
FREQ_BATCH_TEXTS = 256
FREQ_STREAMING = True

# Only used if FREQ_STREAMING=False
FREQ_SHUFFLE_SEED = 0

# Weighted pool sampling parameters (used for "freq_sample")
FREQ_SAMPLE_SEED = 42
FREQ_SMOOTHING = 1.0  # pseudo-count so log(weights) is safe

# Task params
DEFAULT_L = 16
P_PRESENT = 1
EXCLUDE_QUERY_FROM_CONTEXT = True   # if True, query token is guaranteed not to be one of the pair-tokens

# Embedding preprocessing
CENTER_EMBEDDINGS = False
L2_NORMALIZE_EMBEDDINGS = False

# -------------------------
# Training/perf knobs (easy to tweak)
# -------------------------
BATCH_SIZE = 64             # micro-batch size (per optimizer step is BATCH_SIZE * GRAD_ACCUM_STEPS)
GRAD_ACCUM_STEPS = 1        # set >1 to simulate larger effective batch without more memory

LR_SCALE_WITH_EFFECTIVE_BS = True
REF_BATCH_SIZE = 256
REF_LR = 1e-3
WEIGHT_DECAY = 0.01

# Override feedforward width (MLP). None => default 4*d_model (GPT-2 style, very expensive).
D_FF = 32

MAX_STEPS = 50_000
EVAL_INTERVAL = 500
VAL_EXAMPLES = 10_000
TEST_EXAMPLES = 50_000
EARLYSTOP_ACC = 0.995
EARLYSTOP_EVALS = 5

# Eval batch size (None => auto)
EVAL_BATCH_SIZE = None

# Batch generation & transfers
SAMPLE_ON_DEVICE = False     # if True and GPU, sampler generates batches directly on GPU
PIN_MEMORY = True            # only used when SAMPLE_ON_DEVICE=False and device is cuda

# AMP for speed on GPU
AMP_DTYPE = torch.float16    # torch.float16 is safest on most Colab GPUs; bf16 may work on some GPUs
USE_AMP = (device.type == "cuda")
USE_GRAD_SCALER = USE_AMP and (AMP_DTYPE == torch.float16)

# Optional: torch.compile (often not worth it for sweeps with many fresh models)
COMPILE_MODEL = False

# Seeds
SEEDS = [0, 1, 2]  # default 3 seeds

# Sweep lists
m_list = [256, 512, 1024, 2048, 4096, 8192, 10_000]
L_list = [16]
D_K_list = [8, 16, 24, 32, 48, 64, 96, 128, 192, 256, 384, 512]
h_list = [1, 2, 4, 8, 16]

if FAST_DEV_RUN:
    m_list = [6144]
    D_K_list = [64,128,256,512,1024]
    h_list = [1,2,4,8,16]
    SEEDS = [20,21,22,23,24]
    MAX_STEPS = 100000
    EVAL_INTERVAL = 250
    VAL_EXAMPLES = 2_000
    TEST_EXAMPLES = 5_000
    print("FAST_DEV_RUN=True (smoke test mode)")

def _compute_lr():
    eff_bs = int(BATCH_SIZE) * int(GRAD_ACCUM_STEPS)
    if LR_SCALE_WITH_EFFECTIVE_BS:
        return float(REF_LR) * (eff_bs / float(REF_BATCH_SIZE))
    return float(REF_LR)

LR = _compute_lr()
EFFECTIVE_BATCH_SIZE = int(BATCH_SIZE) * int(GRAD_ACCUM_STEPS)

# Auto output dir (prevents accidentally resuming mismatched hyperparams)
RUN_TAG = None  # set to a string if you want to force a fixed output directory name
TASK_TAG = "paired_partner"  # distinguishes this experiment from the old one
if RUN_TAG is None:
    dff_tag = str(D_FF) if D_FF is not None else "4x"
    lr_tag = f"{LR:.0e}"  # e.g. 1e-03

    token_tag = ""
    if TOKEN_SELECTION_MODE != "random":
        token_tag = f"_tok{TOKEN_SELECTION_MODE}_{FREQ_SOURCE_TAG}_maxtok{int(FREQ_MAX_TOKENS)}"

    RUN_TAG = f"{TASK_TAG}_{MODEL_NAME}_L{DEFAULT_L}_bs{BATCH_SIZE}_ga{GRAD_ACCUM_STEPS}_eff{EFFECTIVE_BATCH_SIZE}_lr{lr_tag}_dff{dff_tag}_amp{int(USE_AMP)}{token_tag}"

# Local output dir (as before)
OUT_DIR = os.path.join("./capacity_threshold_outputs", RUN_TAG)
os.makedirs(OUT_DIR, exist_ok=True)

# Matching Drive output dir
DRIVE_OUT_DIR = os.path.join(DRIVE_PROJECT_ROOT, RUN_TAG) if DRIVE_ENABLED else None
if DRIVE_OUT_DIR is not None:
    os.makedirs(DRIVE_OUT_DIR, exist_ok=True)

CSV_PATH = os.path.join(OUT_DIR, "runs.csv")
CSV_PATH_DRIVE = os.path.join(DRIVE_OUT_DIR, "runs.csv") if DRIVE_OUT_DIR is not None else None

# If we're in a fresh runtime, restore CSV from Drive so resume works
_maybe_restore_from_drive(CSV_PATH_DRIVE, CSV_PATH)

print(f"Config: BATCH_SIZE={BATCH_SIZE}, GRAD_ACCUM_STEPS={GRAD_ACCUM_STEPS}, EFFECTIVE_BS={EFFECTIVE_BATCH_SIZE}, LR={LR:.3g}, D_FF={D_FF}")
print(f"Token pool mode: {TOKEN_SELECTION_MODE}")
print(f"Local outputs: {OUT_DIR}")
if DRIVE_OUT_DIR is not None:
    print(f"Drive outputs: {DRIVE_OUT_DIR}")

# =========================
# Load GPT-2 embeddings (frozen)
# =========================
print("Loading GPT-2 tokenizer/model to extract frozen embeddings...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
gpt2 = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
with torch.no_grad():
    WTE = gpt2.get_input_embeddings().weight.detach().cpu()  # (vocab, d_model)
d_model = WTE.shape[1]
vocab_size = WTE.shape[0]
NULL_ID = int(tokenizer.eos_token_id)  # recommended
del gpt2
gc.collect()
if device.type == "cuda":
    torch.cuda.empty_cache()

print(f"GPT-2 vocab_size={vocab_size}, d_model={d_model}, NULL_ID={NULL_ID}")

# =========================
# Utilities
# =========================
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device.type == "cuda":
        torch.cuda.manual_seed_all(seed)

# -------------------------
# Frequency-aware token pool (NEW)
# -------------------------
def _slug(s: str) -> str:
    return "".join(c if (c.isalnum() or c in "-_.") else "_" for c in str(s))

# Shared cache dir (not tied to RUN_TAG), so you compute counts once and reuse.
TOKEN_FREQ_CACHE_DIR = os.path.join("./capacity_threshold_outputs", "_token_freq_cache")
os.makedirs(TOKEN_FREQ_CACHE_DIR, exist_ok=True)

TOKEN_FREQ_CACHE_DIR_DRIVE = os.path.join(DRIVE_PROJECT_ROOT, "_token_freq_cache") if DRIVE_ENABLED else None
if TOKEN_FREQ_CACHE_DIR_DRIVE is not None:
    os.makedirs(TOKEN_FREQ_CACHE_DIR_DRIVE, exist_ok=True)

def _counts_cache_paths(model_name: str) -> tuple[str, str | None]:
    fname = (
        f"counts_{_slug(model_name)}_{_slug(FREQ_SOURCE_TAG)}_"
        f"{_slug(FREQ_DATASET_NAME)}_{_slug(FREQ_DATASET_CONFIG)}_{_slug(FREQ_DATASET_SPLIT)}_"
        f"maxtok{int(FREQ_MAX_TOKENS)}.npz"
    )
    local_path = os.path.join(TOKEN_FREQ_CACHE_DIR, fname)
    drive_path = os.path.join(TOKEN_FREQ_CACHE_DIR_DRIVE, fname) if TOKEN_FREQ_CACHE_DIR_DRIVE is not None else None
    return local_path, drive_path

def _extract_text(example: dict) -> str | None:
    # Common field names first
    if "text" in example and isinstance(example["text"], str):
        return example["text"]
    if "content" in example and isinstance(example["content"], str):
        return example["content"]
    # Fallback: first string field
    for _, v in example.items():
        if isinstance(v, str):
            return v
    return None

def load_or_compute_token_counts(tokenizer, vocab_size: int, model_name: str) -> np.ndarray:
    """
    Returns counts: (vocab_size,) int64 token counts estimated from a corpus.
    Uses a shared on-disk cache + optional Drive mirroring.
    """
    cache_local, cache_drive = _counts_cache_paths(model_name)

    # Restore cache from Drive if needed
    _maybe_restore_from_drive(cache_drive, cache_local)

    if os.path.exists(cache_local):
        data = np.load(cache_local)
        counts = data["counts"].astype(np.int64)
        if counts.shape[0] != vocab_size:
            raise ValueError(f"Cached counts vocab mismatch: {counts.shape[0]} != {vocab_size}")
        return counts

    # Otherwise compute
    print(f"[TokenFreq] Computing token counts from corpus: {FREQ_DATASET_NAME}/{FREQ_DATASET_CONFIG}:{FREQ_DATASET_SPLIT}")
    print(f"[TokenFreq] streaming={FREQ_STREAMING}, max_tokens={int(FREQ_MAX_TOKENS):,}, batch_texts={int(FREQ_BATCH_TEXTS)}")

    from datasets import load_dataset  # lazy import so random-mode doesn't need it at runtime

    ds = load_dataset(
        FREQ_DATASET_NAME,
        FREQ_DATASET_CONFIG,
        split=FREQ_DATASET_SPLIT,
        streaming=bool(FREQ_STREAMING),
    )

    # Optional shuffle only if not streaming
    if (not FREQ_STREAMING) and hasattr(ds, "shuffle"):
        ds = ds.shuffle(seed=int(FREQ_SHUFFLE_SEED))

    counts = np.zeros((vocab_size,), dtype=np.int64)
    total_tokens = 0
    batch: list[str] = []

    def _process_batch(texts: list[str]):
        nonlocal total_tokens, counts
        enc = tokenizer(
            texts,
            add_special_tokens=False,
            return_attention_mask=False,
            return_token_type_ids=False,
            truncation=False,
        )
        ids_list = enc["input_ids"]
        if len(ids_list) == 0:
            return
        if not any(len(x) > 0 for x in ids_list):
            return
        flat = np.concatenate([np.asarray(x, dtype=np.int64) for x in ids_list if len(x) > 0], axis=0)
        if flat.size == 0:
            return
        counts += np.bincount(flat, minlength=vocab_size)
        total_tokens += int(flat.size)

    # Iterate until we hit max token budget
    for ex in ds:
        txt = _extract_text(ex)
        if not txt:
            continue
        batch.append(txt)
        if len(batch) >= int(FREQ_BATCH_TEXTS):
            _process_batch(batch)
            batch = []
            if total_tokens >= int(FREQ_MAX_TOKENS):
                break

    if batch and total_tokens < int(FREQ_MAX_TOKENS):
        _process_batch(batch)

    print(f"[TokenFreq] Done. Counted ~{total_tokens:,} tokens.")

    np.savez_compressed(
        cache_local,
        counts=counts,
        total_tokens=np.array([total_tokens], dtype=np.int64),
        model_name=np.array([model_name]),
        dataset_name=np.array([FREQ_DATASET_NAME]),
        dataset_config=np.array([FREQ_DATASET_CONFIG]),
        split=np.array([FREQ_DATASET_SPLIT]),
        max_tokens=np.array([int(FREQ_MAX_TOKENS)], dtype=np.int64),
    )
    _sync_file_to_drive(cache_local, cache_drive)
    return counts

def build_item_token_ids_random(m: int, null_id: int, vocab_size: int) -> np.ndarray:
    """
    Current behavior: stable uniform random set (seed=42) without replacement, excluding null_id.
    Also preserves subset property by sampling MAX_M once and slicing.
    """
    all_indices = np.arange(vocab_size)
    valid_candidates = all_indices[all_indices != null_id]
    if len(valid_candidates) < m:
        raise ValueError(f"Could not collect {m} token IDs (vocab_size={vocab_size}).")
    rng = np.random.default_rng(seed=42)
    ids = rng.choice(valid_candidates, size=m, replace=False)
    return ids.astype(np.int64)

def build_item_token_pool(
    max_m: int,
    mode: str,
    null_id: int,
    vocab_size: int,
    tokenizer,
    model_name: str,
) -> np.ndarray:
    """
    Returns an ordered pool of length max_m such that smaller m is always a prefix slice.

    mode:
      - "random": stable uniform random subset
      - "top_freq": top tokens by estimated corpus frequency
      - "freq_sample": weighted permutation (Plackett-Luce / Gumbel trick), proportional to corpus freq
    """
    mode = str(mode).lower()
    if mode == "random":
        return build_item_token_ids_random(m=max_m, null_id=null_id, vocab_size=vocab_size)

    counts = load_or_compute_token_counts(tokenizer=tokenizer, vocab_size=vocab_size, model_name=model_name)

    ids = np.arange(vocab_size, dtype=np.int64)
    mask = (ids != int(null_id))
    ids = ids[mask]
    c = counts[mask].astype(np.int64)

    if mode == "top_freq":
        # Sort by (-count, token_id) for determinism under ties
        order = np.lexsort((ids, -c))
        pool = ids[order][:max_m]
        return pool.astype(np.int64)

    if mode == "freq_sample":
        # Weighted random permutation using Gumbel-top-k:
        # scores = log(w) + gumbel, sort desc => sequential sampling proportional to w
        w = c.astype(np.float64) + float(FREQ_SMOOTHING)
        rng = np.random.default_rng(int(FREQ_SAMPLE_SEED))
        u = rng.random(size=w.shape[0])
        g = -np.log(-np.log(u + 1e-12) + 1e-12)
        scores = np.log(w) + g
        order = np.argsort(scores)[::-1]
        pool = ids[order][:max_m]
        return pool.astype(np.int64)

    raise ValueError(f"Unknown TOKEN_SELECTION_MODE={mode!r}. Use 'random'|'top_freq'|'freq_sample'.")

def build_restricted_embeddings_from_pool(
    m: int,
    center_embeddings: bool,
    l2_normalize: bool,
    device: torch.device,
    embed_dtype: torch.dtype,
):
    """
    Returns:
      E_restricted: (m+1, d_model) embeddings (items then NULL), frozen buffer.
      restricted_ids: np array (m+1,) of GPT-2 vocab IDs.
    Preprocessing: subtract mean of item embeddings (if center), then L2-normalize rowwise (if l2_normalize).
    Applies same transform to NULL embedding.
    """
    E_items = E_ITEM_POOL_CPU[:m]           # (m, d_model)
    E = torch.cat([E_items, E_NULL_CPU], dim=0).clone()  # (m+1, d_model)

    if center_embeddings:
        mean = E_items.mean(dim=0, keepdim=True)  # mean over ITEMS only
        E = E - mean

    if l2_normalize:
        eps = 1e-8
        norms = E.norm(dim=1, keepdim=True).clamp_min(eps)
        E = E / norms

    restricted_ids = np.concatenate([ITEM_TOKEN_POOL[:m], np.array([NULL_ID], dtype=np.int64)], axis=0)

    # Move to device and pick dtype
    E = E.to(device=device, dtype=embed_dtype)
    return E, restricted_ids

def make_perm(m: int, seed: int) -> torch.LongTensor:
    """Fixed random permutation over {0..m-1}."""
    rng = np.random.default_rng(seed)
    perm = rng.permutation(m).astype(np.int64)
    return torch.from_numpy(perm).long()

# =========================
# Build the token pool + embedding cache from GPT-2 embeddings
# =========================
MAX_M = int(max(m_list))

ITEM_TOKEN_POOL = build_item_token_pool(
    max_m=MAX_M,
    mode=TOKEN_SELECTION_MODE,
    null_id=NULL_ID,
    vocab_size=vocab_size,
    tokenizer=tokenizer,
    model_name=MODEL_NAME,
)

print(f"[TokenPool] mode={TOKEN_SELECTION_MODE}, MAX_M={MAX_M}, example_ids={ITEM_TOKEN_POOL[:10].tolist()}")

ITEM_TOKEN_POOL_T = torch.from_numpy(ITEM_TOKEN_POOL).long()
E_ITEM_POOL_CPU = WTE[ITEM_TOKEN_POOL_T].float()                       # (MAX_M, d_model)
E_NULL_CPU = WTE[torch.tensor([NULL_ID]).long()].float()               # (1, d_model)

# =========================
# Vectorized batch sampler helpers
# =========================
_UPPER_TRI_CACHE = {}
def _upper_tri_mask(k: int, device: torch.device):
    key = (k, str(device))
    if key not in _UPPER_TRI_CACHE:
        _UPPER_TRI_CACHE[key] = torch.triu(torch.ones(k, k, dtype=torch.bool, device=device), diagonal=1)
    return _UPPER_TRI_CACHE[key]

def sample_unique_matrix(
    m: int,
    B: int,
    k: int,
    forbidden: torch.LongTensor | None,
    generator: torch.Generator,
    device: torch.device,
    max_tries: int = 50,
) -> torch.LongTensor:
    """
    Samples (B,k) integers in [0,m) such that:
      - no forbidden values appear (forbidden shape: (B,f))
      - no duplicates within each row
    Uses rejection/resampling, vectorized with O(B*k^2) duplicate checks.
    """
    x = torch.randint(0, m, (B, k), generator=generator, device=device)
    upper = _upper_tri_mask(k, device=x.device)

    for _ in range(max_tries):
        bad = torch.zeros((B, k), dtype=torch.bool, device=x.device)

        if forbidden is not None:
            bad |= (x[..., None] == forbidden[:, None, :]).any(dim=-1)

        # mark duplicates AFTER the first occurrence
        dup = (x[:, :, None] == x[:, None, :])  # (B,k,k)
        dup_pos = (dup & upper).any(dim=1)      # (B,k) marks positions j that match some i<j
        bad |= dup_pos

        if not bad.any():
            return x

        nbad = int(bad.sum().item())
        x[bad] = torch.randint(0, m, (nbad,), generator=generator, device=x.device)

    raise RuntimeError("sample_unique_matrix: failed to sample unique rows; consider increasing max_tries.")

# =========================
# NEW: Paired-context sampler
# =========================
class PairedRelationalBatchSampler:
    """
    NEW TASK sampler.

    For each example:
      - Sample 2L distinct item indices (0..m-1) uniformly without replacement.
      - Shuffle and pair them into L pairs (positions 0..L-1). Each item embedding = E[a]+E[b].
      - Insert a NULL separator token at position L (index m in the restricted vocab).
      - Choose query token q for position L+1 such that:
          * With prob p_present, s=perm[q] IS in the 2L context tokens.
          * With prob (1-p_present), s=perm[q] is NOT in the 2L context tokens.
        Additionally, enforce perm[q] != q.
        If exclude_query_from_context=True, also enforce q is not among the 2L context tokens.

      - Target:
          * if s is present in context, return partner (the other token in s's pair)
          * else return NULL index (m)
    """
    def __init__(
        self,
        m: int,
        L: int,
        perm: torch.LongTensor,
        p_present: float,
        exclude_query_from_context: bool,
        seed: int,
        sample_on_device: bool,
        target_device: torch.device,
    ):
        self.m = int(m)
        self.L = int(L)
        self.p_present = float(p_present)
        self.exclude_query = bool(exclude_query_from_context)

        self.sample_device = target_device if sample_on_device else torch.device("cpu")
        self.perm = perm.clone().long().to(self.sample_device)

        # Precompute inverse permutation on the sampling device
        self.perm_inv = torch.empty_like(self.perm)
        self.perm_inv[self.perm] = torch.arange(self.m, device=self.sample_device)

        self.gen = torch.Generator(device=self.sample_device)
        self.gen.manual_seed(int(seed))

    def _sample_ctx_tokens(self, B: int) -> torch.LongTensor:
        """Sample (B,2L) distinct item indices and shuffle them for random pairing."""
        m, L = self.m, self.L
        dev = self.sample_device
        k = 2 * L

        ctx = sample_unique_matrix(m=m, B=B, k=k, forbidden=None, generator=self.gen, device=dev)
        shuffle_idx = torch.rand(B, k, generator=self.gen, device=dev).argsort(dim=1)
        ctx = ctx.gather(1, shuffle_idx)
        return ctx

    def sample(self, B: int):
        m, L = self.m, self.L
        dev = self.sample_device
        k = 2 * L
        null_index = m  # index of NULL in restricted vocab

        # 1) Sample context token multiset (2L unique tokens), independent of perm
        ctx_tokens = self._sample_ctx_tokens(B)  # (B,2L)

        # 2) Decide whether successor is present
        present = (torch.rand(B, generator=self.gen, device=dev) < self.p_present)

        # 3) Allocate outputs
        q = torch.empty((B,), dtype=torch.long, device=dev)
        targets = torch.full((B,), fill_value=null_index, dtype=torch.long, device=dev)

        # 4) Handle present=True rows
        if present.any():
            rows_p = torch.nonzero(present, as_tuple=False).squeeze(1)
            ctx_p = ctx_tokens[rows_p].clone()

            # Ensure at least one valid successor token exists per row
            while True:
                pre = self.perm_inv[ctx_p]            # predecessor of each ctx token, shape (Bp,2L)
                valid = (pre != ctx_p)                # excludes fixed points (perm[q]=q)
                if self.exclude_query:
                    # require q=pre not to be in the context token set
                    pre_in_ctx = (pre[:, :, None] == ctx_p[:, None, :]).any(dim=-1)  # (Bp,2L)
                    valid &= ~pre_in_ctx
                has_valid = valid.any(dim=1)
                if has_valid.all():
                    break

                # Extremely rare: resample context for rows with no valid option
                bad_rows = torch.nonzero(~has_valid, as_tuple=False).squeeze(1)
                nbad = int(bad_rows.numel())
                ctx_p[bad_rows] = self._sample_ctx_tokens(nbad)

            # Choose a valid position j
            pre = self.perm_inv[ctx_p]
            valid = (pre != ctx_p)
            if self.exclude_query:
                pre_in_ctx = (pre[:, :, None] == ctx_p[:, None, :]).any(dim=-1)
                valid &= ~pre_in_ctx

            scores = torch.rand(ctx_p.shape, generator=self.gen, device=dev)
            scores[~valid] = -1.0
            j = scores.argmax(dim=1)  # (Bp,)

            s = ctx_p.gather(1, j[:, None]).squeeze(1)             # successor token (in context)
            q_p = self.perm_inv[s]                                  # query token so that perm[q]=s
            partner = ctx_p.gather(1, (j ^ 1)[:, None]).squeeze(1)  # partner in the same pair

            q[rows_p] = q_p
            targets[rows_p] = partner
            ctx_tokens[rows_p] = ctx_p

        # 5) Handle present=False rows
        if (~present).any():
            rows_a = torch.nonzero(~present, as_tuple=False).squeeze(1)
            ctx_a = ctx_tokens[rows_a]  # (Ba,2L)
            Ba = int(ctx_a.shape[0])

            q_a = torch.randint(0, m, (Ba,), generator=self.gen, device=dev)
            while True:
                s_a = self.perm[q_a]
                bad = (s_a == q_a)  # exclude fixed points

                if self.exclude_query:
                    bad |= (ctx_a == q_a[:, None]).any(dim=1)

                bad |= (ctx_a == s_a[:, None]).any(dim=1)

                if not bad.any():
                    break
                nbad = int(bad.sum().item())
                q_a[bad] = torch.randint(0, m, (nbad,), generator=self.gen, device=dev)

            q[rows_a] = q_a
            # targets already NULL

        # 6) Return context pairs as (B,L,2)
        ctx_pairs = ctx_tokens.view(B, L, 2)
        return ctx_pairs, q, targets

# =========================
# Model: single GPT-style block with controlled QK budget
# (Query-only forward: compute ONLY last token output)
# =========================
class MHA_QKBudget(nn.Module):
    def __init__(self, d_model: int, h: int, d_k: int):
        super().__init__()
        assert d_model % h == 0, "Need h | d_model for d_v = d_model/h."
        assert d_k >= 1
        self.d_model = int(d_model)
        self.h = int(h)
        self.d_k = int(d_k)
        self.d_v = int(d_model // h)

        self.W_Q = nn.Linear(d_model, h * d_k, bias=False)
        self.W_K = nn.Linear(d_model, h * d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)    # total value dim fixed at d_model
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward_last(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute attention output ONLY at the last position (query token).
        x: (B, T, d_model)
        returns: (B, d_model)
        """
        B, T, C = x.shape
        h, d_k, d_v = self.h, self.d_k, self.d_v

        # K,V for all tokens
        k = self.W_K(x).view(B, T, h, d_k).transpose(1, 2)  # (B,h,T,d_k)
        v = self.W_V(x).view(B, T, h, d_v).transpose(1, 2)  # (B,h,T,d_v)

        # Q only for the last token
        q = self.W_Q(x[:, -1:, :]).view(B, 1, h, d_k).transpose(1, 2)  # (B,h,1,d_k)

        # For causal attention, the last token can attend to all tokens <= itself => all T positions
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B,h,1,T)
        attn = torch.softmax(scores, dim=-1)
        out  = torch.matmul(attn, v)  # (B,h,1,d_v)

        out = out.transpose(1, 2).contiguous().view(B, 1, C)  # (B,1,d_model)
        out = self.W_O(out).squeeze(1)                        # (B,d_model)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, h: int, d_k: int):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MHA_QKBudget(d_model=d_model, h=h, d_k=d_k)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(approximate="tanh"),
            nn.Linear(d_ff, d_model),
        )

    def forward_last(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute block output ONLY at the last position.
        x: (B,T,d_model)
        returns: (B,d_model)
        """
        # Attention residual only at last token
        x_ln = self.ln1(x)
        attn_last = self.attn.forward_last(x_ln)     # (B,d_model)
        x_last = x[:, -1, :] + attn_last

        # MLP residual only at last token
        x_last = x_last + self.mlp(self.ln2(x_last))
        return x_last

class SingleBlockPairedRelationalModel(nn.Module):
    """
    - Frozen input/output embeddings (buffer): E (V_restricted, d_model)
    - One transformer block (query-only compute)
    - Input sequence embeddings:
        positions 0..L-1: E[a]+E[b] for each pair (a,b)
        position L: NULL separator embedding E[NULL]
        position L+1: query embedding E[q]
    - Output logits: h_q @ E^T over restricted vocab (items + NULL)
    """
    def __init__(self, E_restricted: torch.Tensor, h: int, d_k: int, d_ff: int | None = None):
        super().__init__()
        self.register_buffer("E", E_restricted.contiguous(), persistent=False)
        d_model = E_restricted.shape[1]
        if d_ff is None:
            d_ff = 4 * d_model
        self.block = TransformerBlock(d_model=d_model, d_ff=int(d_ff), h=h, d_k=d_k)
        self.null_index = int(E_restricted.shape[0] - 1)

        # GPT-2-ish init for stability
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, ctx_pairs: torch.LongTensor, q_idx: torch.LongTensor) -> torch.Tensor:
        """
        ctx_pairs: (B, L, 2) indices in [0, m) (item indices only; never NULL)
        q_idx:     (B,) query indices in [0, m)
        returns logits: (B, V_restricted=m+1)
        """
        B, L, _ = ctx_pairs.shape
        # Pair embeddings as sums
        e0 = self.E[ctx_pairs[:, :, 0]]  # (B,L,d_model)
        e1 = self.E[ctx_pairs[:, :, 1]]  # (B,L,d_model)
        ctx_emb = e0 + e1

        d_model = ctx_emb.shape[-1]
        x = torch.empty((B, L + 2, d_model), dtype=ctx_emb.dtype, device=ctx_emb.device)
        x[:, :L, :] = ctx_emb
        x[:, L, :] = self.E[self.null_index].unsqueeze(0).expand(B, -1)
        x[:, L + 1, :] = self.E[q_idx]

        h_q = self.block.forward_last(x)    # (B,d_model)
        logits = F.linear(h_q, self.E)      # (B, m+1)
        return logits

# =========================
# Train / eval
# =========================
def _autocast_ctx():
    if device.type == "cuda":
        return torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=USE_AMP)
    return torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=False)

@torch.inference_mode()
def eval_accuracy(model: nn.Module, sampler: PairedRelationalBatchSampler, n_examples: int, batch_size: int):
    model.eval()
    correct = 0
    total = 0
    n_batches = int(math.ceil(n_examples / batch_size))

    for _ in range(n_batches):
        bs = min(batch_size, n_examples - total)
        ctx_pairs, q, targets = sampler.sample(bs)

        if sampler.sample_device != device:
            if device.type == "cuda" and PIN_MEMORY:
                ctx_pairs = ctx_pairs.pin_memory()
                q = q.pin_memory()
                targets = targets.pin_memory()
            ctx_pairs = ctx_pairs.to(device, non_blocking=True)
            q = q.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

        with _autocast_ctx():
            logits = model(ctx_pairs, q)

        pred = logits.argmax(dim=-1)
        correct += (pred == targets).sum().item()
        total += bs

    return correct / max(1, total)

def train_one_run(
    m: int,
    L: int,
    D_K: int,
    h: int,
    seed: int,
):
    """
    Returns dict with:
      train_steps, best_val_acc, test_acc, time_seconds
    """
    assert D_K % h == 0
    d_k = D_K // h
    assert d_k >= 4

    set_all_seeds(seed)

    # Build permutation graph (seed affects it)
    perm = make_perm(m=m, seed=seed)

    # Embeddings: store in AMP dtype on GPU for bandwidth wins; float32 on CPU
    embed_dtype = AMP_DTYPE if device.type == "cuda" else torch.float32

    # Build restricted embeddings (items + NULL)
    E_restricted, restricted_ids = build_restricted_embeddings_from_pool(
        m=m,
        center_embeddings=CENTER_EMBEDDINGS,
        l2_normalize=L2_NORMALIZE_EMBEDDINGS,
        device=device,
        embed_dtype=embed_dtype,
    )
    # E_restricted indices: 0..m-1 items, m is NULL

    # Samplers
    train_sampler = PairedRelationalBatchSampler(
        m=m, L=L, perm=perm, p_present=P_PRESENT,
        exclude_query_from_context=EXCLUDE_QUERY_FROM_CONTEXT,
        seed=seed + 10_000,
        sample_on_device=SAMPLE_ON_DEVICE,
        target_device=device,
    )
    val_sampler = PairedRelationalBatchSampler(
        m=m, L=L, perm=perm, p_present=P_PRESENT,
        exclude_query_from_context=EXCLUDE_QUERY_FROM_CONTEXT,
        seed=seed + 20_000,
        sample_on_device=SAMPLE_ON_DEVICE,
        target_device=device,
    )
    test_sampler = PairedRelationalBatchSampler(
        m=m, L=L, perm=perm, p_present=P_PRESENT,
        exclude_query_from_context=EXCLUDE_QUERY_FROM_CONTEXT,
        seed=seed + 30_000,
        sample_on_device=SAMPLE_ON_DEVICE,
        target_device=device,
    )

    # Model
    model = SingleBlockPairedRelationalModel(E_restricted=E_restricted, h=h, d_k=d_k, d_ff=D_FF).to(device)
    if COMPILE_MODEL and hasattr(torch, "compile"):
        model = torch.compile(model, mode="reduce-overhead")

    model.train()

    # Optimizer
    try:
        optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, fused=(device.type == "cuda"))
    except TypeError:
        optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    scaler = torch.cuda.amp.GradScaler(enabled=USE_GRAD_SCALER)

    # Baseline sanity-check (always predict NULL)
    baseline_null_acc = 1.0 - P_PRESENT

    best_val = 0.0
    consec = 0
    steps_done = 0

    if device.type == "cuda":
        torch.cuda.synchronize()
    t0 = time.time()

    # Eval batch size auto
    eval_bs = EVAL_BATCH_SIZE
    if eval_bs is None:
        eval_bs = 1024 if device.type == "cuda" else max(256, BATCH_SIZE)

    pbar = tqdm(range(1, MAX_STEPS + 1), desc=f"train m={m} D_K={D_K} h={h} seed={seed}", leave=False)
    for step in pbar:
        steps_done = step

        optimizer.zero_grad(set_to_none=True)
        loss_val = 0.0

        for _ in range(int(GRAD_ACCUM_STEPS)):
            ctx_pairs, q, targets = train_sampler.sample(BATCH_SIZE)

            if train_sampler.sample_device != device:
                if device.type == "cuda" and PIN_MEMORY:
                    ctx_pairs = ctx_pairs.pin_memory()
                    q = q.pin_memory()
                    targets = targets.pin_memory()
                ctx_pairs = ctx_pairs.to(device, non_blocking=True)
                q = q.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)

            with _autocast_ctx():
                logits = model(ctx_pairs, q)
                loss = F.cross_entropy(logits, targets) / float(GRAD_ACCUM_STEPS)

            if USE_GRAD_SCALER:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            loss_val += float(loss.item())

        if USE_GRAD_SCALER:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()

        if step % EVAL_INTERVAL == 0:
            val_acc = eval_accuracy(model, val_sampler, n_examples=VAL_EXAMPLES, batch_size=int(eval_bs))
            best_val = max(best_val, val_acc)

            if val_acc >= EARLYSTOP_ACC:
                consec += 1
            else:
                consec = 0

            pbar.set_postfix(loss=float(loss_val), val_acc=float(val_acc), best_val=float(best_val), consec=consec)

            if consec >= EARLYSTOP_EVALS:
                break

    # Test
    test_acc = eval_accuracy(model, test_sampler, n_examples=TEST_EXAMPLES, batch_size=int(eval_bs))

    if device.type == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()
    dt = t1 - t0

    # Cleanup
    del model
    gc.collect()
    if device.type == "cuda":
        torch.cuda.empty_cache()

    return {
        "train_steps": int(steps_done),
        "best_val_acc": float(best_val),
        "test_acc": float(test_acc),
        "baseline_null_acc": float(baseline_null_acc),
        "time_seconds": float(dt),
    }

# =========================
# Sweep runner + logging (resume-safe)
# =========================
def valid_head_partitions(D_K: int, h_list, d_model: int):
    out = []
    for h in h_list:
        if d_model % h != 0:
            continue
        if D_K % h != 0:
            continue
        d_k = D_K // h
        if d_k < 4:
            continue
        out.append((h, d_k))
    return out

def load_existing_runs(csv_path: str):
    if os.path.exists(csv_path):
        try:
            df = pd.read_csv(csv_path)
            return df
        except Exception:
            return pd.DataFrame()
    return pd.DataFrame()

def run_sweep():
    existing = load_existing_runs(CSV_PATH)
    done_keys = set()
    if len(existing) > 0:
        # Resume: skip configs already present in this CSV
        for _, r in existing.iterrows():
            done_keys.add((int(r["m"]), int(r["L"]), int(r["D_K"]), int(r["h"]), int(r["seed"])))
        print(f"Found existing CSV with {len(existing)} runs; will resume/skip completed configs.")

    planned = []
    for m in m_list:
        for L in L_list:
            for D_K in D_K_list:
                for (h, d_k) in valid_head_partitions(D_K, h_list, d_model):
                    for seed in SEEDS:
                        k = (m, L, D_K, h, seed)
                        if k not in done_keys:
                            planned.append((m, L, D_K, h, d_k, seed))
    print(f"Planned runs remaining: {len(planned)}")

    rows = []
    outer = tqdm(planned, desc="sweep", leave=True)
    for (m, L, D_K, h, d_k, seed) in outer:
        outer.set_postfix(m=m, D_K=D_K, h=h, d_k=d_k, seed=seed)

        try:
            res = train_one_run(m=m, L=L, D_K=D_K, h=h, seed=seed)
            row = {
                "m": m,
                "L": L,
                "d_model": d_model,
                "D_K": D_K,
                "h": h,
                "d_k": d_k,
                "seed": seed,
                "train_steps": res["train_steps"],
                "best_val_acc": res["best_val_acc"],
                "test_acc": res["test_acc"],
                "baseline_null_acc": res["baseline_null_acc"],
                "time_seconds": res["time_seconds"],
                "center_embeddings": int(CENTER_EMBEDDINGS),
                "l2_normalize_embeddings": int(L2_NORMALIZE_EMBEDDINGS),
                "exclude_query_from_context": int(EXCLUDE_QUERY_FROM_CONTEXT),
                "p_present": P_PRESENT,
                "model_name": MODEL_NAME,
                "task_tag": TASK_TAG,
                "use_amp": int(USE_AMP),
                "amp_dtype": str(AMP_DTYPE).replace("torch.", ""),
                "batch_size": int(BATCH_SIZE),
                "grad_accum_steps": int(GRAD_ACCUM_STEPS),
                "effective_batch_size": int(EFFECTIVE_BATCH_SIZE),
                "lr": float(LR),
                "weight_decay": float(WEIGHT_DECAY),
                "d_ff": int(D_FF) if D_FF is not None else int(4 * d_model),
                "sample_on_device": int(SAMPLE_ON_DEVICE),

                # NEW: token pool provenance
                "token_selection_mode": str(TOKEN_SELECTION_MODE),
                "freq_source_tag": str(FREQ_SOURCE_TAG) if TOKEN_SELECTION_MODE != "random" else "",
                "freq_dataset_name": str(FREQ_DATASET_NAME) if TOKEN_SELECTION_MODE != "random" else "",
                "freq_dataset_config": str(FREQ_DATASET_CONFIG) if TOKEN_SELECTION_MODE != "random" else "",
                "freq_dataset_split": str(FREQ_DATASET_SPLIT) if TOKEN_SELECTION_MODE != "random" else "",
                "freq_max_tokens": int(FREQ_MAX_TOKENS) if TOKEN_SELECTION_MODE != "random" else 0,
                "freq_streaming": int(bool(FREQ_STREAMING)) if TOKEN_SELECTION_MODE != "random" else 0,
                "freq_sample_seed": int(FREQ_SAMPLE_SEED) if TOKEN_SELECTION_MODE == "freq_sample" else 0,
                "freq_smoothing": float(FREQ_SMOOTHING) if TOKEN_SELECTION_MODE == "freq_sample" else 0.0,
            }
        except Exception as e:
            row = {
                "m": m, "L": L, "d_model": d_model, "D_K": D_K, "h": h, "d_k": d_k, "seed": seed,
                "train_steps": 0, "best_val_acc": np.nan, "test_acc": np.nan,
                "baseline_null_acc": 1.0 - P_PRESENT,
                "time_seconds": np.nan,
                "center_embeddings": int(CENTER_EMBEDDINGS),
                "l2_normalize_embeddings": int(L2_NORMALIZE_EMBEDDINGS),
                "exclude_query_from_context": int(EXCLUDE_QUERY_FROM_CONTEXT),
                "p_present": P_PRESENT,
                "model_name": MODEL_NAME,
                "task_tag": TASK_TAG,
                "use_amp": int(USE_AMP),
                "amp_dtype": str(AMP_DTYPE).replace("torch.", ""),
                "batch_size": int(BATCH_SIZE),
                "grad_accum_steps": int(GRAD_ACCUM_STEPS),
                "effective_batch_size": int(EFFECTIVE_BATCH_SIZE),
                "lr": float(LR),
                "weight_decay": float(WEIGHT_DECAY),
                "d_ff": int(D_FF) if D_FF is not None else int(4 * d_model),
                "sample_on_device": int(SAMPLE_ON_DEVICE),

                # NEW: token pool provenance
                "token_selection_mode": str(TOKEN_SELECTION_MODE),
                "freq_source_tag": str(FREQ_SOURCE_TAG) if TOKEN_SELECTION_MODE != "random" else "",
                "freq_dataset_name": str(FREQ_DATASET_NAME) if TOKEN_SELECTION_MODE != "random" else "",
                "freq_dataset_config": str(FREQ_DATASET_CONFIG) if TOKEN_SELECTION_MODE != "random" else "",
                "freq_dataset_split": str(FREQ_DATASET_SPLIT) if TOKEN_SELECTION_MODE != "random" else "",
                "freq_max_tokens": int(FREQ_MAX_TOKENS) if TOKEN_SELECTION_MODE != "random" else 0,
                "freq_streaming": int(bool(FREQ_STREAMING)) if TOKEN_SELECTION_MODE != "random" else 0,
                "freq_sample_seed": int(FREQ_SAMPLE_SEED) if TOKEN_SELECTION_MODE == "freq_sample" else 0,
                "freq_smoothing": float(FREQ_SMOOTHING) if TOKEN_SELECTION_MODE == "freq_sample" else 0.0,

                "error": repr(e),
            }

        rows.append(row)

        # append to CSV incrementally (resume-safe)
        df_new = pd.DataFrame([row])
        if os.path.exists(CSV_PATH):
            df_new.to_csv(CSV_PATH, mode="a", header=False, index=False)
        else:
            df_new.to_csv(CSV_PATH, mode="w", header=True, index=False)

        # Sync runs.csv to Drive periodically (so progress is durable)
        if (len(rows) % int(max(1, SYNC_EVERY_N_RUNS))) == 0:
            _sync_file_to_drive(CSV_PATH, CSV_PATH_DRIVE)

    # final sync
    _sync_file_to_drive(CSV_PATH, CSV_PATH_DRIVE)

    df = pd.read_csv(CSV_PATH) if os.path.exists(CSV_PATH) else pd.DataFrame(rows)
    return df

# =========================
# Summaries + plots
# =========================
def compute_thresholds(df: pd.DataFrame, acc_thresh: float = 0.99):
    """
    For each (m,L):
      D_K* = smallest D_K such that there exists a head partition (h,d_k)
             for which ALL seeds achieve test_acc >= acc_thresh.
    Also record the best (h,d_k) at that D_K* (by mean test_acc).
    """
    out_rows = []
    for m in sorted(df["m"].unique().tolist()):
        for L in sorted(df["L"].unique().tolist()):
            sub = df[(df["m"] == m) & (df["L"] == L)].copy()
            if len(sub) == 0:
                continue

            D_K_star = np.nan
            best_h = np.nan
            best_dk = np.nan
            best_mean = np.nan

            for D_K in sorted(sub["D_K"].unique().tolist()):
                sub_dk = sub[sub["D_K"] == D_K]
                if len(sub_dk) == 0:
                    continue

                candidates = []
                for (h, d_k), g in sub_dk.groupby(["h", "d_k"]):
                    seeds_present = set(map(int, g["seed"].tolist()))
                    if not set(SEEDS).issubset(seeds_present):
                        continue
                    if g["test_acc"].isna().any():
                        continue
                    if (g["test_acc"] >= acc_thresh).all():
                        candidates.append((float(g["test_acc"].mean()), int(h), int(d_k)))

                if candidates:
                    candidates.sort(reverse=True)  # max mean test acc
                    best_mean, best_h, best_dk = candidates[0]
                    D_K_star = int(D_K)
                    break

            out_rows.append({
                "m": int(m),
                "L": int(L),
                "D_K_star": D_K_star,
                "best_h_at_star": best_h,
                "best_dk_at_star": best_dk,
                "mean_test_acc_at_star": best_mean,
            })

    return pd.DataFrame(out_rows)

def plot_accuracy_vs_DK(df: pd.DataFrame, out_path: str):
    """
    Plot: for each m, test accuracy vs D_K, using best head partition per D_K (by mean test acc over seeds).
    """
    grouped = df.groupby(["m", "D_K", "h", "d_k"], as_index=False)["test_acc"].mean()
    best = grouped.sort_values("test_acc").groupby(["m", "D_K"], as_index=False).tail(1)
    best = best.sort_values(["m", "D_K"])

    plt.figure()
    for m in sorted(best["m"].unique().tolist()):
        sub = best[best["m"] == m]
        plt.plot(sub["D_K"], sub["test_acc"], marker="o", label=f"m={m}")
    plt.ylim(0.0, 1.02)
    plt.xlabel("Total key dimension D_K = h * d_k")
    plt.ylabel("Test accuracy (best head partition per D_K)")
    plt.title("Accuracy vs D_K (best head partition)")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.show()
    plt.close()

def plot_scaling(thresholds_df: pd.DataFrame, out_path: str):
    """
    Plot D_K* vs (m * log(m) / d_model) with line fit through origin (estimate constant C).
    """
    df = thresholds_df.copy()
    df = df[~df["D_K_star"].isna()]
    if len(df) == 0:
        print("No thresholds found (D_K_star all NaN); skipping scaling plot.")
        return

    x = df["m"].astype(float) * np.log(df["m"].astype(float)) / float(d_model)
    y = df["D_K_star"].astype(float)

    # Fit line through origin: y ≈ C x
    C = float((x * y).sum() / (x * x).sum()) if float((x * x).sum()) > 0 else float("nan")

    plt.figure()
    plt.scatter(x, y)
    x_line = np.linspace(0.0, float(x.max()) * 1.05, 200)
    y_line = C * x_line
    plt.plot(x_line, y_line)
    plt.xlabel("m * log(m) / d_model")
    plt.ylabel("D_K* (threshold for test_acc >= 0.99 across seeds)")
    plt.title(f"Scaling of threshold D_K*: fit C≈{C:.3g}")
    plt.grid(True, alpha=0.3)
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.show()
    plt.close()

def plot_multihead_advantage(df: pd.DataFrame, thresholds_df: pd.DataFrame, out_path: str):
    """
    For a selected m and near-threshold D_K, plot test_acc vs h (mean across seeds) at fixed D_K.
    Selection heuristic:
      - pick the largest m that has a finite D_K_star
      - set D_K = D_K_star for that m
    """
    th = thresholds_df.dropna(subset=["D_K_star"]).copy()
    if len(th) == 0:
        m_sel = int(df["m"].max())
        D_K_sel = int(np.median(df["D_K"].unique()))
        print(f"No D_K* found; using fallback m={m_sel}, D_K={D_K_sel} for multi-head plot.")
    else:
        th = th.sort_values("m")
        m_sel = int(th.iloc[-1]["m"])
        D_K_sel = int(th.iloc[-1]["D_K_star"])
        print(f"Multi-head plot selection: m={m_sel}, D_K={D_K_sel} (near-threshold).")

    sub = df[(df["m"] == m_sel) & (df["D_K"] == D_K_sel)].copy()
    if len(sub) == 0:
        print("No data for selected (m, D_K); skipping multi-head plot.")
        return

    mean_by_h = sub.groupby("h", as_index=False)["test_acc"].mean().sort_values("h")

    plt.figure()
    plt.plot(mean_by_h["h"], mean_by_h["test_acc"], marker="o")
    plt.ylim(0.0, 1.02)
    plt.xlabel("Number of heads h")
    plt.ylabel("Mean test accuracy (across seeds)")
    plt.title(f"Multi-head advantage at fixed D_K={D_K_sel} (m={m_sel})")
    plt.grid(True, alpha=0.3)
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.show()
    plt.close()

# =========================
# Run sweep, then summarize + plot
# =========================
df_runs = run_sweep()

print("\n=== Completed runs (tail) ===")
display(df_runs.tail(10))

thresholds = compute_thresholds(df_runs, acc_thresh=0.99)
thresholds_path = os.path.join(OUT_DIR, "thresholds.csv")
thresholds.to_csv(thresholds_path, index=False)

# Sync thresholds.csv
thresholds_path_drive = os.path.join(DRIVE_OUT_DIR, "thresholds.csv") if DRIVE_OUT_DIR is not None else None
_sync_file_to_drive(thresholds_path, thresholds_path_drive)

print("\n=== Threshold summary (D_K*) ===")
display(thresholds)

# Plots (local)
plot1_path = os.path.join(OUT_DIR, "plot1_accuracy_vs_DK.png")
plot2_path = os.path.join(OUT_DIR, "plot2_scaling_DKstar.png")
plot3_path = os.path.join(OUT_DIR, "plot3_multihead_advantage.png")

plot_accuracy_vs_DK(df_runs, plot1_path)
plot_scaling(thresholds, plot2_path)
plot_multihead_advantage(df_runs, thresholds, plot3_path)

# Sync plots
plot1_drive = os.path.join(DRIVE_OUT_DIR, "plot1_accuracy_vs_DK.png") if DRIVE_OUT_DIR is not None else None
plot2_drive = os.path.join(DRIVE_OUT_DIR, "plot2_scaling_DKstar.png") if DRIVE_OUT_DIR is not None else None
plot3_drive = os.path.join(DRIVE_OUT_DIR, "plot3_multihead_advantage.png") if DRIVE_OUT_DIR is not None else None

_sync_file_to_drive(plot1_path, plot1_drive)
_sync_file_to_drive(plot2_path, plot2_drive)
_sync_file_to_drive(plot3_path, plot3_drive)

print("\nSaved files (local):")
print(" -", CSV_PATH)
print(" -", thresholds_path)
print(" -", plot1_path)
print(" -", plot2_path)
print(" -", plot3_path)

if DRIVE_OUT_DIR is not None:
    print("\nSaved files (Drive mirror):")
    print(" -", CSV_PATH_DRIVE)
    print(" -", thresholds_path_drive)
    print(" -", plot1_drive)
    print(" -", plot2_drive)
    print(" -", plot3_drive)



# ============================================================
# Idealized Self-Attention on Permutation Graph (Capacity vs D_K and h)
# Here extended with a value channel and L2-based training/evaluation
# Includes GPU determinism
# ============================================================

# --- Deterministic CUDA GEMM (must be set BEFORE importing torch) ---
import os
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")  # use ":16:8" if memory is very tight

# --- Standard imports (torch comes AFTER the env var above) ---
import math, json, random, time, zipfile
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

# --- CUDA precision & backend knobs (safe on CPU, effective on CUDA) ---
if torch.cuda.is_available():
    # Deterministic cuDNN algorithms; disable autotuner
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Disable TF32 (for tighter reproducibility in float32)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

torch.set_grad_enabled(True)

# =========================
# Config
# =========================
SEED = 42
m = 256                       # nodes
d_model = 32                  # embedding dim

h = 8                         # (kept for single-run mode)

DK_TRIALS = [32, 64, 128, 256]     # used by single-run mode only
alpha = 10.0                       # sharpness (used only by legacy BCE path)
lr = 1e-3                          # AdamW lr
weight_decay = 0.0
max_steps = 120_000
check_every = 500
patience = 5
target_f1 = 0.995                  # legacy; retained for backward compatibility
target_mse = 1e-2                  # desired upper bound on test MSE for "pass"
target_pos_rate = 0.5              # ρ in (0,1)
VAL_CONTEXTS = 500
TEST_CONTEXTS = 2000

TRAIN_LENGTH_CHOICES = [16]
TEST_LENGTH_CHOICES  = [16]

SAVE_TO_DRIVE = True
BASE_NAME = "perm_graph_results"

# Optional: print reproducibility report at start
PRINT_REPRO_REPORT = True


# =========================
# Sweep config
# =========================
RUN_SWEEP = True
H_SWEEP = [1,2,4,8,16]
M_SWEEP = [64,128,256]
DMODEL_SWEEP = [16,32]
REPEATS = 3
DK_RATIOS = [0.5,1,2,4]
MAX_DK = 2048
INCLUDE_BASE_DKS = False
SWEEP_TAG = time.strftime("%Y%m%d_%H%M%S")
SWEEP_ROOT = f"perm_graph_sweep_{SWEEP_TAG}"


# =========================
# Utilities
# =========================
def print_repro_report():
    print("=== Reproducibility Report ===")
    print("Device:", "cuda" if torch.cuda.is_available() else "cpu")
    try:
        print("Deterministic algos:", torch.are_deterministic_algorithms_enabled())
    except AttributeError:
        # Older PyTorch
        print("Deterministic algos: <unknown for this torch version>")
    print("CUBLAS_WORKSPACE_CONFIG:", os.getenv("CUBLAS_WORKSPACE_CONFIG"))
    if torch.cuda.is_available():
        print("TF32 matmul allowed:", torch.backends.cuda.matmul.allow_tf32)
        print("TF32 cuDNN allowed:", torch.backends.cudnn.allow_tf32)
        print("cuDNN deterministic:", torch.backends.cudnn.deterministic)
        print("cuDNN benchmark:", torch.backends.cudnn.benchmark)
    print("==============================\n")

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Keep cuDNN behavior stable
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Enforce deterministic ops (raise on non-determinism)
    try:
        torch.use_deterministic_algorithms(True, warn_only=False)
    except TypeError:
        # Older PyTorch without warn_only kwarg
        torch.use_deterministic_algorithms(True)
    except Exception:
        # If not available for some reason, at least warn
        print("Warning: torch.use_deterministic_algorithms could not be enabled.")

def device_str():
    return "cuda" if torch.cuda.is_available() else "cpu"

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def rng_from(master: int, salt: int) -> np.random.Generator:
    ss = np.random.SeedSequence([master, salt])
    return np.random.default_rng(ss.generate_state(4))

def get_outdir(save_to_drive: bool, rho: float) -> str:
    ts = time.strftime("%Y%m%d_%H%M%S")
    base = "/content"
    if save_to_drive:
        try:
            from google.colab import drive  # type: ignore
            drive.mount("/content/drive", force_remount=False)
            base = "/content/drive/MyDrive"
            print("Saving to Google Drive under MyDrive.")
        except Exception as e:
            print("Drive mount failed; falling back to /content. Reason:", repr(e))
            base = "/content"
    name = f"{BASE_NAME}_rho{int(100*rho)}_{ts}"
    return os.path.join(base, name)

# =========================
# Data generation
# =========================
@dataclass
class DataBundle:
    X: torch.Tensor
    perm: torch.Tensor
    contexts_val: List[torch.Tensor]
    contexts_test: List[torch.Tensor]

def build_graph_and_embeddings(seed: int) -> Tuple[torch.Tensor, torch.Tensor]:
    rng = rng_from(seed, 100)
    perm_np = np.arange(m)
    rng.shuffle(perm_np)
    perm = torch.tensor(perm_np, dtype=torch.long)
    torch_rng = torch.Generator().manual_seed(seed)
    X = torch.randn((m, d_model), generator=torch_rng) / math.sqrt(d_model)
    X = X / (X.norm(dim=1, keepdim=True) + 1e-8)
    return X, perm

def sample_context_with_pos_rate(rng: np.random.Generator,
                                 perm: torch.Tensor,
                                 ell: int,
                                 rho: float) -> torch.Tensor:
    S = rng.choice(m, size=ell, replace=False).astype(np.int64)
    S_set = set(int(x) for x in S)
    b = int(rng.binomial(ell, rho))
    if b == 0:
        return torch.tensor(list(S_set), dtype=torch.long)
    U = rng.choice(list(S_set), size=min(b, len(S_set)), replace=False).astype(np.int64)
    perm_np = perm.detach().cpu().numpy()
    for i in U:
        target = int(perm_np[i])
        if target not in S_set:
            pool = list(S_set - {int(i)})
            if not pool:
                continue
            j = int(rng.choice(pool))
            S_set.remove(j)
            S_set.add(target)
    return torch.tensor(list(S_set), dtype=torch.long)

def sample_context(rng: np.random.Generator,
                   perm: torch.Tensor,
                   rho: float,
                   length_choices: List[int]) -> torch.Tensor:
    ell = int(rng.choice(length_choices))
    return sample_context_with_pos_rate(rng, perm, ell, rho)

def make_context_list(n: int,
                      seed: int,
                      salt: int,
                      perm: torch.Tensor,
                      rho: float,
                      length_choices: List[int]) -> List[torch.Tensor]:
    rng = rng_from(seed, salt)
    return [sample_context(rng, perm, rho, length_choices) for _ in range(n)]

def make_data(seed: int, rho: float) -> DataBundle:
    X, perm = build_graph_and_embeddings(seed)
    contexts_val  = make_context_list(VAL_CONTEXTS,  seed, 200, perm, rho, TEST_LENGTH_CHOICES)
    contexts_test = make_context_list(TEST_CONTEXTS, seed, 300, perm, rho, TEST_LENGTH_CHOICES)
    return DataBundle(X=X, perm=perm, contexts_val=contexts_val, contexts_test=contexts_test)

def build_labels_for_context(ctx: torch.Tensor, perm: torch.Tensor, device) -> torch.Tensor:
    targets = perm[ctx]                            # [ell]
    eq = (targets[:, None] == ctx[None, :])        # [ell, ell] boolean
    return eq.to(device, dtype=torch.float32)

# =========================
# Model (idealized SA + value channel)
# =========================
class IdealizedSelfAttention(nn.Module):
    """
    QK path is identical to the original code (no softmax, no scaling, max over heads).
    A new value channel maps embeddings to messages via a *frozen* random linear map W_V.
    """
    def __init__(self, d_model: int, D_K: int, h: int, alpha: float = 10.0, d_msg: Optional[int] = None):
        super().__init__()
        assert D_K % h == 0, "D_K must be divisible by h"
        self.d_model = d_model
        self.D_K = D_K
        self.h = h
        self.d_k = D_K // h
        self.alpha = alpha

        # Message dimension; by default equal to d_model.
        self.d_msg = d_model if d_msg is None else d_msg

        std = 1.0 / math.sqrt(d_model)
        self.W_Q = nn.Parameter(torch.randn(d_model, D_K) * std)
        self.W_K = nn.Parameter(torch.randn(d_model, D_K) * std)

        # NEW: value map, drawn at random and *frozen* (matches Assumption "Random value map")
        self.W_V = nn.Parameter(torch.randn(d_model, self.d_msg) * std, requires_grad=False)

        # Legacy threshold parameter (unused by value-channel training but kept for backward compatibility)
        self.tau  = nn.Parameter(torch.tensor(0.0))

    @torch.no_grad()
    def save_checkpoint(self, path: str, meta: Dict):
        ckpt = {
            "W_Q": self.W_Q.detach().cpu(),
            "W_K": self.W_K.detach().cpu(),
            "W_V": self.W_V.detach().cpu(),
            "tau": float(self.tau.detach().cpu()),
            "meta": meta
        }
        torch.save(ckpt, path)

    def scores_Smax(self, X_ctx: torch.Tensor) -> torch.Tensor:
        """
        Compute QK scores and aggregate over heads with a max, exactly as in the original code.
        No scaling, no softmax.
        """
        ell = X_ctx.shape[0]
        Q = X_ctx @ self.W_Q
        K = X_ctx @ self.W_K
        Qh = Q.view(ell, self.h, self.d_k).transpose(0, 1)        # [h, ell, d_k]
        Kh = K.view(ell, self.h, self.d_k).transpose(0, 1)        # [h, ell, d_k]
        S_heads = torch.matmul(Qh, Kh.transpose(1, 2))            # [h, ell, ell]
        S_max = S_heads.max(dim=0).values                         # [ell, ell]
        return S_max

    def value_messages(self, X_ctx: torch.Tensor) -> torch.Tensor:
        """
        Shared value map: messages live in the embedding span, y = x W_V.
        """
        return X_ctx @ self.W_V

# =========================
# Loss & metrics
# =========================
def weighted_bce_logits(z: torch.Tensor, y: torch.Tensor, pos_weight_scalar: float) -> torch.Tensor:
    """
    Legacy BCE loss for edge recognition. Retained for compatibility but NOT used
    in the value-channel training path.
    """
    pos_term = F.softplus(-z) * y * pos_weight_scalar
    neg_term = F.softplus( z) * (1.0 - y)
    return (pos_term + neg_term).mean()

@torch.no_grad()
def evaluate_model(model: IdealizedSelfAttention,
                   X: torch.Tensor,
                   perm: torch.Tensor,
                   contexts: List[torch.Tensor],
                   device,
                   compute_hist: bool = False):
    """
    Legacy edge-recognition evaluation (micro-F1 etc.).
    Kept for reference; the value-channel training/eval below uses L2/MSE only.
    """
    model.eval()
    tau = float(model.tau.detach().cpu().item())
    TP = FP = FN = 0
    pos_scores_all = []
    hard_neg_scores_all = []

    for ctx in contexts:
        ctx = ctx.to(device)
        X_ctx = X[ctx]
        y = build_labels_for_context(ctx, perm, device)
        S_max = model.scores_Smax(X_ctx)
        preds = (S_max > tau).float()

        TP += (preds * y).sum().item()
        FP += (preds * (1.0 - y)).sum().item()
        FN += ((1.0 - preds) * y).sum().item()

        pos_scores = S_max[y == 1.0]
        if pos_scores.numel() > 0:
            pos_scores_all.append(pos_scores.detach().cpu())

        neg_matrix = S_max.clone()
        neg_matrix[y == 1.0] = -float('inf')
        for r in range(neg_matrix.shape[0]):
            row = neg_matrix[r]
            valid = torch.isfinite(row)
            if valid.any():
                k = min(5, int(valid.sum().item()))
                topk_vals = torch.topk(row[valid], k=k).values
                hard_neg_scores_all.append(topk_vals.detach().cpu())

    denom = (2 * TP + FP + FN)
    micro_f1 = (2 * TP / denom) if denom > 0 else 0.0

    if len(pos_scores_all) == 0 or len(hard_neg_scores_all) == 0:
        margin = float('nan')
    else:
        pos_concat = torch.cat(pos_scores_all)
        hard_neg_concat = torch.cat(hard_neg_scores_all)
        margin = float(pos_concat.mean().item() - hard_neg_concat.mean().item())

    if compute_hist:
        pos_np = torch.cat(pos_scores_all).numpy() if len(pos_scores_all) > 0 else np.array([])
        neg_np = torch.cat(hard_neg_scores_all).numpy() if len(hard_neg_scores_all) > 0 else np.array([])
        return micro_f1, margin, pos_np, neg_np
    else:
        return micro_f1, margin, None, None

@torch.no_grad()
def evaluate_message_retrieval(model: IdealizedSelfAttention,
                               X: torch.Tensor,
                               perm: torch.Tensor,
                               contexts: List[torch.Tensor],
                               device) -> Tuple[float, int]:
    """
    Evaluate message retrieval using the value channel.
    For each context position i such that the permutation neighbor pi(i) is also in-context,
    we compute the MSE between the predicted message and the true neighbor message.

    The aggregated QK scores S_max (max over heads, no softmax/scaling) are reused as
    weights for mixing value messages, mirroring the theoretical model where the same
    weights used for edge recognition also mix messages.
    """
    model.eval()
    total_sq_err = 0.0
    total_count = 0

    for ctx in contexts:
        ctx = ctx.to(device)
        X_ctx = X[ctx]                                   # [ell, d_model]
        S_max = model.scores_Smax(X_ctx)                 # [ell, ell]

        V_ctx = model.value_messages(X_ctx)              # [ell, d_msg]
        Y_hat = S_max @ V_ctx                            # [ell, d_msg]

        neighbors = perm[ctx]                            # [ell]
        Y_tgt_all = model.value_messages(X[neighbors])   # [ell, d_msg]

        # Mask: only positions i where pi(i) is also in the context
        in_ctx = (neighbors[:, None] == ctx[None, :]).any(dim=1)  # [ell]
        if in_ctx.any():
            diff = Y_hat[in_ctx] - Y_tgt_all[in_ctx]     # [n_valid, d_msg]
            total_sq_err += float(diff.pow(2).sum().item())
            total_count += diff.numel()

    if total_count == 0:
        return float("nan"), 0

    mse = total_sq_err / total_count
    return mse, total_count

# =========================
# Optimizer (determinism-friendly)
# =========================
def make_adamw(params, lr: float, weight_decay: float):
    kw = dict(lr=lr, weight_decay=weight_decay)
    # Prefer disabling fused/foreach paths to avoid nondeterministic atomics on some builds
    try:
        return torch.optim.AdamW(params, **kw, fused=False, foreach=False)
    except TypeError:
        try:
            return torch.optim.AdamW(params, **kw, foreach=False)
        except TypeError:
            return torch.optim.AdamW(params, **kw)

# =========================
# Training
# =========================
@dataclass
class TrainResult:
    DK: int
    h: int
    d_k: int
    passed: bool
    test_micro_f1: float   # legacy field (now always NaN)
    margin: float          # legacy field (now always NaN)
    tau: float
    steps_trained: int
    first_pass_step: Optional[int]
    stop_step: int
    ckpt_path: str
    test_mse: float        # new: test-set message retrieval MSE
    val_mse: float         # new: best validation MSE observed during training

# === value-channel training: minimize L2 / MSE between predicted and true messages ===
def train_one_configuration(DK: int,
                            h_val: int,
                            bundle: DataBundle,
                            device,
                            outdir: str,
                            rho: float) -> TrainResult:
    d_k = DK // h_val
    model = IdealizedSelfAttention(d_model=d_model, D_K=DK, h=h_val, alpha=alpha).to(device)
    opt = make_adamw(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_rng = rng_from(SEED, 500 + DK + 13*h_val)

    best_val_mse = float("inf")
    consecutive_passes = 0
    first_pass_step = None
    stop_step = max_steps

    for step in tqdm(range(1, max_steps + 1), desc=f"Training (values) h={h_val}, D_K={DK}", leave=False):
        ctx = sample_context(train_rng, bundle.perm, rho, TRAIN_LENGTH_CHOICES).to(device)
        X_ctx = bundle.X[ctx]                        # [ell, d_model]

        # QK scores aggregated over heads with max, exactly as before.
        S_max = model.scores_Smax(X_ctx)             # [ell, ell]

        # Value messages from context embeddings
        V_ctx = model.value_messages(X_ctx)          # [ell, d_msg]
        Y_hat = S_max @ V_ctx                        # [ell, d_msg]

        # Target messages: neighbor messages y_{pi(i)} = x_{pi(i)} W_V
        neighbors = bundle.perm[ctx]                 # [ell]
        Y_tgt_all = model.value_messages(bundle.X[neighbors])  # [ell, d_msg]

        # Only supervise positions where pi(i) is also in the context (matches theory)
        in_ctx = (neighbors[:, None] == ctx[None, :]).any(dim=1)  # [ell]
        if not in_ctx.any():
            continue  # extremely unlikely, but safe

        loss = F.mse_loss(Y_hat[in_ctx], Y_tgt_all[in_ctx])

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        # Periodic validation based on message-retrieval MSE
        if step % check_every == 0:
            val_mse, _ = evaluate_message_retrieval(
                model, bundle.X, bundle.perm, bundle.contexts_val, device
            )

            # Track best validation MSE (for logging)
            if val_mse < best_val_mse:
                best_val_mse = val_mse

            # Original behavior: only stop early once we're "good enough"
            passed_now = (val_mse <= target_mse)
            consecutive_passes = consecutive_passes + 1 if passed_now else 0

            if passed_now and first_pass_step is None:
                first_pass_step = step

            if consecutive_passes >= patience:
                stop_step = step
                break

    # Final test evaluation (message retrieval MSE)
    test_mse, _ = evaluate_message_retrieval(
        model, bundle.X, bundle.perm, bundle.contexts_test, device
    )
    passed = (test_mse <= target_mse)

    # Optional histogram of score separation for legacy QK-only analysis
    try:
        test_micro_f1, margin, pos_hist, neg_hist = evaluate_model(
            model, bundle.X, bundle.perm, bundle.contexts_test, device, compute_hist=True
        )
        d_k = DK // h_val
        tag = f"DK{DK}_h{h_val}_dk{d_k}"
        if pos_hist is not None and neg_hist is not None and pos_hist.size > 0 and neg_hist.size > 0:
            plt.figure(figsize=(6,4))
            plt.hist(pos_hist, bins=60, alpha=0.6, label="S+ (positives)", density=True)
            plt.hist(neg_hist, bins=60, alpha=0.6, label="Top-5 S- (hard negatives)", density=True)
            plt.xlabel("Score")
            plt.ylabel("Density")
            plt.title(f"Score separation — {tag} (test)")
            plt.legend()
            plt.tight_layout()
            hist_path = os.path.join(outdir, f"score_hist_{tag}.png")
            ensure_dir(outdir)
            plt.savefig(hist_path, dpi=150)
            plt.close()
    except Exception as e:
        # Histogram plotting is purely diagnostic; ignore failures
        print("Histogram plotting skipped due to:", repr(e))
        test_micro_f1, margin = float("nan"), float("nan")

    tau_val = float(model.tau.detach().cpu().item())
    return TrainResult(
        DK=DK,
        h=h_val,
        d_k=d_k,
        passed=passed,
        test_micro_f1=float(test_micro_f1),
        margin=float(margin),
        tau=tau_val,
        steps_trained=stop_step,
        first_pass_step=first_pass_step,
        stop_step=stop_step,
        ckpt_path="NA",
        test_mse=float(test_mse),
        val_mse=float(best_val_mse if best_val_mse < float("inf") else float("nan")),
    )


# =========================
# Original single-experiment (updated for value-channel MSE)
# =========================
def run_experiment():
    """
    Single experiment using current globals: SEED, m, d_model, DK_TRIALS, h, etc.
    Training and evaluation are based on message-retrieval MSE via the value channel.
    """
    set_seed(SEED)
    device = device_str()
    print(f"Device: {device}")
    if PRINT_REPRO_REPORT:
        print_repro_report()

    outdir = get_outdir(SAVE_TO_DRIVE, target_pos_rate)
    ensure_dir(outdir)
    print("Output directory:", outdir)

    bundle = make_data(SEED, target_pos_rate)
    bundle = DataBundle(X=bundle.X.to(device),
                        perm=bundle.perm.to(device),
                        contexts_val=bundle.contexts_val,
                        contexts_test=bundle.contexts_test)

    results: List[TrainResult] = []
    smallest_passing: Optional[int] = None

    for DK in DK_TRIALS:
        print("="*80)
        print(f"Training configuration (values): D_K={DK} (h={h} ⇒ d_k={DK//h}), ρ={target_pos_rate}")
        tr = train_one_configuration(DK, h, bundle, device, outdir, target_pos_rate)
        results.append(tr)
        print(f"→ Test MSE: {tr.test_mse:.6e} | best val MSE: {tr.val_mse:.6e}")
        print(f"   Steps trained: {tr.steps_trained} | "
              f"first_pass_step (val MSE ≤ {target_mse:.2e}): {tr.first_pass_step} | "
              f"checkpoint: {tr.ckpt_path}")
        if tr.passed and smallest_passing is None:
            smallest_passing = DK

    rows = []
    for r in results:
        rows.append({
            "D_K": r.DK,
            "h": r.h,
            "d_k": r.d_k,
            "rho": target_pos_rate,
            "passed": r.passed,
            "test_mse": r.test_mse,
            "val_mse": r.val_mse,
            "tau": r.tau,
            "steps_trained": r.steps_trained,
            "first_pass_step": r.first_pass_step if r.first_pass_step is not None else "NA",
            "stop_step": r.stop_step,
            "checkpoint": r.ckpt_path
        })
    df = pd.DataFrame(rows).sort_values(["h","D_K"]).reset_index(drop=True)
    csv_path = os.path.join(outdir, "results.csv")
    json_path = os.path.join(outdir, "results.json")
    df.to_csv(csv_path, index=False)
    with open(json_path, "w") as f:
        json.dump(rows, f, indent=2)

    print("\nSaved results:")
    print(" CSV:", csv_path)
    print(" JSON:", json_path)

    print("\nResults table:")
    display(df)

    if smallest_passing is None:
        print(f"\nOutcome: No configuration reached test MSE ≤ {target_mse:.2e}.")
    else:
        print(f"\nEmpirical hat D_K*: smallest passing D_K = {smallest_passing}")

    if not SAVE_TO_DRIVE:
        zip_path = os.path.join("/content", "perm_graph_results.zip")
        print("\nCreating zip for download at:", zip_path)
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
            for root, _, files in os.walk(outdir):
                for fn in files:
                    abspath = os.path.join(root, fn)
                    relpath = os.path.relpath(abspath, os.path.dirname(outdir))
                    zf.write(abspath, arcname=relpath)
        try:
            from google.colab import files  # type: ignore
            files.download(zip_path)
        except Exception as e:
            print("Automatic download not available; you can manually fetch:", zip_path, "| Reason:", repr(e))

    return df, outdir, smallest_passing

# =========================
# Sweep helpers & driver (updated to use MSE instead of F1)
# =========================
def _dk_candidates_for(m_val: int,
                       d_val: int,
                       h_val: int,
                       ratios: List[float],
                       base_dks: List[int],
                       max_dk: Optional[int]) -> List[int]:
    thr_l2 = (m_val * max(math.log2(max(m_val, 2)), 1e-9)) / max(d_val, 1e-9)
    cand = set()
    for r in ratios:
        dk = int(math.ceil(r * thr_l2))
        dk = max(dk, 1)
        dk = int(math.ceil(dk / h_val) * h_val)   # enforce divisibility by h
        cand.add(dk)
    if INCLUDE_BASE_DKS and base_dks:
        for dk in base_dks:
            if dk % h_val == 0:
                cand.add(dk)
    filtered = sorted(x for x in cand if x >= h_val and (max_dk is None or x <= max_dk))
    if not filtered:
        filtered = [h_val]
    return filtered

def _mount_and_get_sweep_root():
    # bootstrap mount & yield sweep root dir
    _orig = BASE_NAME
    base = f"{SWEEP_ROOT}/_bootstrap_only"
    globals()["BASE_NAME"] = base
    tmp = get_outdir(SAVE_TO_DRIVE, target_pos_rate)
    drive_base = "/content/drive/MyDrive" if tmp.startswith("/content/drive/MyDrive") else "/content"
    sweep_root_dir = os.path.join(drive_base, SWEEP_ROOT)
    ensure_dir(sweep_root_dir)
    globals()["BASE_NAME"] = _orig
    print(f"\nSweep root: {sweep_root_dir}")
    return sweep_root_dir, drive_base

def _ratio_log2(m_val: int, d_val: int, DK_val: int):
    thr_l2 = (m_val * math.log2(max(m_val, 2))) / max(d_val, 1e-9)
    return DK_val / max(thr_l2, 1e-9), thr_l2

def _plot_block_by_h(agg_block: pd.DataFrame, m_val: int, d_val: int, out_dir: str):
    """Multi-line plot (one line per h) with error bars; x-axis = DK / (m log2 m / d)."""
    ensure_dir(out_dir)
    plt.figure(figsize=(7.2, 4.6))
    for h_val, sub in agg_block.groupby("h"):
        sub = sub.sort_values("ratio_log2")
        x = sub["ratio_log2"].to_numpy()
        y = sub["test_mse_mean"].to_numpy()
        yerr = sub["test_mse_std"].fillna(0.0).to_numpy()
        plt.errorbar(x, y, yerr=yerr, fmt="-o", capsize=3, label=f"h={h_val}", alpha=0.95)
    plt.axvline(1.0, linestyle="--", linewidth=1)
    plt.xlabel("DK / (m log2 m / d_model)")
    plt.ylabel("Mean test MSE")
    plt.title(f"Test MSE vs Capacity Ratio — m={m_val}, d_model={d_val}")
    plt.legend(title="Heads")
    plt.tight_layout()
    fig_path = os.path.join(out_dir, f"mse_vs_ratio_m{m_val}_d{d_val}_byH.png")
    plt.savefig(fig_path, dpi=150)
    plt.close()
    print("Saved block plot:", fig_path)

def run_sweep():
    """
    Joint sweep over M_SWEEP × DMODEL_SWEEP × H_SWEEP with REPEATS.
    For each (m, d_model), run all heads & DKs, then immediately save
    the per-block plot to Drive. Finally, save global summaries and
    a heatmap of min DK over h for each (m, d_model).

    All training and evaluation are based on message-retrieval MSE.
    """
    global m, d_model, SEED

    set_seed(SEED)
    device = device_str()
    print(f"Device: {device}")
    if PRINT_REPRO_REPORT:
        print_repro_report()

    sweep_root_dir, drive_base = _mount_and_get_sweep_root()

    all_rows: List[pd.DataFrame] = []
    block_min_rows = []  # for min-DK heatmap (over h)

    DK_TRIALS_ORIG = list(DK_TRIALS)
    SEED_BASE = SEED

    for m_val in M_SWEEP:
        for d_val in DMODEL_SWEEP:
            print(f"\n=== Block: m={m_val}, d_model={d_val} ===")

            # Per-block directories
            block_dir = os.path.join(sweep_root_dir, f"m{m_val}_d{d_val}")
            ensure_dir(block_dir)

            block_rows = []  # accumulate all runs for this (m,d)

            # Do all repeats (so we can compute means/std err bars)
            for rep in range(1, REPEATS + 1):
                # Fix globals for data generation only
                m = m_val
                d_model = d_val
                SEED = int(SEED_BASE + rep*997 + m_val*13 + d_val*23)

                # Build data once per replicate so all heads/DKs share it (fair comparison)
                bundle = make_data(SEED, target_pos_rate)
                bundle = DataBundle(X=bundle.X.to(device),
                                    perm=bundle.perm.to(device),
                                    contexts_val=bundle.contexts_val,
                                    contexts_test=bundle.contexts_test)

                for h_val in H_SWEEP:
                    dk_list = _dk_candidates_for(
                        m_val=m_val, d_val=d_val, h_val=h_val,
                        ratios=DK_RATIOS, base_dks=DK_TRIALS_ORIG, max_dk=MAX_DK
                    )
                    print(f"  • run{rep:02d} | h={h_val} | DK candidates={dk_list}")

                    run_outdir = os.path.join(block_dir, f"run{rep:02d}", f"h{h_val}")
                    ensure_dir(run_outdir)

                    for DK in dk_list:
                        tr = train_one_configuration(DK, h_val, bundle, device, run_outdir, target_pos_rate)
                        block_rows.append({
                            "m": m_val,
                            "d_model": d_val,
                            "seed": SEED,
                            "run": rep,
                            "h": h_val,
                            "D_K": tr.DK,
                            "d_k": tr.d_k,
                            "rho": target_pos_rate,
                            "passed": tr.passed,
                            "test_mse": tr.test_mse,
                            "val_mse": tr.val_mse,
                            "tau": tr.tau,
                            "steps_trained": tr.steps_trained,
                            "first_pass_step": tr.first_pass_step if tr.first_pass_step is not None else "NA",
                            "stop_step": tr.stop_step,
                            "outdir": run_outdir
                        })

            # Aggregate per-block and save plots immediately
            df_block = pd.DataFrame(block_rows)
            if len(df_block) == 0:
                print("No data collected for this block.")
                continue

            # normalize first_pass_step to numeric
            def _to_num(x):
                if x == "NA": return np.nan
                try:
                    return float(x)
                except Exception:
                    return np.nan
            df_block["first_pass_step_num"] = df_block["first_pass_step"].apply(_to_num)

            # Add ratio_log2 and thr for plotting
            thr_l2 = (m_val * math.log2(max(m_val, 2))) / max(d_val, 1e-9)
            df_block["thr_log2"] = thr_l2
            df_block["ratio_log2"] = df_block["D_K"] / max(thr_l2, 1e-9)

            # Per (h, D_K) aggregates
            group_cols = ["m", "d_model", "h", "D_K"]
            agg_block = df_block.groupby(group_cols).agg(
                runs=("passed", "count"),
                pass_rate=("passed", "mean"),
                test_mse_mean=("test_mse", "mean"),
                test_mse_std=("test_mse", "std"),
                steps_mean=("steps_trained", "mean"),
                steps_median=("steps_trained", "median"),
                first_pass_step_median=("first_pass_step_num", "median"),
                thr_log2=("thr_log2", "first"),
            ).reset_index()
            agg_block["ratio_log2"] = agg_block["D_K"] / agg_block["thr_log2"].replace(0, np.nan)

            # Save per-block CSVs
            block_all_csv = os.path.join(block_dir, "block_all_runs.csv")
            block_agg_csv = os.path.join(block_dir, "block_grouped_by_h_DK.csv")
            df_block.to_csv(block_all_csv, index=False)
            agg_block.to_csv(block_agg_csv, index=False)
            print("Saved block CSVs:", block_all_csv, "and", block_agg_csv)

            # Plot multi-line (one line per h) and save now
            _plot_block_by_h(agg_block, m_val, d_val, block_dir)

            # Compute minimum DK over all heads (exists some h with pass_rate == 1.0 / 0.8 / 0.5)
            mins = {"m": m_val, "d_model": d_val, "thr_log2": thr_l2}
            for cut, label in [(1.0, "100"), (0.80, "80"), (0.50, "50")]:
                cand = agg_block.loc[agg_block["pass_rate"] >= cut, ["D_K", "h"]].sort_values("D_K")
                mins[f"DK_star_anyh_pass{label}"] = (cand["D_K"].iloc[0] if len(cand) else np.nan)
                mins[f"h_for_DK_star_pass{label}"] = (cand["h"].iloc[0] if len(cand) else np.nan)
            block_min_rows.append(mins)

            # Keep for global concat
            all_rows.append(df_block)

    # === Global summaries ===
    if len(all_rows) == 0:
        print("No rows collected; sweep did not run.")
        return

    df_all = pd.concat(all_rows, ignore_index=True)

    # Global grouped with h included
    group_cols_global = ["m", "d_model", "h", "D_K"]
    agg_global = df_all.groupby(group_cols_global).agg(
        runs=("passed", "count"),
        pass_rate=("passed", "mean"),
        test_mse_mean=("test_mse", "mean"),
        test_mse_std=("test_mse", "std"),
        steps_mean=("steps_trained", "mean"),
        steps_median=("steps_trained", "median"),
        thr_log2=("thr_log2", "first")
    ).reset_index()
    agg_global["ratio_log2"] = agg_global["D_K"] / agg_global["thr_log2"].replace(0, np.nan)

    # Minimum DK per (m,d) over all heads (as requested)
    df_min_anyh = pd.DataFrame(block_min_rows)
    df_min_anyh["DK_star_anyh_ratio_log2_pass100"] = df_min_anyh["DK_star_anyh_pass100"] / df_min_anyh["thr_log2"]
    df_min_anyh["DK_star_anyh_ratio_log2_pass80"]  = df_min_anyh["DK_star_anyh_pass80"]  / df_min_anyh["thr_log2"]
    df_min_anyh["DK_star_anyh_ratio_log2_pass50"]  = df_min_anyh["DK_star_anyh_pass50"]  / df_min_anyh["thr_log2"]

    # Save global artifacts at sweep root
    all_csv = os.path.join(sweep_root_dir, "sweep_all_runs.csv")
    agg_csv = os.path.join(sweep_root_dir, "sweep_grouped_by_m_d_h_DK.csv")
    min_anyh_csv = os.path.join(sweep_root_dir, "sweep_min_DK_over_heads_per_m_d.csv")
    df_all.to_csv(all_csv, index=False)
    agg_global.to_csv(agg_csv, index=False)
    df_min_anyh.to_csv(min_anyh_csv, index=False)
    print("\nSweep saved:")
    print(" • All runs:", all_csv)
    print(" • Grouped per (m, d_model, h, D_K):", agg_csv)
    print(" • Min DK over heads per (m, d_model):", min_anyh_csv)

    # Optional combined overlay (all (m,d,h) curves) — can get busy
    try:
        plt.figure(figsize=(8.0,5.2))
        for (mm, dd, hh), sub in agg_global.groupby(["m", "d_model", "h"]):
            sub_sorted = sub.sort_values("ratio_log2")
            x = sub_sorted["ratio_log2"].to_numpy()
            y = sub_sorted["test_mse_mean"].to_numpy()
            yerr = sub_sorted["test_mse_std"].fillna(0.0).to_numpy()
            plt.errorbar(x, y, yerr=yerr, fmt="-o", capsize=3, alpha=0.80, label=f"m={mm}, d={dd}, h={hh}")
        plt.axvline(1.0, linestyle="--", linewidth=1)
        plt.xlabel("DK / (m log2 m / d_model)")
        plt.ylabel("Mean test MSE")
        plt.title("Test MSE vs Capacity Ratio — ALL (m, d_model, h)")
        plt.legend(ncol=2, fontsize=8)
        plt.tight_layout()
        fig_path_all = os.path.join(sweep_root_dir, "mse_vs_ratio_ALL_byH.png")
        plt.savefig(fig_path_all, dpi=150)
        plt.close()
        print(" • Combined plot:", fig_path_all)
    except Exception as e:
        print("Combined plotting skipped due to:", repr(e))

    # === Plot heatmap of min DK over heads per (m, d_model) ===
    try:
        # Use pass100 criterion for robustness; fall back handled in CSV
        pivot = df_min_anyh.pivot(index="m", columns="d_model", values="DK_star_anyh_pass100")
        plt.figure(figsize=(1.2 + 0.8*pivot.shape[1], 1.2 + 0.6*pivot.shape[0]))
        im = plt.imshow(pivot, aspect="auto", interpolation="nearest")
        plt.colorbar(im, label="Min DK (exists some h, pass rate = 100%)")
        plt.xticks(ticks=np.arange(pivot.shape[1]), labels=list(pivot.columns))
        plt.yticks(ticks=np.arange(pivot.shape[0]), labels=list(pivot.index))
        plt.xlabel("d_model")
        plt.ylabel("m")
        plt.title("Minimum DK over heads per (m, d_model)")
        # annotate cells
        for i in range(pivot.shape[0]):
            for j in range(pivot.shape[1]):
                val = pivot.iloc[i, j]
                txt = "—" if (pd.isna(val)) else str(int(val))
                plt.text(j, i, txt, ha="center", va="center", fontsize=9)
        plt.tight_layout()
        heatmap_path = os.path.join(sweep_root_dir, "min_DK_over_heads_heatmap.png")
        plt.savefig(heatmap_path, dpi=150)
        plt.close()
        print(" • Min DK heatmap:", heatmap_path)
    except Exception as e:
        print("Heatmap plotting skipped due to:", repr(e))

# Execute
if __name__ == "__main__":
    if RUN_SWEEP:
        run_sweep()
    else:
        run_experiment()
