# ============================================================
# Idealized Self-Attention on Permutation Graph (Capacity vs D_K and h)
# Includes GPU determinism 
# ============================================================

# --- Deterministic CUDA GEMM (must be set BEFORE importing torch) ---
import os
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")  # use ":16:8" if memory is very tight

# --- Standard imports (torch comes AFTER the env var above) ---
import math, json, random, time, zipfile
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

# --- CUDA precision & backend knobs (safe on CPU, effective on CUDA) ---
if torch.cuda.is_available():
    # Deterministic cuDNN algorithms; disable autotuner
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Disable TF32 (for tighter reproducibility in float32)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

torch.set_grad_enabled(True)

# =========================
# Config
# =========================
SEED = 42
m = 256                       # nodes
d_model = 32                  # embedding dim

h = 8                         # (kept for single-run mode)

DK_TRIALS = [32, 64, 128, 256]     # used by single-run mode only
alpha = 10.0                       # sharpness
lr = 1e-3                          # AdamW lr
weight_decay = 0.0
max_steps = 20_000
check_every = 500
patience = 5
target_f1 = 0.995
target_pos_rate = 0.5              # ρ in (0,1)
VAL_CONTEXTS = 500
TEST_CONTEXTS = 2000

TRAIN_LENGTH_CHOICES = [16]
TEST_LENGTH_CHOICES  = [16]

SAVE_TO_DRIVE = True
BASE_NAME = "perm_graph_results"

# Optional: print reproducibility report at start
PRINT_REPRO_REPORT = True


# =========================
# Sweep config
# =========================
RUN_SWEEP = True
H_SWEEP = [1,2,4,8,16,32,64]
M_SWEEP = [64,128,256,512]
DMODEL_SWEEP = [16,32,64]
REPEATS = 10
DK_RATIOS = [0.5,0.75,1,1.5, 2.0]
MAX_DK = 2048
INCLUDE_BASE_DKS = False
SWEEP_TAG = time.strftime("%Y%m%d_%H%M%S")
SWEEP_ROOT = f"perm_graph_sweep_{SWEEP_TAG}"


# =========================
# Utilities
# =========================
def print_repro_report():
    print("=== Reproducibility Report ===")
    print("Device:", "cuda" if torch.cuda.is_available() else "cpu")
    try:
        print("Deterministic algos:", torch.are_deterministic_algorithms_enabled())
    except AttributeError:
        # Older PyTorch
        print("Deterministic algos: <unknown for this torch version>")
    print("CUBLAS_WORKSPACE_CONFIG:", os.getenv("CUBLAS_WORKSPACE_CONFIG"))
    if torch.cuda.is_available():
        print("TF32 matmul allowed:", torch.backends.cuda.matmul.allow_tf32)
        print("TF32 cuDNN allowed:", torch.backends.cudnn.allow_tf32)
        print("cuDNN deterministic:", torch.backends.cudnn.deterministic)
        print("cuDNN benchmark:", torch.backends.cudnn.benchmark)
    print("==============================\n")

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Keep cuDNN behavior stable
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Enforce deterministic ops (raise on non-determinism)
    try:
        torch.use_deterministic_algorithms(True, warn_only=False)
    except TypeError:
        # Older PyTorch without warn_only kwarg
        torch.use_deterministic_algorithms(True)
    except Exception:
        # If not available for some reason, at least warn
        print("Warning: torch.use_deterministic_algorithms could not be enabled.")

def device_str():
    return "cuda" if torch.cuda.is_available() else "cpu"

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def rng_from(master: int, salt: int) -> np.random.Generator:
    ss = np.random.SeedSequence([master, salt])
    return np.random.default_rng(ss.generate_state(4))

def get_outdir(save_to_drive: bool, rho: float) -> str:
    ts = time.strftime("%Y%m%d_%H%M%S")
    base = "/content"
    if save_to_drive:
        try:
            from google.colab import drive  # type: ignore
            drive.mount("/content/drive", force_remount=False)
            base = "/content/drive/MyDrive"
            print("Saving to Google Drive under MyDrive.")
        except Exception as e:
            print("Drive mount failed; falling back to /content. Reason:", repr(e))
            base = "/content"
    name = f"{BASE_NAME}_rho{int(100*rho)}_{ts}"
    return os.path.join(base, name)

# =========================
# Data generation
# =========================
@dataclass
class DataBundle:
    X: torch.Tensor
    perm: torch.Tensor
    contexts_val: List[torch.Tensor]
    contexts_test: List[torch.Tensor]

def build_graph_and_embeddings(seed: int) -> Tuple[torch.Tensor, torch.Tensor]:
    rng = rng_from(seed, 100)
    perm_np = np.arange(m)
    rng.shuffle(perm_np)
    perm = torch.tensor(perm_np, dtype=torch.long)
    torch_rng = torch.Generator().manual_seed(seed)
    X = torch.randn((m, d_model), generator=torch_rng) / math.sqrt(d_model)
    X = X / (X.norm(dim=1, keepdim=True) + 1e-8)
    return X, perm

def sample_context_with_pos_rate(rng: np.random.Generator,
                                 perm: torch.Tensor,
                                 ell: int,
                                 rho: float) -> torch.Tensor:
    S = rng.choice(m, size=ell, replace=False).astype(np.int64)
    S_set = set(int(x) for x in S)
    b = int(rng.binomial(ell, rho))
    if b == 0:
        return torch.tensor(list(S_set), dtype=torch.long)
    U = rng.choice(list(S_set), size=min(b, len(S_set)), replace=False).astype(np.int64)
    perm_np = perm.detach().cpu().numpy()
    for i in U:
        target = int(perm_np[i])
        if target not in S_set:
            pool = list(S_set - {int(i)})
            if not pool:
                continue
            j = int(rng.choice(pool))
            S_set.remove(j)
            S_set.add(target)
    return torch.tensor(list(S_set), dtype=torch.long)

def sample_context(rng: np.random.Generator,
                   perm: torch.Tensor,
                   rho: float,
                   length_choices: List[int]) -> torch.Tensor:
    ell = int(rng.choice(length_choices))
    return sample_context_with_pos_rate(rng, perm, ell, rho)

def make_context_list(n: int,
                      seed: int,
                      salt: int,
                      perm: torch.Tensor,
                      rho: float,
                      length_choices: List[int]) -> List[torch.Tensor]:
    rng = rng_from(seed, salt)
    return [sample_context(rng, perm, rho, length_choices) for _ in range(n)]

def make_data(seed: int, rho: float) -> DataBundle:
    X, perm = build_graph_and_embeddings(seed)
    contexts_val  = make_context_list(VAL_CONTEXTS,  seed, 200, perm, rho, TEST_LENGTH_CHOICES)
    contexts_test = make_context_list(TEST_CONTEXTS, seed, 300, perm, rho, TEST_LENGTH_CHOICES)
    return DataBundle(X=X, perm=perm, contexts_val=contexts_val, contexts_test=contexts_test)

def build_labels_for_context(ctx: torch.Tensor, perm: torch.Tensor, device) -> torch.Tensor:
    targets = perm[ctx]                            # [ell]
    eq = (targets[:, None] == ctx[None, :])        # [ell, ell] boolean
    return eq.to(device, dtype=torch.float32)

# =========================
# Model (idealized SA)
# =========================
class IdealizedSelfAttention(nn.Module):
    def __init__(self, d_model: int, D_K: int, h: int, alpha: float = 10.0):
        super().__init__()
        assert D_K % h == 0, "D_K must be divisible by h"
        self.d_model = d_model
        self.D_K = D_K
        self.h = h
        self.d_k = D_K // h
        self.alpha = alpha
        std = 1.0 / math.sqrt(d_model)
        self.W_Q = nn.Parameter(torch.randn(d_model, D_K) * std)
        self.W_K = nn.Parameter(torch.randn(d_model, D_K) * std)
        self.tau  = nn.Parameter(torch.tensor(0.0))

    @torch.no_grad()
    def save_checkpoint(self, path: str, meta: Dict):
        ckpt = {
            "W_Q": self.W_Q.detach().cpu(),
            "W_K": self.W_K.detach().cpu(),
            "tau": float(self.tau.detach().cpu()),
            "meta": meta
        }
        torch.save(ckpt, path)

    def scores_Smax(self, X_ctx: torch.Tensor) -> torch.Tensor:
        ell = X_ctx.shape[0]
        Q = X_ctx @ self.W_Q
        K = X_ctx @ self.W_K
        Qh = Q.view(ell, self.h, self.d_k).transpose(0, 1)        # [h, ell, d_k]
        Kh = K.view(ell, self.h, self.d_k).transpose(0, 1)        # [h, ell, d_k]
        S_heads = torch.matmul(Qh, Kh.transpose(1, 2))            # [h, ell, ell]
        S_max = S_heads.max(dim=0).values                         # [ell, ell]
        return S_max

# =========================
# Loss & metrics
# =========================
def weighted_bce_logits(z: torch.Tensor, y: torch.Tensor, pos_weight_scalar: float) -> torch.Tensor:
    pos_term = F.softplus(-z) * y * pos_weight_scalar
    neg_term = F.softplus( z) * (1.0 - y)
    return (pos_term + neg_term).mean()

@torch.no_grad()
def evaluate_model(model: IdealizedSelfAttention,
                   X: torch.Tensor,
                   perm: torch.Tensor,
                   contexts: List[torch.Tensor],
                   device,
                   compute_hist: bool = False) -> Tuple[float, float, Optional[np.ndarray], Optional[np.ndarray]]:
    model.eval()
    tau = float(model.tau.detach().cpu().item())
    TP = FP = FN = 0
    pos_scores_all = []
    hard_neg_scores_all = []

    for ctx in contexts:
        ctx = ctx.to(device)
        X_ctx = X[ctx]
        y = build_labels_for_context(ctx, perm, device)
        S_max = model.scores_Smax(X_ctx)
        preds = (S_max > tau).float()

        TP += (preds * y).sum().item()
        FP += (preds * (1.0 - y)).sum().item()
        FN += ((1.0 - preds) * y).sum().item()

        pos_scores = S_max[y == 1.0]
        if pos_scores.numel() > 0:
            pos_scores_all.append(pos_scores.detach().cpu())

        neg_matrix = S_max.clone()
        neg_matrix[y == 1.0] = -float('inf')
        for r in range(neg_matrix.shape[0]):
            row = neg_matrix[r]
            valid = torch.isfinite(row)
            if valid.any():
                k = min(5, int(valid.sum().item()))
                topk_vals = torch.topk(row[valid], k=k).values
                hard_neg_scores_all.append(topk_vals.detach().cpu())

    denom = (2 * TP + FP + FN)
    micro_f1 = (2 * TP / denom) if denom > 0 else 0.0

    if len(pos_scores_all) == 0 or len(hard_neg_scores_all) == 0:
        margin = float('nan')
    else:
        pos_concat = torch.cat(pos_scores_all)
        hard_neg_concat = torch.cat(hard_neg_scores_all)
        margin = float(pos_concat.mean().item() - hard_neg_concat.mean().item())

    if compute_hist:
        pos_np = torch.cat(pos_scores_all).numpy() if len(pos_scores_all) > 0 else np.array([])
        neg_np = torch.cat(hard_neg_scores_all).numpy() if len(hard_neg_scores_all) > 0 else np.array([])
        return micro_f1, margin, pos_np, neg_np
    else:
        return micro_f1, margin, None, None

# =========================
# Optimizer (determinism-friendly)
# =========================
def make_adamw(params, lr: float, weight_decay: float):
    kw = dict(lr=lr, weight_decay=weight_decay)
    # Prefer disabling fused/foreach paths to avoid nondeterministic atomics on some builds
    try:
        return torch.optim.AdamW(params, **kw, fused=False, foreach=False)
    except TypeError:
        try:
            return torch.optim.AdamW(params, **kw, foreach=False)
        except TypeError:
            return torch.optim.AdamW(params, **kw)

# =========================
# Training
# =========================
@dataclass
class TrainResult:
    DK: int
    h: int
    d_k: int
    passed: bool
    test_micro_f1: float
    margin: float
    tau: float
    steps_trained: int
    first_pass_step: Optional[int]
    stop_step: int
    ckpt_path: str

# === pass h_val explicitly and use deterministic-friendly AdamW ===
def train_one_configuration(DK: int,
                            h_val: int,
                            bundle: DataBundle,
                            device,
                            outdir: str,
                            rho: float) -> TrainResult:
    d_k = DK // h_val
    model = IdealizedSelfAttention(d_model=d_model, D_K=DK, h=h_val, alpha=alpha).to(device)
    opt = make_adamw(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_rng = rng_from(SEED, 500 + DK + 13*h_val)

    consecutive_passes = 0
    first_pass_step = None
    stop_step = max_steps

    for step in tqdm(range(1, max_steps + 1), desc=f"Training h={h_val}, D_K={DK}", leave=False):
        ctx = sample_context(train_rng, bundle.perm, rho, TRAIN_LENGTH_CHOICES).to(device)
        X_ctx = bundle.X[ctx]
        y = build_labels_for_context(ctx, bundle.perm, device)
        ell = X_ctx.shape[0]

        S_max = model.scores_Smax(X_ctx)
        z = alpha * (S_max - model.tau)

        pos_w = float(ell - 1)
        loss = weighted_bce_logits(z, y, pos_w)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        if step % check_every == 0:
            val_f1, _, _, _ = evaluate_model(model, bundle.X, bundle.perm,
                                             bundle.contexts_val, device, compute_hist=False)
            passed_now = (val_f1 >= target_f1)
            consecutive_passes = consecutive_passes + 1 if passed_now else 0
            if passed_now and first_pass_step is None:
                first_pass_step = step
            if consecutive_passes >= patience:
                stop_step = step
                break

    test_f1, margin, pos_hist, neg_hist = evaluate_model(
        model, bundle.X, bundle.perm, bundle.contexts_test, device, compute_hist=True
    )
    passed = (test_f1 >= target_f1)

    tag = f"DK{DK}_h{h_val}_dk{d_k}"

    try:
        if pos_hist is not None and neg_hist is not None and pos_hist.size > 0 and neg_hist.size > 0:
            plt.figure(figsize=(6,4))
            plt.hist(pos_hist, bins=60, alpha=0.6, label="S+ (positives)", density=True)
            plt.hist(neg_hist, bins=60, alpha=0.6, label="Top-5 S- (hard negatives)", density=True)
            plt.xlabel("Score")
            plt.ylabel("Density")
            plt.title(f"Score separation — {tag} (test)")
            plt.legend()
            plt.tight_layout()
            hist_path = os.path.join(outdir, f"score_hist_{tag}.png")
            ensure_dir(outdir)
            plt.savefig(hist_path, dpi=150)
            plt.close()
    except Exception as e:
        print("Histogram plotting skipped due to:", repr(e))

    tau_val = float(model.tau.detach().cpu().item())
    return TrainResult(
        DK=DK, h=h_val, d_k=d_k, passed=passed, test_micro_f1=float(test_f1),
        margin=float(margin), tau=tau_val, steps_trained=stop_step,
        first_pass_step=first_pass_step, stop_step=stop_step, ckpt_path="NA"
    )

# =========================
# Original single-experiment (kept for convenience)
# =========================
def run_experiment():
    """
    Single experiment using current globals: SEED, m, d_model, DK_TRIALS, h, etc.
    """
    set_seed(SEED)
    device = device_str()
    print(f"Device: {device}")
    if PRINT_REPRO_REPORT:
        print_repro_report()

    outdir = get_outdir(SAVE_TO_DRIVE, target_pos_rate)
    ensure_dir(outdir)
    print("Output directory:", outdir)

    bundle = make_data(SEED, target_pos_rate)
    bundle = DataBundle(X=bundle.X.to(device),
                        perm=bundle.perm.to(device),
                        contexts_val=bundle.contexts_val,
                        contexts_test=bundle.contexts_test)

    results: List[TrainResult] = []
    smallest_passing: Optional[int] = None

    for DK in DK_TRIALS:
        print("="*80)
        print(f"Training configuration: D_K={DK} (h={h} ⇒ d_k={DK//h}), ρ={target_pos_rate}")
        tr = train_one_configuration(DK, h, bundle, device, outdir, target_pos_rate)
        results.append(tr)
        print(f"→ Test micro-F1: {tr.test_micro_f1:.6f} | margin: {tr.margin:.6f} | τ: {tr.tau:.6f}")
        print(f"   Steps trained: {tr.steps_trained} | first_pass_step: {tr.first_pass_step} | checkpoint: {tr.ckpt_path}")
        if tr.passed and smallest_passing is None:
            smallest_passing = DK

    rows = []
    for r in results:
        rows.append({
            "D_K": r.DK,
            "h": r.h,
            "d_k": r.d_k,
            "rho": target_pos_rate,
            "passed": r.passed,
            "microF1_test": r.test_micro_f1,
            "margin": r.margin,
            "tau": r.tau,
            "steps_trained": r.steps_trained,
            "first_pass_step": r.first_pass_step if r.first_pass_step is not None else "NA",
            "stop_step": r.stop_step,
            "checkpoint": r.ckpt_path
        })
    df = pd.DataFrame(rows).sort_values(["h","D_K"]).reset_index(drop=True)
    csv_path = os.path.join(outdir, "results.csv")
    json_path = os.path.join(outdir, "results.json")
    df.to_csv(csv_path, index=False)
    with open(json_path, "w") as f:
        json.dump(rows, f, indent=2)

    print("\nSaved results:")
    print(" CSV:", csv_path)
    print(" JSON:", json_path)

    print("\nResults table:")
    display(df)

    if smallest_passing is None:
        print("\nOutcome: No configuration reached micro‑F1 ≥ 0.995 on the test set.")
    else:
        print(f"\nEmpirical hat D_K*: smallest passing D_K = {smallest_passing}")

    if not SAVE_TO_DRIVE:
        zip_path = os.path.join("/content", "perm_graph_results.zip")
        print("\nCreating zip for download at:", zip_path)
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
            for root, _, files in os.walk(outdir):
                for fn in files:
                    abspath = os.path.join(root, fn)
                    relpath = os.path.relpath(abspath, os.path.dirname(outdir))
                    zf.write(abspath, arcname=relpath)
        try:
            from google.colab import files  # type: ignore
            files.download(zip_path)
        except Exception as e:
            print("Automatic download not available; you can manually fetch:", zip_path, "| Reason:", repr(e))

    return df, outdir, smallest_passing

# =========================
# Sweep helpers & driver 
# =========================
def _dk_candidates_for(m_val: int,
                       d_val: int,
                       h_val: int,
                       ratios: List[float],
                       base_dks: List[int],
                       max_dk: Optional[int]) -> List[int]:
    thr_l2 = (m_val * max(math.log2(max(m_val, 2)), 1e-9)) / max(d_val, 1e-9)
    cand = set()
    for r in ratios:
        dk = int(math.ceil(r * thr_l2))
        dk = max(dk, 1)
        dk = int(math.ceil(dk / h_val) * h_val)   # enforce divisibility by h
        cand.add(dk)
    if INCLUDE_BASE_DKS and base_dks:
        for dk in base_dks:
            if dk % h_val == 0:
                cand.add(dk)
    filtered = sorted(x for x in cand if x >= h_val and (max_dk is None or x <= max_dk))
    if not filtered:
        filtered = [h_val]
    return filtered

def _mount_and_get_sweep_root():
    # bootstrap mount & yield sweep root dir
    _orig = BASE_NAME
    base = f"{SWEEP_ROOT}/_bootstrap_only"
    globals()["BASE_NAME"] = base
    tmp = get_outdir(SAVE_TO_DRIVE, target_pos_rate)
    drive_base = "/content/drive/MyDrive" if tmp.startswith("/content/drive/MyDrive") else "/content"
    sweep_root_dir = os.path.join(drive_base, SWEEP_ROOT)
    ensure_dir(sweep_root_dir)
    globals()["BASE_NAME"] = _orig
    print(f"\nSweep root: {sweep_root_dir}")
    return sweep_root_dir, drive_base

def _ratio_log2(m_val: int, d_val: int, DK_val: int):
    thr_l2 = (m_val * math.log2(max(m_val, 2))) / max(d_val, 1e-9)
    return DK_val / max(thr_l2, 1e-9), thr_l2

def _plot_block_by_h(agg_block: pd.DataFrame, m_val: int, d_val: int, out_dir: str):
    """Multi-line plot (one line per h) with error bars; x-axis = DK / (m log2 m / d)."""
    ensure_dir(out_dir)
    plt.figure(figsize=(7.2, 4.6))
    for h_val, sub in agg_block.groupby("h"):
        sub = sub.sort_values("ratio_log2")
        x = sub["ratio_log2"].to_numpy()
        y = sub["microF1_mean"].to_numpy()
        yerr = sub["microF1_std"].fillna(0.0).to_numpy()
        plt.errorbar(x, y, yerr=yerr, fmt="-o", capsize=3, label=f"h={h_val}", alpha=0.95)
    plt.axvline(1.0, linestyle="--", linewidth=1)
    plt.xlabel("DK / (m log2 m / d_model)")
    plt.ylabel("Mean micro‑F1 (test)")
    plt.title(f"F1 vs Capacity Ratio — m={m_val}, d_model={d_val}")
    plt.legend(title="Heads")
    plt.tight_layout()
    fig_path = os.path.join(out_dir, f"f1_vs_ratio_m{m_val}_d{d_val}_byH.png")
    plt.savefig(fig_path, dpi=150)
    plt.close()
    print("Saved block plot:", fig_path)

def run_sweep():
    """
    Joint sweep over M_SWEEP × DMODEL_SWEEP × H_SWEEP with REPEATS.
    For each (m, d_model), run all heads & DKs, then immediately save
    the per-block plot to Drive. Finally, save global summaries and
    a heatmap of min DK over h for each (m, d_model).
    """
    global m, d_model, SEED

    set_seed(SEED)
    device = device_str()
    print(f"Device: {device}")
    if PRINT_REPRO_REPORT:
        print_repro_report()

    sweep_root_dir, drive_base = _mount_and_get_sweep_root()

    all_rows: List[pd.DataFrame] = []
    block_min_rows = []  # for min-DK heatmap (over h)

    DK_TRIALS_ORIG = list(DK_TRIALS)
    SEED_BASE = SEED

    for m_val in M_SWEEP:
        for d_val in DMODEL_SWEEP:
            print(f"\n=== Block: m={m_val}, d_model={d_val} ===")

            # Per-block directories
            block_dir = os.path.join(sweep_root_dir, f"m{m_val}_d{d_val}")
            ensure_dir(block_dir)

            block_rows = []  # accumulate all runs for this (m,d)

            # Do all repeats (so we can compute means/std err bars)
            for rep in range(1, REPEATS + 1):
                # Fix globals for data generation only
                m = m_val
                d_model = d_val
                SEED = int(SEED_BASE + rep*997 + m_val*13 + d_val*23)

                # Build data once per replicate so all heads/DKs share it (fair comparison)
                bundle = make_data(SEED, target_pos_rate)
                bundle = DataBundle(X=bundle.X.to(device),
                                    perm=bundle.perm.to(device),
                                    contexts_val=bundle.contexts_val,
                                    contexts_test=bundle.contexts_test)

                for h_val in H_SWEEP:
                    dk_list = _dk_candidates_for(
                        m_val=m_val, d_val=d_val, h_val=h_val,
                        ratios=DK_RATIOS, base_dks=DK_TRIALS_ORIG, max_dk=MAX_DK
                    )
                    print(f"  • run{rep:02d} | h={h_val} | DK candidates={dk_list}")

                    run_outdir = os.path.join(block_dir, f"run{rep:02d}", f"h{h_val}")
                    ensure_dir(run_outdir)

                    for DK in dk_list:
                        tr = train_one_configuration(DK, h_val, bundle, device, run_outdir, target_pos_rate)
                        block_rows.append({
                            "m": m_val,
                            "d_model": d_val,
                            "seed": SEED,
                            "run": rep,
                            "h": h_val,
                            "D_K": tr.DK,
                            "d_k": tr.d_k,
                            "rho": target_pos_rate,
                            "passed": tr.passed,
                            "microF1_test": tr.test_micro_f1,
                            "margin": tr.margin,
                            "tau": tr.tau,
                            "steps_trained": tr.steps_trained,
                            "first_pass_step": tr.first_pass_step if tr.first_pass_step is not None else "NA",
                            "stop_step": tr.stop_step,
                            "outdir": run_outdir
                        })

            # Aggregate per-block and save plots immediately
            df_block = pd.DataFrame(block_rows)
            if len(df_block) == 0:
                print("No data collected for this block.")
                continue

            # normalize first_pass_step to numeric
            def _to_num(x):
                if x == "NA": return np.nan
                try:
                    return float(x)
                except Exception:
                    return np.nan
            df_block["first_pass_step_num"] = df_block["first_pass_step"].apply(_to_num)

            # Add ratio_log2 and thr for plotting
            thr_l2 = (m_val * math.log2(max(m_val, 2))) / max(d_val, 1e-9)
            df_block["thr_log2"] = thr_l2
            df_block["ratio_log2"] = df_block["D_K"] / max(thr_l2, 1e-9)

            # Per (h, D_K) aggregates
            group_cols = ["m", "d_model", "h", "D_K"]
            agg_block = df_block.groupby(group_cols).agg(
                runs=("passed", "count"),
                pass_rate=("passed", "mean"),
                microF1_mean=("microF1_test", "mean"),
                microF1_std=("microF1_test", "std"),
                margin_mean=("margin", "mean"),
                margin_std=("margin", "std"),
                steps_mean=("steps_trained", "mean"),
                steps_median=("steps_trained", "median"),
                first_pass_step_median=("first_pass_step_num", "median"),
                thr_log2=("thr_log2", "first"),
            ).reset_index()
            agg_block["ratio_log2"] = agg_block["D_K"] / agg_block["thr_log2"].replace(0, np.nan)

            # Save per-block CSVs
            block_all_csv = os.path.join(block_dir, "block_all_runs.csv")
            block_agg_csv = os.path.join(block_dir, "block_grouped_by_h_DK.csv")
            df_block.to_csv(block_all_csv, index=False)
            agg_block.to_csv(block_agg_csv, index=False)
            print("Saved block CSVs:", block_all_csv, "and", block_agg_csv)

            # Plot multi-line (one line per h) and save now
            _plot_block_by_h(agg_block, m_val, d_val, block_dir)

            # Compute minimum DK over all heads (exists some h with pass_rate == 1.0 / 0.8 / 0.5)
            mins = {"m": m_val, "d_model": d_val, "thr_log2": thr_l2}
            for cut, label in [(1.0, "100"), (0.80, "80"), (0.50, "50")]:
                cand = agg_block.loc[agg_block["pass_rate"] >= cut, ["D_K", "h"]].sort_values("D_K")
                mins[f"DK_star_anyh_pass{label}"] = (cand["D_K"].iloc[0] if len(cand) else np.nan)
                mins[f"h_for_DK_star_pass{label}"] = (cand["h"].iloc[0] if len(cand) else np.nan)
            block_min_rows.append(mins)

            # Keep for global concat
            all_rows.append(df_block)

    # === Global summaries ===
    if len(all_rows) == 0:
        print("No rows collected; sweep did not run.")
        return

    df_all = pd.concat(all_rows, ignore_index=True)

    # Global grouped with h included
    group_cols_global = ["m", "d_model", "h", "D_K"]
    agg_global = df_all.groupby(group_cols_global).agg(
        runs=("passed", "count"),
        pass_rate=("passed", "mean"),
        microF1_mean=("microF1_test", "mean"),
        microF1_std=("microF1_test", "std"),
        margin_mean=("margin", "mean"),
        margin_std=("margin", "std"),
        steps_mean=("steps_trained", "mean"),
        steps_median=("steps_trained", "median"),
        thr_log2=("thr_log2", "first")
    ).reset_index()
    agg_global["ratio_log2"] = agg_global["D_K"] / agg_global["thr_log2"].replace(0, np.nan)

    # Minimum DK per (m,d) over all heads (as requested)
    df_min_anyh = pd.DataFrame(block_min_rows)
    df_min_anyh["DK_star_anyh_ratio_log2_pass100"] = df_min_anyh["DK_star_anyh_pass100"] / df_min_anyh["thr_log2"]
    df_min_anyh["DK_star_anyh_ratio_log2_pass80"]  = df_min_anyh["DK_star_anyh_pass80"]  / df_min_anyh["thr_log2"]
    df_min_anyh["DK_star_anyh_ratio_log2_pass50"]  = df_min_anyh["DK_star_anyh_pass50"]  / df_min_anyh["thr_log2"]

    # Save global artifacts at sweep root
    all_csv = os.path.join(sweep_root_dir, "sweep_all_runs.csv")
    agg_csv = os.path.join(sweep_root_dir, "sweep_grouped_by_m_d_h_DK.csv")
    min_anyh_csv = os.path.join(sweep_root_dir, "sweep_min_DK_over_heads_per_m_d.csv")
    df_all.to_csv(all_csv, index=False)
    agg_global.to_csv(agg_csv, index=False)
    df_min_anyh.to_csv(min_anyh_csv, index=False)
    print("\nSweep saved:")
    print(" • All runs:", all_csv)
    print(" • Grouped per (m, d_model, h, D_K):", agg_csv)
    print(" • Min DK over heads per (m, d_model):", min_anyh_csv)

    # Optional combined overlay (all (m,d,h) curves) — can get busy
    try:
        plt.figure(figsize=(8.0,5.2))
        for (mm, dd, hh), sub in agg_global.groupby(["m", "d_model", "h"]):
            sub_sorted = sub.sort_values("ratio_log2")
            x = sub_sorted["ratio_log2"].to_numpy()
            y = sub_sorted["microF1_mean"].to_numpy()
            yerr = sub_sorted["microF1_std"].fillna(0.0).to_numpy()
            plt.errorbar(x, y, yerr=yerr, fmt="-o", capsize=3, alpha=0.80, label=f"m={mm}, d={dd}, h={hh}")
        plt.axvline(1.0, linestyle="--", linewidth=1)
        plt.xlabel("DK / (m log2 m / d_model)")
        plt.ylabel("Mean micro‑F1 (test)")
        plt.title("F1 vs Capacity Ratio — ALL (m, d_model, h)")
        plt.legend(ncol=2, fontsize=8)
        plt.tight_layout()
        fig_path_all = os.path.join(sweep_root_dir, "f1_vs_ratio_ALL_byH.png")
        plt.savefig(fig_path_all, dpi=150)
        plt.close()
        print(" • Combined plot:", fig_path_all)
    except Exception as e:
        print("Combined plotting skipped due to:", repr(e))

    # === Plot heatmap of min DK over heads per (m, d_model) ===
    try:
        # Use pass100 criterion for robustness; fall back handled in CSV
        pivot = df_min_anyh.pivot(index="m", columns="d_model", values="DK_star_anyh_pass100")
        plt.figure(figsize=(1.2 + 0.8*pivot.shape[1], 1.2 + 0.6*pivot.shape[0]))
        im = plt.imshow(pivot, aspect="auto", interpolation="nearest")
        plt.colorbar(im, label="Min DK (exists some h, pass rate = 100%)")
        plt.xticks(ticks=np.arange(pivot.shape[1]), labels=list(pivot.columns))
        plt.yticks(ticks=np.arange(pivot.shape[0]), labels=list(pivot.index))
        plt.xlabel("d_model")
        plt.ylabel("m")
        plt.title("Minimum DK over heads per (m, d_model)")
        # annotate cells
        for i in range(pivot.shape[0]):
            for j in range(pivot.shape[1]):
                val = pivot.iloc[i, j]
                txt = "—" if (pd.isna(val)) else str(int(val))
                plt.text(j, i, txt, ha="center", va="center", fontsize=9)
        plt.tight_layout()
        heatmap_path = os.path.join(sweep_root_dir, "min_DK_over_heads_heatmap.png")
        plt.savefig(heatmap_path, dpi=150)
        plt.close()
        print(" • Min DK heatmap:", heatmap_path)
    except Exception as e:
        print("Heatmap plotting skipped due to:", repr(e))

# Execute
if __name__ == "__main__":
    if RUN_SWEEP:
        run_sweep()
    else:
        run_experiment()







#@title Aggregate permutation-graph sweep CSVs into a "min DK over heads" heatmap + line plot (log y with powers of 2)
#@markdown **Instructions:** Set `TARGET_DIR` to the folder in Google Drive that contains your *_grouped_by_m_d_h_DK*.csv summary files.
#@markdown <br/>By default we treat "minimum DK" as the smallest DK for which `microF1_mean >= 0.9899`. If you intended the opposite, change `COMPARE_OP` to `"<"`.
#@markdown <br/>**New:** Line plot uses a log y-axis with **major ticks at powers of 2 only**, labeled as \(2^k\). No fractional ticks.
#@markdown <br/>**New:** Error-bar toggle on per-(m,d_model) plots — set `ERRORBAR_MODE` to `"std"` (±1 SD) or `"ci95"` (95% CI for the mean).
#@markdown <br/>**This version enforces a single selection policy everywhere:**
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• For each (m, d_model), pick the **min total D_K** that meets the F1 criterion;
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• Among rows at that min D_K, pick the **head with max F1** (ties → **more heads**);
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• Exact duplicate configs keep the **best row** (max F1, then higher pass_rate, then higher runs).
#@markdown <br/><br/> Provide exclusion lists at the top:
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• `EXCLUDE_M`: list of `m` values to exclude (e.g., `[16384, 32768]`);
#@markdown <br/>&nbsp;&nbsp;&nbsp;&nbsp;• `EXCLUDE_PAIRS`: list of `(m, d_model)` pairs to exclude (e.g., `[(16384, 2048), (32768, 4096)]`).
#@markdown <br/>All excluded values are removed **before** any computation, tables, or plots.

# === Parameters ===
TARGET_DIR = "/content/drive/MyDrive/Self Attention/Ten trial sweep/Outputs"  #@param {type:"string"}
#TARGET_DIR = "/content/drive/MyDrive/Self Attention/Outputs2"  #@param {type:"string"}
SEARCH_SUBDIRS = False                                      #@param {type:"boolean"}
THRESHOLD = 0.99                                            #@param {type:"number"}
COMPARE_OP = ">="                                           #@param [">=", "<"] {allow-input: true}
#RANDOM_SEED = None                                         #@param {type:"raw"}  # (unused: dedup is deterministic by best F1)

# --- Exclusions (NEW) ---
#EXCLUDE_M = [1024,2048,4096]            #@param {type:"raw"}  # e.g., [128, 256]
EXCLUDE_M = []            #@param {type:"raw"}  # e.g., [128, 256]
#EXCLUDE_PAIRS = [(128,16),(256,16),(512,16)]        #@param {type:"raw"}  # e.g., [(128, 16), (256, 32)]
EXCLUDE_PAIRS = []        #@param {type:"raw"}  # e.g., [(128, 16), (256, 32)]

# Error-bar mode (affects only the per-(m,d_model) F1 vs DK figures)
ERRORBAR_MODE = "ci95"  #@param ["std", "ci95"] {allow-input: true}
# "std"  -> draw ±1 sample standard deviation across runs
# "ci95" -> draw two-sided 95% CI half-width for the mean
#          uses Student's t (df = runs-1) when 'runs' is available; otherwise falls back to z≈1.96 if SE is present
# --- Runs filter ---
EXCLUDE_SINGLE_RUN = False  #@param {type:"boolean"}  # If True, drop rows with runs == 1 before any computation



# Diagonal plots configuration (m/d_model ratios to consider)
DIAGONAL_RATIOS = [8]  # e.g., 8 means m == 8 * d_model

# Selection tolerance for floating F1 ties (treat values within this as equal)
F1_TOL = 1e-12

# === Imports ===
import os, glob, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import matplotlib.ticker as mticker

# === Mount Drive (safe if already mounted) ===
try:
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive", force_remount=False)
except Exception as e:
    print("Drive mount skipped or failed (running locally perhaps):", repr(e))

# === Validate directory ===
if not os.path.isdir(TARGET_DIR):
    # try relative to MyDrive
    alt = os.path.join("/content/drive/MyDrive", TARGET_DIR.strip("/"))
    if os.path.isdir(alt):
        TARGET_DIR = alt
    else:
        raise FileNotFoundError(f"Directory not found: {TARGET_DIR}")

# Ensure directory exists (and is writable)
os.makedirs(TARGET_DIR, exist_ok=True)
print(f"\nReading summary CSVs from: {TARGET_DIR}")

# === Find candidate files ===
pattern = "**/*grouped_by_*.csv" if SEARCH_SUBDIRS else "*grouped_by_*.csv"
paths = glob.glob(os.path.join(TARGET_DIR, pattern), recursive=SEARCH_SUBDIRS)
paths = sorted(set(paths))

if not paths:
    raise FileNotFoundError("No files matched *grouped_by_m_d_h_DK*.csv in the target directory.")

# REQUIRED_PREFIX = ["m", "d_model", "h", "D_K", "runs", "pass_rate"]
# NEED_COLUMN = "microF1_mean"

# good_files, bad_files = [], []

# # === Quick header validation ===
# for p in paths:
#     try:
#         header = pd.read_csv(p, nrows=0)
#         cols = list(header.columns)
#         if len(cols) >= 6 and cols[:6] == REQUIRED_PREFIX and NEED_COLUMN in cols:
#             good_files.append(p)
#         else:
#             bad_files.append((p, f"First 6 columns found: {cols[:6]} ; required: {REQUIRED_PREFIX} ; microF1_mean present? {NEED_COLUMN in cols}"))
#     except Exception as e:
#         bad_files.append((p, f"read_csv failed: {repr(e)}"))


REQUIRED_COLUMNS = {"m", "d_model", "h", "D_K", "microF1_mean"}  # minimum set

# === Quick header validation (order-independent) ===
for p in paths:
    try:
        header = pd.read_csv(p, nrows=0)
        cols = set(header.columns)
        if REQUIRED_COLUMNS.issubset(cols):
            good_files.append(p)
        else:
            bad_files.append((
                p,
                f"missing required columns; found={sorted(header.columns)}, need≥{sorted(REQUIRED_COLUMNS)}"
            ))
    except Exception as e:
        bad_files.append((p, f"read_csv failed: {repr(e)}"))



if not good_files:
    raise RuntimeError("No valid CSVs found. All candidate files failed the format check.")

print(f"Found {len(good_files)} valid file(s), {len(bad_files)} skipped due to format.\n")
if bad_files:
    print("Skipped files (reason shown):")
    for b, why in bad_files:
        print(" -", os.path.basename(b), "->", why)
    print()

# === Load and normalize ===
dfs = []
for p in good_files:
    try:
        df = pd.read_csv(p)
        # enforce numeric types
        for c in ["m", "d_model", "h", "D_K", "runs"]:
            if c in df.columns:
                df[c] = pd.to_numeric(df[c], errors="coerce")
        for c in ["pass_rate", "microF1_mean"]:
            if c in df.columns:
                df[c] = pd.to_numeric(df[c], errors="coerce")
        # keep only rows with required info
        df = df.dropna(subset=["m", "d_model", "h", "D_K", "microF1_mean"])
        dfs.append(df)
    except Exception as e:
        print("Read failed for", p, "->", repr(e))

if not dfs:
    raise RuntimeError("After reading, no usable rows were found.")

df_all = pd.concat(dfs, ignore_index=True)

# Cast to ints where appropriate (safe via dropna above)
for c in ["m", "d_model", "h", "D_K", "runs"]:
    if c in df_all.columns:
        df_all[c] = df_all[c].astype(int)

# === Exclusion helpers (NEW) ===
def _normalize_exclude_m(val):
    """Return a set of ints for EXCLUDE_M."""
    s = set()
    if val is None:
        return s
    if isinstance(val, (int, np.integer)):
        return {int(val)}
    try:
        for v in val:
            if v is None:
                continue
            if isinstance(v, (list, tuple)) and len(v) >= 1:
                s.add(int(v[0]))
            else:
                s.add(int(v))
    except Exception:
        try:
            s.add(int(val))
        except Exception:
            pass
    return s

def _normalize_exclude_pairs(val):
    """Return a set of (int, int) pairs for EXCLUDE_PAIRS."""
    s = set()
    if val is None:
        return s
    try:
        for item in val:
            if item is None:
                continue
            if isinstance(item, (tuple, list)) and len(item) >= 2:
                m, d = item[0], item[1]
                if m is None or d is None:
                    continue
                s.add((int(m), int(d)))
            elif isinstance(item, str) and ("," in item):
                parts = item.split(",")
                if len(parts) >= 2:
                    s.add((int(parts[0].strip()), int(parts[1].strip())))
    except Exception:
        pass
    return s

def _apply_exclusions(df, excl_m_set, excl_pair_set):
    """Drop rows where m ∈ excl_m_set OR (m,d_model) ∈ excl_pair_set."""
    if df.empty or (not excl_m_set and not excl_pair_set):
        return df.copy(), 0
    mask_m = df["m"].isin(excl_m_set) if excl_m_set else pd.Series(False, index=df.index)
    if excl_pair_set:
        pair_series = pd.Series(list(zip(df["m"].values, df["d_model"].values)), index=df.index)
        mask_pairs = pair_series.isin(excl_pair_set)
    else:
        mask_pairs = pd.Series(False, index=df.index)
    mask = mask_m | mask_pairs
    removed = int(mask.sum())
    return df.loc[~mask].copy(), removed

def _drop_single_run_rows(df: pd.DataFrame, enabled: bool):
    """
    If enabled is True, remove rows with runs == 1.
    If the 'runs' column is missing, print a note and do nothing.
    Returns (filtered_df, removed_count).
    """
    if not enabled:
        return df.copy(), 0
    if "runs" not in df.columns:
        print("Single-run exclusion enabled but 'runs' column not found; skipping.")
        return df.copy(), 0

    mask = (df["runs"] == 1)
    removed = int(mask.sum())
    return df.loc[~mask].copy(), removed


# Normalize + apply exclusions **before** any selection/plots (NEW)
EXCLUDE_M_SET = _normalize_exclude_m(EXCLUDE_M)
EXCLUDE_PAIR_SET = _normalize_exclude_pairs(EXCLUDE_PAIRS)

if EXCLUDE_M_SET or EXCLUDE_PAIR_SET:
    print(f"Exclusion lists active — EXCLUDE_M: {sorted(EXCLUDE_M_SET)} ; EXCLUDE_PAIRS: {sorted(EXCLUDE_PAIR_SET)}")
else:
    print("Exclusion lists empty — no rows will be dropped.")

df_all, _removed_rows = _apply_exclusions(df_all, EXCLUDE_M_SET, EXCLUDE_PAIR_SET)
print(f"Rows remaining after exclusions: {len(df_all)}")

if df_all.empty:
    raise RuntimeError("All rows were excluded by EXCLUDE_M / EXCLUDE_PAIRS. Nothing to compute or plot.")

# --- Apply the single-run filter (NEW) ---
df_all, removed_single = _drop_single_run_rows(df_all, EXCLUDE_SINGLE_RUN)
if EXCLUDE_SINGLE_RUN:
    print(f"Single-run rows filter active — removed {removed_single} row(s) with runs == 1.")
else:
    print("Single-run rows filter inactive.")
print(f"Rows remaining after single-run filter: {len(df_all)}")

# === Build per-head F1 stats for the plotting script ===
# We aggregate over all rows with the same (m, d_model, h, D_K).
# If microF1_std and microF1_n (or runs) exist, we compute a pooled std; otherwise we fall back
# to sample std across rows of microF1_mean (NaN if only one row).

def _build_per_head(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    # Identify count column to use if present
    count_col = None
    if "microF1_n" in df.columns:
        count_col = "microF1_n"
    elif "runs" in df.columns:
        count_col = "runs"

    def _combine(group: pd.DataFrame) -> pd.Series:
        mu = pd.to_numeric(group["microF1_mean"], errors="coerce").astype(float)
        # choose n_i for each row (>=1)
        if count_col is not None:
            n_i = pd.to_numeric(group[count_col], errors="coerce").fillna(1).astype(int).clip(lower=1)
        else:
            n_i = pd.Series(1, index=group.index)

        N = int(n_i.sum()) if len(n_i) else 0
        mu_bar = float((mu * n_i).sum() / N) if N > 0 else float("nan")

        # pooled variance if we have per-row stds; else fallback to sample std across means
        if "microF1_std" in group.columns and group["microF1_std"].notna().any():
            s = pd.to_numeric(group["microF1_std"], errors="coerce").fillna(0.0).astype(float)
            # pooled (unbiased) variance across groups:
            within = ((n_i - 1).clip(lower=0) * (s ** 2)).sum()
            between = (n_i * (mu - mu_bar) ** 2).sum()
            denom = max(N - 1, 1)
            var = float((within + between) / denom)
            std = float(np.sqrt(max(var, 0.0)))
        else:
            std = float(mu.std(ddof=1)) if len(mu) > 1 else float("nan")

        return pd.Series({
            "F1_mean": mu_bar,
            "F1_std": std,
            "F1_n": int(N if N > 0 else len(group))
        })

    per_head = (
        df.groupby(["m", "d_model", "h", "D_K"], as_index=False)
          .apply(_combine)
          .reset_index(drop=True)
    )

    # enforce dtypes expected by the plotting script
    per_head["m"] = per_head["m"].astype(int)
    per_head["d_model"] = per_head["d_model"].astype(int)
    per_head["h"] = per_head["h"].astype(int)
    per_head["D_K"] = pd.to_numeric(per_head["D_K"], errors="coerce").astype(float)
    per_head["F1_mean"] = pd.to_numeric(per_head["F1_mean"], errors="coerce").astype(float)
    per_head["F1_std"] = pd.to_numeric(per_head["F1_std"], errors="coerce").astype(float)
    per_head["F1_n"] = pd.to_numeric(per_head["F1_n"], errors="coerce").astype(int)

    return per_head

per_head = _build_per_head(df_all)
# Optional: save for inspection
per_head_csv = os.path.join(TARGET_DIR, "per_head_stats.csv")
per_head.to_csv(per_head_csv, index=False)
print(f"Built per_head stats with {len(per_head)} rows → {per_head_csv}")


if df_all.empty:
    raise RuntimeError("All rows were excluded by EXCLUDE_M / EXCLUDE_PAIRS / single-run filter. Nothing to compute or plot.")



# === Deduplicate exact configs by BEST row: (m, d_model, h, D_K) ===
# Keep the single row with:
#   1) highest microF1_mean
#   2) then highest pass_rate (if present)
#   3) then largest runs (if present)
def _pick_best(group: pd.DataFrame) -> pd.DataFrame:
    sort_cols = ["microF1_mean"]
    ascending = [False]
    if "pass_rate" in group.columns:
        sort_cols.append("pass_rate")
        ascending.append(False)
    if "runs" in group.columns:
        sort_cols.append("runs")
        ascending.append(False)
    # Deterministic within equalities due to pandas stable sort
    g = group.sort_values(sort_cols, ascending=ascending)
    return g.iloc[[0]]

keys = ["m", "d_model", "h", "D_K"]
unique_df = df_all.groupby(keys, group_keys=False).apply(_pick_best).reset_index(drop=True)

print(f"Loaded rows (post-exclusion): {len(df_all)} ; Unique configs retained (best per exact config): {len(unique_df)}")

# === Save aggregated (best one per (m, d_model, h, D_K)) ===
aggregated_csv = os.path.join(TARGET_DIR, "aggregated_grouped_by_m_d_h_DK.csv")
(
    unique_df
    .sort_values(["m", "d_model", "h", "D_K"])
    .to_csv(aggregated_csv, index=False)
)
# (No stdout on purpose)

# === Select rows that meet (or fail) the threshold, per COMPARE_OP ===
if COMPARE_OP.strip() == ">=":
    ok = unique_df[unique_df["microF1_mean"] >= THRESHOLD].copy()
    criterion_desc = f"microF1_mean ≥ {THRESHOLD}"
elif COMPARE_OP.strip() == "<":
    ok = unique_df[unique_df["microF1_mean"] < THRESHOLD].copy()
    criterion_desc = f"microF1_mean < {THRESHOLD}"
else:
    raise ValueError("COMPARE_OP must be \">=\" or \"<\"")

if ok.empty:
    raise RuntimeError(f"No configurations satisfy the criterion ({criterion_desc}) after applying exclusions. Nothing to plot.")

# === For each (m, d_model), choose:
#     (1) the smallest DK meeting the criterion,
#     (2) among those rows at that DK, the head(s) with maximum F1,
#     (3) if still tied, prefer MORE heads (larger h).
#     This selection is used everywhere downstream.
minDK_for_pair = ok.groupby(["m", "d_model"])["D_K"].transform("min")

# Rows at the minimum D_K for each (m, d_model)
at_minDK = ok.loc[ok["D_K"] == minDK_for_pair].copy()

# Among those, filter to those achieving the max F1 (within tolerance)
maxF1_for_pair = at_minDK.groupby(["m", "d_model"])["microF1_mean"].transform("max")
cands = at_minDK.loc[at_minDK["microF1_mean"] >= (maxF1_for_pair - F1_TOL)].copy()

# Final tie-break on h: prefer MORE heads
min_rows = (
    cands.sort_values(["m", "d_model", "h"], ascending=[True, True, False])
         .drop_duplicates(["m", "d_model"], keep="first")
         .reset_index(drop=True)
)

# === Determine if chosen DK equals the smallest measured DK for that chosen head (yellow flag) ===
min_measured_per_head = unique_df.groupby(["m", "d_model", "h"], as_index=False)["D_K"].min()
min_measured_per_head = min_measured_per_head.rename(columns={"D_K": "min_measured_DK_for_head"})

min_rows = min_rows.merge(min_measured_per_head, on=["m", "d_model", "h"], how="left")
min_rows["yellow"] = (min_rows["D_K"] == min_rows["min_measured_DK_for_head"])

# === Pivot to a grid: rows=m, cols=d_model, values=DK, and a parallel boolean mask for yellow ===
vals = min_rows.pivot(index="m", columns="d_model", values="D_K")
flags = min_rows.pivot(index="m", columns="d_model", values="yellow")

# ensure consistent ordering
vals = vals.sort_index().sort_index(axis=1)
flags = flags.reindex_like(vals)

m_values = list(vals.index)
d_values = list(vals.columns)


# Also save the table used for the heatmap (with the chosen head and flag)
out_csv = os.path.join(TARGET_DIR, "min_DK_over_heads_table.csv")
min_rows_to_save = min_rows[["m", "d_model", "h", "D_K", "microF1_mean", "yellow"]].sort_values(["m","d_model"])
min_rows_to_save.to_csv(out_csv, index=False)
print("Saved table to:", out_csv, "| Exists?", os.path.exists(out_csv))


# ====================================================================================
# Combined heatmap: center = mean DK; upper-right = DK upper bound (CI worst-case);
# lower-left = DK lower bound (CI best-case). Replaces the two separate CI heatmaps.
# Relies on: unique_df, TARGET_DIR, COMPARE_OP, THRESHOLD, F1_TOL,
# and that the first (mean) heatmap has already produced `vals`, `m_values`, `d_values`.
# ====================================================================================

# --- Discover error-bar columns (same logic as before) ---
_STD_COLS = ["microF1_std", "microF1_stddev"]
_SE_COLS  = ["microF1_stderr", "microF1_se", "microF1_sem"]
_CI_COLS  = ["microF1_ci", "microF1_ci95", "microF1_CI"]

STD_COL = next((c for c in _STD_COLS if c in unique_df.columns), None)
SE_COL  = next((c for c in _SE_COLS  if c in unique_df.columns), None)
CI_COL  = next((c for c in _CI_COLS  if c in unique_df.columns), None)

# 95% two-sided t criticals for df = 1..30
_T_CRIT_95 = {
    1:12.706, 2:4.303, 3:3.182, 4:2.776, 5:2.571, 6:2.447, 7:2.365, 8:2.306, 9:2.262,
    10:2.228, 11:2.201, 12:2.179, 13:2.160, 14:2.145, 15:2.131, 16:2.120, 17:2.110,
    18:2.101, 19:2.093, 20:2.086, 21:2.080, 22:2.074, 23:2.069, 24:2.064, 25:2.060,
    26:2.056, 27:2.052, 28:2.048, 29:2.045, 30:2.042
}
def _tcrit_95_df(df_val: int) -> float:
    df_val = int(max(1, df_val))
    if df_val <= 30:
        return _T_CRIT_95[df_val]
    elif df_val <= 60:
        return 2.000
    else:
        return 1.960

def _sanitize_err(a):
    a = np.asarray(a, dtype=float)
    a[~np.isfinite(a)] = np.nan
    a[a < 0] = np.nan
    return a

def _ci95_halfwidth_series(df: pd.DataFrame) -> pd.Series:
    """Return per-row 95% CI half-width for the mean F1."""
    n   = df["runs"].to_numpy(dtype=float) if ("runs" in df.columns) else None
    std = df[STD_COL].to_numpy(dtype=float) if STD_COL else None
    se  = df[SE_COL ].to_numpy(dtype=float) if SE_COL  else None
    ci  = df[CI_COL ].to_numpy(dtype=float) if CI_COL  else None

    if ci is not None:
        yerr_ci95 = ci.copy()  # already half-width
    elif se is not None:
        if n is not None:
            df_arr = np.maximum(1, n.astype(int) - 1)
            tcrit = np.array([_tcrit_95_df(int(d)) for d in df_arr])
            yerr_ci95 = tcrit * se
        else:
            yerr_ci95 = 1.96 * se
    elif (std is not None) and (n is not None):
        df_arr = np.maximum(1, n.astype(int) - 1)
        tcrit = np.array([_tcrit_95_df(int(d)) for d in df_arr])
        with np.errstate(divide="ignore", invalid="ignore"):
            yerr_ci95 = np.where(n > 0, tcrit * std / np.sqrt(n), np.nan)
    else:
        return pd.Series(np.nan, index=df.index)

    return pd.Series(_sanitize_err(yerr_ci95), index=df.index)

# Compute CI half-widths and bounds
unique_df_ci = unique_df.copy()
unique_df_ci["microF1_ci95_hw"] = _ci95_halfwidth_series(unique_df_ci)

if unique_df_ci["microF1_ci95_hw"].isna().all():
    print("Combined CI heatmap: cannot compute 95% CI half-widths (no std/sem/ci available).")
    print("Will plot only the mean in the center; CI corners will show '—'.")
    unique_df_ci["microF1_lb95"] = np.nan
    unique_df_ci["microF1_ub95"] = np.nan
else:
    unique_df_ci["microF1_lb95"] = np.clip(unique_df_ci["microF1_mean"] - unique_df_ci["microF1_ci95_hw"], 0.0, 1.0)
    unique_df_ci["microF1_ub95"] = np.clip(unique_df_ci["microF1_mean"] + unique_df_ci["microF1_ci95_hw"], 0.0, 1.0)
    present_rows = int(unique_df_ci["microF1_ci95_hw"].notna().sum())
    print(f"Combined CI heatmap: computed 95% CI half-widths for {present_rows}/{len(unique_df_ci)} rows.")

def _minDK_grid_for_metric(df_src: pd.DataFrame, metric_col: str) -> pd.DataFrame:
    """Return a pivot grid (m x d_model) of the minimum DK meeting the threshold using `metric_col`."""
    # Filter by threshold using this metric
    cmp = COMPARE_OP.strip()
    if cmp == ">=":
        ok = df_src[df_src[metric_col] >= THRESHOLD].copy()
    elif cmp == "<":
        ok = df_src[df_src[metric_col] < THRESHOLD].copy()
    else:
        raise ValueError("COMPARE_OP must be \">=\" or \"<\"")

    if ok.empty:
        # Return an all-NaN grid on the same (m,d) domain as the mean `vals`
        empty = pd.DataFrame(index=vals.index, columns=vals.columns, dtype=float)
        return empty

    # (1) smallest DK meeting criterion per (m, d_model)
    minDK_for_pair = ok.groupby(["m", "d_model"])["D_K"].transform("min")

    # (2) subset at that minimum DK
    at_minDK = ok.loc[ok["D_K"] == minDK_for_pair].copy()

    # (3) within those, keep max metric (within tolerance)
    maxMetric_for_pair = at_minDK.groupby(["m", "d_model"])[metric_col].transform("max")
    cands = at_minDK.loc[at_minDK[metric_col] >= (maxMetric_for_pair - F1_TOL)].copy()

    # (4) tie-break on h: prefer MORE heads
    chosen = (
        cands.sort_values(["m", "d_model", "h"], ascending=[True, True, False])
             .drop_duplicates(["m", "d_model"], keep="first")
             .reset_index(drop=True)
    )

    # (5) pivot to grid and align to the mean grid's ordering
    grid = chosen.pivot(index="m", columns="d_model", values="D_K")
    grid = grid.sort_index().sort_index(axis=1)
    grid = grid.reindex(index=vals.index, columns=vals.columns)
    return grid

# Build DK grids:
# - mean DK grid is the earlier `vals` from your first heatmap
# - upper DK bound (pessimistic) uses the LOWER F1 bound
dk_upper = _minDK_grid_for_metric(unique_df_ci, "microF1_lb95")
# - lower DK bound (optimistic) uses the UPPER F1 bound
dk_lower = _minDK_grid_for_metric(unique_df_ci, "microF1_ub95")

# --- Plot the combined grid with three numbers per cell ---
fig_w = 1.2 + 0.9 * max(1, len(d_values))
fig_h = 1.2 + 0.9 * max(1, len(m_values))
fig3, ax3 = plt.subplots(figsize=(fig_w, fig_h))

# subtle grey background to delineate cells
bg = np.zeros((len(m_values), len(d_values)))
_ = ax3.imshow(bg, aspect="auto", vmin=0, vmax=1, cmap="Greys", alpha=0.08)

# grid / ticks
ax3.set_xticks(np.arange(len(d_values)))
ax3.set_yticks(np.arange(len(m_values)))
ax3.set_xticklabels([str(int(x)) for x in d_values], rotation=0)
ax3.set_yticklabels([str(int(y)) for y in m_values])
ax3.set_xticks(np.arange(-.5, len(d_values), 1), minor=True)
ax3.set_yticks(np.arange(-.5, len(m_values), 1), minor=True)
ax3.grid(which="minor", color="black", linewidth=0.5, alpha=0.25)
ax3.tick_params(top=False, bottom=True, left=True, right=False)

# text sizes: center big, corners small
FS_CENTER = 13
FS_CORNER = 10
OFF = 0.40  # how far into the corner to place the small numbers

for i, m_val in enumerate(m_values):
    for j, d_val in enumerate(d_values):
        # center (mean)
        v_mean  = vals.iloc[i, j]
        # upper-right corner: DK upper bound (pessimistic) -> uses F1 lower bound
        v_upper = dk_upper.iloc[i, j]
        # lower-left corner: DK lower bound (optimistic) -> uses F1 upper bound
        v_lower = dk_lower.iloc[i, j]

        # formatters
        c_txt = "—" if pd.isna(v_mean)  else str(int(v_mean))
        u_txt = "—" if pd.isna(v_upper) else str(int(v_upper))
        l_txt = "—" if pd.isna(v_lower) else str(int(v_lower))

        # draw center (big)
        ax3.text(j, i, c_txt, ha="center", va="center",
                 fontsize=FS_CENTER, fontweight="bold", color="black")

        # draw upper-right (small)
        ax3.text(j + OFF, i - OFF, u_txt, ha="right", va="top",
                 fontsize=FS_CORNER, color="black")

        # draw lower-left (small)
        ax3.text(j - OFF, i + OFF, l_txt, ha="left", va="bottom",
                 fontsize=FS_CORNER, color="black")

title = ("Minimum DK achieved")
ax3.set_title(title)
ax3.set_xlabel(r"$d_{\mathrm{model}}$")
ax3.set_ylabel("m")
plt.tight_layout()

# Save figure
combined_path = os.path.join(TARGET_DIR, "min_DK_over_heads_heatmap_mean_with_CI_bounds.png")
fig3.savefig(combined_path, dpi=150)
print("\nSaved heatmap to:", combined_path, "| Exists?", os.path.exists(combined_path))
plt.show()
plt.close(fig3)

# Save a tidy table with the three DK numbers for each (m, d_model)
mean_long  = vals.stack(dropna=False).rename("DK_mean").reset_index()
upper_long = dk_upper.stack(dropna=False).rename("DK_upper").reset_index()
lower_long = dk_lower.stack(dropna=False).rename("DK_lower").reset_index()

triplet_df = (mean_long
              .merge(upper_long, on=["m","d_model"], how="left")
              .merge(lower_long, on=["m","d_model"], how="left")
              .sort_values(["m","d_model"]))

out_triplet_csv = os.path.join(TARGET_DIR, "min_DK_over_heads_table_mean_with_CI_bounds.csv")
triplet_df.to_csv(out_triplet_csv, index=False)
print("Saved combined table to:", out_triplet_csv, "| Exists?", os.path.exists(out_triplet_csv))


# --------------------------------------------------------------------------------
# === TABLE — Per-head size d_k for the SELECTED head (max F1 at min D_K)
# --------------------------------------------------------------------------------

selected_with_dk = min_rows.copy()
selected_with_dk["d_k"] = selected_with_dk["D_K"].astype(float) / selected_with_dk["h"].astype(float)

per_head_rows_to_save = (
    selected_with_dk[["m", "d_model", "h", "D_K", "d_k", "microF1_mean"]]
    .sort_values(["m", "d_model"])
    .copy()
)

# If all d_k are integers, cast to int for cleaner display; otherwise keep as float
dk_vals = per_head_rows_to_save["d_k"].to_numpy(dtype=float)
if np.all(np.isfinite(dk_vals)) and np.allclose(dk_vals, np.round(dk_vals)):
    per_head_rows_to_save["d_k"] = np.round(per_head_rows_to_save["d_k"]).astype(int)
else:
    per_head_rows_to_save["d_k"] = per_head_rows_to_save["d_k"].astype(float)

per_head_csv = os.path.join(TARGET_DIR, "min_per_head_dk_at_min_DK_table.csv")
per_head_rows_to_save.to_csv(per_head_csv, index=False)
print("Saved per-head table (selected head) to:", per_head_csv, "| Exists?", os.path.exists(per_head_csv))

print("\nPer-head size d_k for the selected head at the min total D_K (full table):")
print(per_head_rows_to_save.to_string(index=False))



# --------------------------------------------------------------------------------
# LINE PLOT with error bars — y = number of heads (h) that produced the minimum DK,
# per (m, d_model). Error bars computed exactly as in the scatter script:
#   ±10% DK window around the winner + one-sided Welch t-test (alpha=0.05)
# X-axis: log2 with powers-of-two tick labels; Y-axis: log with powers-of-two ticks
# --------------------------------------------------------------------------------

import os, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

# ---------------------------------------
# Pre-flight checks and light conveniences
# ---------------------------------------
if "TARGET_DIR" not in globals():
    TARGET_DIR = "."

if "min_rows" not in globals() or not isinstance(min_rows, pd.DataFrame):
    raise RuntimeError(
        "This script expects a DataFrame `min_rows` with columns ['m','d_model','h'] "
        "containing the chosen head per (m, d_model)."
    )

# ---------------------------------------
# Helper to recover per-head stats exactly like in the second script
# Required columns in the returned df:
#   ['m','d_model','h','D_K','F1_mean','F1_std','F1_n']
# ---------------------------------------
def _coerce_per_head_df():
    tidy_candidates = (
        "per_head", "stats_long", "triplet_df_full", "all_stats",
        "summary_long", "f1_stats_by_head"
    )
    for name in tidy_candidates:
        if name in globals():
            obj = globals()[name]
            if isinstance(obj, pd.DataFrame):
                df = obj.copy()
                rename = {}
                for c in df.columns:
                    lc = str(c).lower()
                    if lc in ("d_k", "dk"):                      rename[c] = "D_K"
                    if lc in ("f1_mean", "f1_mu"):               rename[c] = "F1_mean"
                    if lc in ("f1_std","f1_sigma","std_f1"):     rename[c] = "F1_std"
                    if lc in ("f1_n","n_f1","count_f1","n"):     rename[c] = "F1_n"
                df = df.rename(columns=rename)
                need = {"m","d_model","h","D_K","F1_mean","F1_std","F1_n"}
                if need.issubset(set(df.columns)):
                    out = df[list(need)].copy()
                    out["m"] = out["m"].astype(int)
                    out["d_model"] = out["d_model"].astype(int)
                    out["h"] = out["h"].astype(int)
                    out["D_K"] = out["D_K"].astype(float)
                    out["F1_mean"] = out["F1_mean"].astype(float)
                    out["F1_std"] = out["F1_std"].astype(float)
                    out["F1_n"] = out["F1_n"].astype(int)
                    return out

    grid_names_mean = ("F1_mean","f1_mean","F1_mu","f1_mu","mean_f1")
    grid_names_std  = ("F1_std","f1_std","F1_sigma","f1_sigma","std_f1")
    grid_names_n    = ("F1_n","f1_n","count_f1","n_f1","N_f1")
    grid_names_dk   = ("D_K","DK","DK_by_head","D_K_by_head","d_k","dk_grid")

    grids = {}
    for key in grid_names_mean:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_mean"] = globals()[key]; break
    for key in grid_names_std:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_std"] = globals()[key]; break
    for key in grid_names_n:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_n"] = globals()[key]; break
    for key in grid_names_dk:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["D_K"] = globals()[key]; break

    if set(grids) == {"F1_mean","F1_std","F1_n","D_K"}:
        def _stack(df, out_name):
            tmp = df.copy()
            tmp.columns.name = "h"
            long = tmp.stack(dropna=False).rename(out_name).reset_index()
            long["h"] = long["h"].astype(int)
            return long

        df = _stack(grids["F1_mean"], "F1_mean") \
            .merge(_stack(grids["F1_std"], "F1_std"), on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["F1_n"],   "F1_n"),   on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["D_K"],    "D_K"),    on=["m","d_model","h"], how="left")

        df["m"] = df["m"].astype(int)
        df["d_model"] = df["d_model"].astype(int)
        df["h"] = df["h"].astype(int)
        df["D_K"] = df["D_K"].astype(float)
        df["F1_mean"] = df["F1_mean"].astype(float)
        df["F1_std"] = df["F1_std"].astype(float)
        df["F1_n"] = df["F1_n"].astype(int)
        return df[["m","d_model","h","D_K","F1_mean","F1_std","F1_n"]].copy()

    return None

per_head = _coerce_per_head_df()
if per_head is None:
    raise RuntimeError(
        "Could not find per-head F1 stats. Provide a DataFrame `per_head` with columns "
        "['m','d_model','h','D_K','F1_mean','F1_std','F1_n'], or adjust the auto-detection."
    )

# ---------------------------------------
# Welch t-test and DK window (same constants as scatter script)
# ---------------------------------------
def p_less_welch(mu_c, sd_c, n_c, mu_w, sd_w, n_w):
    v1 = (sd_c**2) / max(int(n_c), 1)
    v2 = (sd_w**2) / max(int(n_w), 1)
    se2 = v1 + v2
    if not np.isfinite(se2) or se2 <= 1e-16:
        return 1.0 if (mu_c >= mu_w) else 0.0

    t = (mu_c - mu_w) / math.sqrt(se2)
    df_den = 0.0
    if n_c > 1 and v1 > 0:
        df_den += (v1**2) / (n_c - 1)
    if n_w > 1 and v2 > 0:
        df_den += (v2**2) / (n_w - 1)
    df = (se2**2) / df_den if df_den > 0 else 1e9

    try:
        from scipy.stats import t as student_t
        p = float(student_t.cdf(t, df))
    except Exception:
        p = 0.5 * (1.0 + math.erf(t / math.sqrt(2.0)))
    return p

alpha  = 0.05  # 95%
dk_tol = 0.10  # ±10%

# ---------------------------------------
# Compute kept-head bounds per (m, d_model) for winners in min_rows
# ---------------------------------------
winners = (
    min_rows[["m","d_model","h"]]
    .dropna()
    .astype({"m": int, "d_model": int, "h": int})
    .copy()
)

bounds = []
for _, r in winners.iterrows():
    m_val, d_val, h_star = int(r["m"]), int(r["d_model"]), int(r["h"])
    sub = per_head[(per_head["m"]==m_val) & (per_head["d_model"]==d_val)].copy()
    if sub.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    win = sub[sub["h"]==h_star]
    if win.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    mu_w = float(win["F1_mean"].iloc[0])
    sd_w = float(win["F1_std"].iloc[0])
    n_w  = int(win["F1_n"].iloc[0])
    DK_w = float(win["D_K"].iloc[0])

    lo, hi = (1.0 - dk_tol) * DK_w, (1.0 + dk_tol) * DK_w
    near = sub[(sub["D_K"] >= lo) & (sub["D_K"] <= hi)].copy()

    if near.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    keep = []
    for _, c in near.iterrows():
        p_less = p_less_welch(
            float(c["F1_mean"]), float(c["F1_std"]), int(c["F1_n"]),
            mu_w, sd_w, n_w
        )
        keep.append(p_less >= alpha)

    kept = near.loc[keep]
    if kept.empty:
        h_lo = h_hi = h_star
    else:
        h_lo = int(np.nanmin(kept["h"].to_numpy()))
        h_hi = int(np.nanmax(kept["h"].to_numpy()))
        h_lo = min(h_lo, h_star)
        h_hi = max(h_hi, h_star)

    bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_lo, "h_upper": h_hi})

bounds_df = pd.DataFrame(bounds)

# Merge bounds with winners and build asymmetric yerr
sc_df_ci = winners.merge(bounds_df, on=["m","d_model"], how="left")
sc_df_ci["yerr_low"]  = (sc_df_ci["h"] - sc_df_ci["h_lower"]).clip(lower=0).astype(float)
sc_df_ci["yerr_high"] = (sc_df_ci["h_upper"] - sc_df_ci["h"]).clip(lower=0).astype(float)

# Pivot y and yerr to the same (m x d_model) layout as the line plot
h_grid = min_rows.pivot(index="m", columns="d_model", values="h").sort_index().sort_index(axis=1)
# Align x-axis with heatmap's d_model order for consistent comparison:
h_grid = h_grid.reindex(columns=d_values)

yerr_low_grid  = sc_df_ci.pivot(index="m", columns="d_model", values="yerr_low").sort_index().reindex(columns=d_values)
yerr_high_grid = sc_df_ci.pivot(index="m", columns="d_model", values="yerr_high").sort_index().reindex(columns=d_values)

# ---------------------------------------
# Figure and plotting (line + markers + error bars)
# ---------------------------------------
fig2_w = 1.2 + 0.9 * max(1, len(d_values))
fig2_h = 4.8
fig2, ax2 = plt.subplots(figsize=(fig2_w, fig2_h))

xs = np.array([int(x) for x in d_values], dtype=float)

for m_val in h_grid.index:
    ys = h_grid.loc[m_val].astype(float).values  # NaNs break the line where no data
    # Draw the line + markers first
    (line_handle,) = ax2.plot(xs, ys, marker="o", linewidth=1.8, label=f"m={int(m_val)}")

    # Error bars: only at finite points, match line color
    if (m_val in yerr_low_grid.index) and (m_val in yerr_high_grid.index):
        ylo = yerr_low_grid.loc[m_val].astype(float).values
        yhi = yerr_high_grid.loc[m_val].astype(float).values

        mask = np.isfinite(ys) & np.isfinite(ylo) & np.isfinite(yhi)
        if np.any(mask):
            x_pts  = xs[mask]
            y_pts  = ys[mask]
            yerr_l = np.maximum(0.0, ylo[mask])
            yerr_h = np.maximum(0.0, yhi[mask])

            # On a log-y axis, ensure the lower end stays > 0 (conservative epsilon guard)
            eps = 1e-12
            yerr_l = np.minimum(yerr_l, np.maximum(0.0, y_pts - eps))

            ax2.errorbar(
                x_pts, y_pts,
                yerr=np.vstack([yerr_l, yerr_h]),
                fmt="none",
                ecolor=line_handle.get_color(),
                elinewidth=0.9,
                capsize=3,
                alpha=0.85,
#                zorder=1.5,
                zorder=line_handle.get_zorder() + 1
            )

ax2.set_xlabel(r"$d_{\mathrm{model}}$")
ax2.set_ylabel("number of heads")
ax2.set_title(r"Heads producing minimum $D_K$")
ax2.grid(alpha=0.3, linewidth=0.6)

# Log scales for both axes (base 2 for x to match your original)
ax2.set_yscale("log")
ax2.set_xscale("log", base=2)

# --- helper: powers-of-two formatter ---
def _pow2_label(y, _):
    if y <= 0:
        return ""
    k = np.log2(y)
    if np.isclose(k, round(k)):
        return fr"$2^{{{int(round(k))}}}$"
    return ""

# Compute power-of-two major ticks that span the Y data range (no clipping)
all_h = pd.Series(h_grid.values.ravel(), dtype="float64").dropna()
all_h = all_h[all_h > 0]
if not all_h.empty:
    k_min_y = int(np.floor(np.log2(all_h.min())))
    k_max_y = int(np.ceil(np.log2(all_h.max())))
    pow2_ticks_y = [2**k for k in range(k_min_y, k_max_y + 1)]
else:
    pow2_ticks_y = [1, 2, 4, 8, 16, 32]

ax2.yaxis.set_major_locator(mticker.FixedLocator(pow2_ticks_y))
ax2.yaxis.set_minor_locator(mticker.NullLocator())
ax2.yaxis.set_major_formatter(mticker.FuncFormatter(_pow2_label))

# Compute power-of-two major ticks that span the X data range (no clipping)
xs_pos = xs[xs > 0]
if xs_pos.size > 0:
    k_min_x = int(np.floor(np.log2(xs_pos.min())))
    k_max_x = int(np.ceil(np.log2(xs_pos.max())))
    pow2_ticks_x = [2**k for k in range(k_min_x, k_max_x + 1)]
else:
    pow2_ticks_x = [1, 2, 4, 8, 16, 32]

ax2.xaxis.set_major_locator(mticker.FixedLocator(pow2_ticks_x))
ax2.xaxis.set_minor_locator(mticker.NullLocator())
ax2.xaxis.set_major_formatter(mticker.FuncFormatter(_pow2_label))

# Legend placement adapts to the number of m's
if len(h_grid.index) <= 10:
    ax2.legend(title="m", fontsize=9)
else:
    ax2.legend(title="m", fontsize=8, bbox_to_anchor=(1.02, 1.0), loc="upper left", borderaxespad=0.)

plt.tight_layout()

# === Save line plot explicitly to Drive ===
lineplot_path = os.path.join(TARGET_DIR, "h_that_produced_min_DK_lineplot.png")
fig2.savefig(lineplot_path, dpi=150, bbox_inches="tight")
print("Saved line plot (with error bars) to:", lineplot_path, "| Exists?", os.path.exists(lineplot_path))
plt.show()
plt.close(fig2)



# --------------------------------------------------------------------------------
# Scatter: heads vs (m / d_model) with global linear fit + confidence bars
# Confidence bars: for each (m, d_model), consider heads whose DK is within ±10% of the
# chosen optimal DK. Keep heads that cannot be eliminated by a one-sided Welch t-test
# vs. the winner at alpha=0.05 (i.e., p >= 0.05 for H1: mu_c < mu_w). Plot the min/max
# kept head as vertical error bars around the winner's head.
# --------------------------------------------------------------------------------

import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ---------------------------------------
# Pre-flight checks and light conveniences
# ---------------------------------------
if "TARGET_DIR" not in globals():
    TARGET_DIR = "."

if "min_rows" not in globals() or not isinstance(min_rows, pd.DataFrame):
    raise RuntimeError(
        "This script expects a DataFrame `min_rows` with columns ['m','d_model','h'] "
        "containing the chosen head per (m, d_model)."
    )

# ---------------------------------------
# Recover per-m colors from an earlier line plot (ax2), if available
# ---------------------------------------
color_by_m = {}
try:
    for line in ax2.get_lines():
        lbl = line.get_label()
        if isinstance(lbl, str) and lbl.startswith("m="):
            try:
                mm = int(lbl.split("=")[1])
                color_by_m[mm] = line.get_color()
            except Exception:
                pass
except NameError:
    # If ax2 isn't defined, we'll fall back later
    color_by_m = {}

# ---------------------------------------
# Build the scatter base table from the chosen rows
# ---------------------------------------
sc_df = (
    min_rows[["m", "d_model", "h"]]
    .dropna()
    .copy()
)
sc_df["m"] = sc_df["m"].astype(int)
sc_df["d_model"] = sc_df["d_model"].astype(int)
sc_df["h"] = sc_df["h"].astype(int)
sc_df["ratio_m_over_d"] = sc_df["m"].astype(float) / sc_df["d_model"].astype(float)

# ---------------------------------------------------
# Locate / construct a tidy per-head table with stats
# Required columns: ['m','d_model','h','D_K','F1_mean','F1_std','F1_n']
# ---------------------------------------------------
def _coerce_per_head_df():
    # 1) If a tidy table already exists under common names, normalize its columns.
    tidy_candidates = (
        "per_head", "stats_long", "triplet_df_full", "all_stats", "summary_long", "f1_stats_by_head"
    )
    for name in tidy_candidates:
        if name in globals():
            obj = globals()[name]
            if isinstance(obj, pd.DataFrame):
                df = obj.copy()
                rename = {}
                for c in df.columns:
                    lc = str(c).lower()
                    if lc in ("d_k", "dk"):        rename[c] = "D_K"
                    if lc in ("f1_mean", "f1_mu"): rename[c] = "F1_mean"
                    if lc in ("f1_std", "f1_sigma","std_f1"): rename[c] = "F1_std"
                    if lc in ("f1_n", "n_f1", "count_f1", "n"): rename[c] = "F1_n"
                df = df.rename(columns=rename)
                need = {"m","d_model","h","D_K","F1_mean","F1_std","F1_n"}
                if need.issubset(set(df.columns)):
                    out = df[list(need)].copy()
                    # enforce types
                    out["m"] = out["m"].astype(int)
                    out["d_model"] = out["d_model"].astype(int)
                    out["h"] = out["h"].astype(int)
                    out["D_K"] = out["D_K"].astype(float)
                    out["F1_mean"] = out["F1_mean"].astype(float)
                    out["F1_std"] = out["F1_std"].astype(float)
                    out["F1_n"] = out["F1_n"].astype(int)
                    return out

    # 2) Try to rebuild from pivot grids commonly used in earlier steps:
    #    dataframes with index = (m, d_model), columns = heads
    grid_names_mean = ("F1_mean","f1_mean","F1_mu","f1_mu","mean_f1")
    grid_names_std  = ("F1_std","f1_std","F1_sigma","f1_sigma","std_f1")
    grid_names_n    = ("F1_n","f1_n","count_f1","n_f1","N_f1")
    grid_names_dk   = ("D_K","DK","DK_by_head","D_K_by_head","d_k","dk_grid")

    grids = {}
    for key in grid_names_mean:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_mean"] = globals()[key]
            break
    for key in grid_names_std:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_std"] = globals()[key]
            break
    for key in grid_names_n:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_n"] = globals()[key]
            break
    for key in grid_names_dk:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["D_K"] = globals()[key]
            break

    if set(grids) == {"F1_mean","F1_std","F1_n","D_K"}:
        def _stack(df, out_name):
            tmp = df.copy()
            tmp.columns.name = "h"  # ensure stacked column is named 'h'
            long = tmp.stack(dropna=False).rename(out_name).reset_index()
            # expected columns: ['m','d_model','h', out_name]
            long["h"] = long["h"].astype(int)
            return long

        df = _stack(grids["F1_mean"], "F1_mean") \
            .merge(_stack(grids["F1_std"], "F1_std"), on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["F1_n"],   "F1_n"),   on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["D_K"],    "D_K"),    on=["m","d_model","h"], how="left")

        df["m"] = df["m"].astype(int)
        df["d_model"] = df["d_model"].astype(int)
        df["h"] = df["h"].astype(int)
        df["D_K"] = df["D_K"].astype(float)
        df["F1_mean"] = df["F1_mean"].astype(float)
        df["F1_std"] = df["F1_std"].astype(float)
        df["F1_n"] = df["F1_n"].astype(int)
        return df[["m","d_model","h","D_K","F1_mean","F1_std","F1_n"]].copy()

    return None

per_head = _coerce_per_head_df()
if per_head is None:
    raise RuntimeError(
        "Could not find per-head F1 stats. Provide a DataFrame `per_head` with columns "
        "['m','d_model','h','D_K','F1_mean','F1_std','F1_n'], or adjust the auto-detection."
    )

# ---------------------------------------
# One-sided Welch t-test (candidate vs winner): H1: mu_c < mu_w
# Keep a candidate if p >= alpha (cannot eliminate at 95%)
# ---------------------------------------
def p_less_welch(mu_c, sd_c, n_c, mu_w, sd_w, n_w):
    v1 = (sd_c**2) / max(int(n_c), 1)
    v2 = (sd_w**2) / max(int(n_w), 1)
    se2 = v1 + v2
    if not np.isfinite(se2) or se2 <= 1e-16:
        # degenerate variance: conservative fallback based on means
        return 1.0 if (mu_c >= mu_w) else 0.0

    t = (mu_c - mu_w) / math.sqrt(se2)
    # Welch–Satterthwaite df
    df_den = 0.0
    if n_c > 1 and v1 > 0:
        df_den += (v1**2) / (n_c - 1)
    if n_w > 1 and v2 > 0:
        df_den += (v2**2) / (n_w - 1)
    df = (se2**2) / df_den if df_den > 0 else 1e9  # large df => normal approx

    # Student-t CDF if SciPy is available, else normal approx
    try:
        from scipy.stats import t as student_t
        p = float(student_t.cdf(t, df))  # one-sided: P(T <= t)
    except Exception:
        p = 0.5 * (1.0 + math.erf(t / math.sqrt(2.0)))
    return p

alpha  = 0.05  # 95% rule
dk_tol = 0.10  # ±10% window on DK around the winner's DK

bounds = []
winners = sc_df[["m","d_model","h"]].dropna().astype(int)

for _, r in winners.iterrows():
    m_val, d_val, h_star = int(r["m"]), int(r["d_model"]), int(r["h"])
    sub = per_head[(per_head["m"]==m_val) & (per_head["d_model"]==d_val)].copy()
    if sub.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    win = sub[sub["h"]==h_star]
    if win.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    mu_w = float(win["F1_mean"].iloc[0])
    sd_w = float(win["F1_std"].iloc[0])
    n_w  = int(win["F1_n"].iloc[0])
    DK_w = float(win["D_K"].iloc[0])

    # candidates with DK within ±10% of the winner's DK
    lo, hi = (1.0 - dk_tol) * DK_w, (1.0 + dk_tol) * DK_w
    near = sub[(sub["D_K"] >= lo) & (sub["D_K"] <= hi)].copy()
    if near.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    keep_mask = []
    for _, c in near.iterrows():
        p_less = p_less_welch(
            float(c["F1_mean"]), float(c["F1_std"]), int(c["F1_n"]),
            mu_w, sd_w, n_w
        )
        # Keep if NOT significantly worse than winner at alpha
        keep_mask.append(p_less >= alpha)

    kept = near[keep_mask]

    if kept.empty:
        h_lo = h_hi = h_star
    else:
        h_lo = int(np.nanmin(kept["h"].to_numpy()))
        h_hi = int(np.nanmax(kept["h"].to_numpy()))
        # ensure the winner lies in range
        h_lo = min(h_lo, h_star)
        h_hi = max(h_hi, h_star)

    bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_lo, "h_upper": h_hi})

bounds_df = pd.DataFrame(bounds)

# ---------------------------------------
# Merge bounds into the scatter table and compute vertical errors in "heads"
# ---------------------------------------
sc_df_ci = sc_df.merge(bounds_df, on=["m","d_model"], how="left")
sc_df_ci["yerr_low"]  = (sc_df_ci["h"] - sc_df_ci["h_lower"]).clip(lower=0).astype(float)
sc_df_ci["yerr_high"] = (sc_df_ci["h_upper"] - sc_df_ci["h"]).clip(lower=0).astype(float)

# ---------------------------------------
# Fallback color palette if we couldn't read colors from ax2
# ---------------------------------------
if not color_by_m:
    cycle_colors = plt.rcParams["axes.prop_cycle"].by_key().get("color", [])
    if not cycle_colors:
        cycle_colors = [f"C{i}" for i in range(10)]
    m_unique = sorted(int(v) for v in sc_df_ci["m"].unique())
    color_by_m = {m_val: cycle_colors[i % len(cycle_colors)] for i, m_val in enumerate(m_unique)}

# ---------------------------------------
# Figure and plotting
# ---------------------------------------
fig_sc, ax_sc = plt.subplots(figsize=(5.0, 5.0))

# Error bars first so markers sit on top
for m_val, grp in sc_df_ci.groupby("m", sort=True):
    x_pts = grp["ratio_m_over_d"].astype(float).values
    y_pts = grp["h"].astype(float).values
    yerr  = np.vstack([
        grp["yerr_low"].fillna(0).to_numpy(dtype=float),
        grp["yerr_high"].fillna(0).to_numpy(dtype=float)
    ])
    e_kwargs = dict(fmt="none", elinewidth=0.9, capsize=3, alpha=0.85, zorder=1.5)
    if int(m_val) in color_by_m:
        e_kwargs["ecolor"] = color_by_m[int(m_val)]
    ax_sc.errorbar(x_pts, y_pts, yerr=yerr, **e_kwargs)

# Scatter points, color-coded by m
for m_val, grp in sc_df_ci.groupby("m", sort=True):
    pts_x = grp["ratio_m_over_d"].astype(float).values
    pts_y = grp["h"].astype(float).values
    kwargs = dict(s=46, alpha=0.9, label=f"m={int(m_val)}", zorder=2)
    if int(m_val) in color_by_m:
        kwargs["color"] = color_by_m[int(m_val)]
    ax_sc.scatter(pts_x, pts_y, **kwargs)

# Global linear fit: h ≈ a * (m/d_model) + b
x_all = sc_df_ci["ratio_m_over_d"].to_numpy(dtype=float)
y_all = sc_df_ci["h"].to_numpy(dtype=float)
mask = np.isfinite(x_all) & np.isfinite(y_all)
x_all, y_all = x_all[mask], y_all[mask]

fit_done = False
if x_all.size >= 2 and (x_all.max() - x_all.min()) > 1e-12:
    a, b = np.polyfit(x_all, y_all, 1)
    xs_fit = np.linspace(float(x_all.min()), float(x_all.max()), 200)
    ys_fit = a * xs_fit + b

    # R^2
    y_pred = a * x_all + b
    ss_res = float(np.sum((y_all - y_pred) ** 2))
    ss_tot = float(np.sum((y_all - np.mean(y_all)) ** 2))
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan")

    ax_sc.plot(
        xs_fit, ys_fit,
        linestyle="--", linewidth=1.0, color="black",
        label=f"y = {a:.3g} x + {b:.3g}"  #  (R²={r2:.3f})
    )
    fit_done = True
else:
    print("Skipped linear fit: insufficient x-range or too few points.")
    a = b = r2 = float("nan")  # ensure defined for printing below

# Axes, labels, styling (both axes linear)
ax_sc.set_xlabel(r"$m / d_{\mathrm{model}}$")
ax_sc.set_ylabel("Heads producing min DK")
ax_sc.set_title(r"Heads vs $m / d_{\mathrm{model}}$")
ax_sc.grid(alpha=0.3, linewidth=0.6)
ax_sc.set_xscale("linear")
ax_sc.set_yscale("linear")

# Nice bounds with a little padding
if x_all.size:
    xr = float(x_all.max() - x_all.min())
    if xr <= 0:
        ax_sc.set_xlim(float(x_all.min()) - 0.5, float(x_all.max()) + 0.5)
    else:
        ax_sc.set_xlim(float(x_all.min()) - 0.05 * xr, float(x_all.max()) + 0.05 * xr)
if y_all.size:
    yr_min, yr_max = float(np.min(y_all)), float(np.max(y_all))
    if yr_max > yr_min:
        pad = 0.05 * (yr_max - yr_min)
        ax_sc.set_ylim(max(0, yr_min - pad), yr_max + pad)

# Legend placement similar to the line plot
if sc_df_ci["m"].nunique() <= 10:
    ax_sc.legend(title="m", fontsize=9)
else:
    ax_sc.legend(title="m", fontsize=8, bbox_to_anchor=(1.02, 1.0),
                 loc="upper left", borderaxespad=0.)

plt.tight_layout()

# Save + show
scatter_fit_path = os.path.join(TARGET_DIR, "heads_vs_m_over_d_scatter_linear_fit.png")
fig_sc.savefig(scatter_fit_path, dpi=150, bbox_inches="tight")
print("Saved scatter+fit+CI to:", scatter_fit_path, "| Exists?", os.path.exists(scatter_fit_path))
if fit_done:
    print(f"y = {a:.3g} x + {b:.3g} (R²={r2:.3f})")
plt.show()
plt.close(fig_sc)




# --------------------------------------------------------------------------------
# PLOT — Scatter: min DK vs (m * log m / d_model) + global linear fit + CI error bars
# Uses the same per‑m colors as the two graphs immediately before (via color_by_m)
# Confidence intervals are the same DK bounds used for the table (DK_lower, DK_upper)
# --------------------------------------------------------------------------------

# Reuse color_by_m if it exists; otherwise, try to rebuild it from ax2 (line plot).
if 'color_by_m' not in locals() or not isinstance(color_by_m, dict) or not color_by_m:
    color_by_m = {}
    try:
        for line in ax2.get_lines():
            lbl = line.get_label()
            if isinstance(lbl, str) and lbl.startswith("m="):
                mm = int(lbl.split("=")[1])
                color_by_m[mm] = line.get_color()
    except Exception:
        pass

# Prepare the (m, d_model, D_K) table and x = m log m / d_model
sc2_df = (
    min_rows[["m", "d_model", "D_K"]]
    .dropna()
    .copy()
)
sc2_df["x_metric"] = (
    sc2_df["m"].astype(float) * np.log2(sc2_df["m"].astype(float))
    / sc2_df["d_model"].astype(float)
)

# Deterministic fallback mapping if we couldn't recover colors from ax2
if not color_by_m:
    cycle_colors = plt.rcParams["axes.prop_cycle"].by_key().get("color", [])
    if not cycle_colors:
        cycle_colors = [f"C{i}" for i in range(10)]
    m_unique = sorted(int(v) for v in sc2_df["m"].unique())
    color_by_m = {m_val: cycle_colors[i % len(cycle_colors)] for i, m_val in enumerate(m_unique)}

# --- NEW: bring in DK CI bounds computed for the table (same logic, no recomputation) ---
# Prefer using the tidy table you just wrote; fallback to the pivot grids if needed.
if ("triplet_df" in locals()) and {"m","d_model","DK_upper","DK_lower"}.issubset(set(triplet_df.columns)):
    tmp_trip = triplet_df[["m","d_model","DK_upper","DK_lower"]].copy()
else:
    # Fallback: reconstruct from the already-built grids
    mean_long_f  = vals.stack(dropna=False).rename("DK_mean").reset_index()
    upper_long_f = dk_upper.stack(dropna=False).rename("DK_upper").reset_index()
    lower_long_f = dk_lower.stack(dropna=False).rename("DK_lower").reset_index()
    tmp_trip = (mean_long_f
                .merge(upper_long_f, on=["m","d_model"], how="left")
                .merge(lower_long_f, on=["m","d_model"], how="left"))[["m","d_model","DK_upper","DK_lower"]]

# Ensure join keys line up in type/format
sc2_df_ci = sc2_df.copy()
sc2_df_ci["m"] = sc2_df_ci["m"].astype(int)
sc2_df_ci["d_model"] = sc2_df_ci["d_model"].astype(int)

tmp_trip = tmp_trip.copy()
tmp_trip["m"] = tmp_trip["m"].astype(int)
tmp_trip["d_model"] = tmp_trip["d_model"].astype(int)

# Merge DK bounds into the scatter rows
sc2_df_ci = sc2_df_ci.merge(tmp_trip, on=["m","d_model"], how="left")

# Compute asymmetric vertical errors:
# bottom = DK_mean - DK_lower (optimistic); top = DK_upper - DK_mean (pessimistic)
sc2_df_ci["yerr_low"]  = (sc2_df_ci["D_K"] - sc2_df_ci["DK_lower"]).astype(float)
sc2_df_ci["yerr_high"] = (sc2_df_ci["DK_upper"] - sc2_df_ci["D_K"]).astype(float)

# sanitize: nonnegative, finite only
for col in ["yerr_low", "yerr_high"]:
    arr = sc2_df_ci[col].to_numpy(dtype=float)
    arr[~np.isfinite(arr)] = np.nan
    arr[arr < 0] = np.nan
    sc2_df_ci[col] = arr

# --- Plot ---
fig_sc2, ax_sc2 = plt.subplots(figsize=(5.6, 5.0))

# Draw error bars first (slightly lower zorder) so markers sit on top
for m_val, grp in sc2_df_ci.groupby("m", sort=True):
    x_pts = grp["x_metric"].astype(float).values
    y_pts = grp["D_K"].astype(float).values
    yerr  = np.vstack([grp["yerr_low"].values, grp["yerr_high"].values])  # [2 x N]

    e_kwargs = dict(fmt="none", elinewidth=0.9, capsize=3, alpha=0.85, zorder=1.5)
    if int(m_val) in color_by_m:
        e_kwargs["ecolor"] = color_by_m[int(m_val)]
    ax_sc2.errorbar(x_pts, y_pts, yerr=yerr, **e_kwargs)

# Scatter the points grouped by m, using the shared color scheme (unchanged)
for m_val, grp in sc2_df_ci.groupby("m", sort=True):
    x_pts = grp["x_metric"].astype(float).values
    y_pts = grp["D_K"].astype(float).values
    kwargs = dict(s=46, alpha=0.9, label=f"m={int(m_val)}", zorder=2)
    if int(m_val) in color_by_m:
        kwargs["color"] = color_by_m[int(m_val)]
    ax_sc2.scatter(x_pts, y_pts, **kwargs)

# Global linear fit constrained through the origin: D_K ≈ a * (m log m / d_model) (unchanged)
x = sc2_df_ci["x_metric"].to_numpy(dtype=float)
y = sc2_df_ci["D_K"].to_numpy(dtype=float)
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]

fit_done = False
if x.size >= 1:
    sxx = float(np.dot(x, x))
    if sxx > 0:
        a = float(np.dot(x, y) / sxx)  # OLS slope with intercept fixed at 0

        x_start = min(0.0, float(x.min()))
        x_end = float(x.max())
        xs_fit = np.linspace(x_start, x_end, 200)
        ys_fit = a * xs_fit

        # R^2 for a through-origin model: 1 - SSE / sum(y^2)
        y_pred = a * x
        ss_res = float(np.sum((y - y_pred) ** 2))
        ss_tot0 = float(np.sum(y ** 2))
        r2_0 = 1.0 - ss_res / ss_tot0 if ss_tot0 > 0 else float("nan")

        ax_sc2.plot(
            xs_fit, ys_fit,
            linestyle="--", linewidth=1.0, color="black", zorder=1,
            label=f"y = {a:.3g} x"  #  (R²₀={r2_0:.3f})
        )
        fit_done = True

# Labels, title, and styling (unchanged)
ax_sc2.set_xlabel(r"$m \log m \,/\, d_{\mathrm{model}}$")
ax_sc2.set_ylabel(r"Minimum total key dimension $D_K$")
ax_sc2.set_title(r"Min $D_K$ vs $m \log m / d_{\mathrm{model}}$")
ax_sc2.grid(alpha=0.3, linewidth=0.6)

# Legend placement matches earlier behavior (unchanged)
if sc2_df_ci["m"].nunique() <= 10:
    ax_sc2.legend(fontsize=9, title="m")
else:
    ax_sc2.legend(fontsize=8, title="m", bbox_to_anchor=(1.02, 1.0),
                  loc="upper left", borderaxespad=0.)

plt.tight_layout()

# Save + show 
scatter2_ci_path = os.path.join(TARGET_DIR, "min_DK_vs_m_log_m_over_d_scatter_linear_fit.png")
fig_sc2.savefig(scatter2_ci_path, dpi=150, bbox_inches="tight")
if fit_done:
    print(f"Saved scatter+fit+CI to: {scatter2_ci_path} | Exists? {os.path.exists(scatter2_ci_path)}")
    print(f"y = {a:.3g} x (R²₀ = {r2_0:.3f})")
else:
    print(f"Saved scatter+CI (no fit drawn) to: {scatter2_ci_path} | Exists? {os.path.exists(scatter2_ci_path)}")
    print("Skipped linear fit through origin: insufficient variation or zero range in x.")
plt.show()
plt.close(fig_sc2)



# =============================================================================
# DIAGONAL SWEEP PLOTS — with error intervals on the bars (heads)
# Colored so the line+its CI and bars+their CI are easy to distinguish.
#
# One chart per ratio in DIAGONAL_RATIOS:
#   • Left y-axis (line): minimum DK (with DK CI if available)
#   • Right y-axis (bars): selected heads h, with asymmetric CI error bars
# X-axis: the (m, d_model) pairs on the diagonal, increasing m from left to right
# =============================================================================

import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

# ----------------------------
# Pre-flight & light defaults
# ----------------------------
if "TARGET_DIR" not in globals():
    TARGET_DIR = "."

# ---- Colors (choose anything you like) ----
LINE_COLOR      = globals().get("LINE_COLOR", "#1f77b4")  # line + its error bars (tab:blue)
LINE_ERR_COLOR  = globals().get("LINE_ERR_COLOR", LINE_COLOR)

BAR_FACE_COLOR  = globals().get("BAR_FACE_COLOR", "#ff7f0e")  # bars (tab:orange)
BAR_EDGE_COLOR  = globals().get("BAR_EDGE_COLOR", "#ad5a0a")  # bar edges
BAR_ERR_COLOR   = globals().get("BAR_ERR_COLOR", BAR_EDGE_COLOR)  # bar CIs

if "min_rows" not in globals() or not isinstance(min_rows, pd.DataFrame):
    raise RuntimeError(
        "This script expects a DataFrame `min_rows` with columns ['m','d_model','D_K','h']."
    )

# If DIAGONAL_RATIOS isn't provided, infer from the data you actually have in min_rows
if "DIAGONAL_RATIOS" not in globals() or not DIAGONAL_RATIOS:
    tmp = min_rows.dropna(subset=["m","d_model"]).copy()
    tmp["ratio"] = tmp["m"].astype(float) / tmp["d_model"].astype(float)
    # Keep unique ratios in ascending order
    DIAGONAL_RATIOS = sorted(float(r) for r in np.unique(tmp["ratio"].to_numpy(dtype=float)))

# Keep the selected min DK and its h
minDK_and_h = (
    min_rows[["m", "d_model", "D_K", "h"]]
    .rename(columns={"D_K": "min_DK", "h": "h_selected"})
    .copy()
)
minDK_and_h["m"] = minDK_and_h["m"].astype(int)
minDK_and_h["d_model"] = minDK_and_h["d_model"].astype(int)

# ---------------------------------------
# Helpers
# ---------------------------------------
def _pairs_on_ratio(df_pairs: pd.DataFrame, ratio: float) -> pd.DataFrame:
    """Return subset where (m, d_model) lie on the diagonal m/d_model == ratio."""
    r = float(ratio)
    df = df_pairs.copy()
    if float(ratio).is_integer():
        rr = int(round(r))
        mask = (df["m"] == rr * df["d_model"])
    else:
        mask = np.isclose(df["m"] / df["d_model"], r, rtol=1e-9, atol=1e-12)
    out = df.loc[mask].copy()
    out["m"] = out["m"].astype(int)
    out["d_model"] = out["d_model"].astype(int)
    return out

def _pow2_label(y, _):
    if y <= 0:
        return ""
    k = np.log2(y)
    if np.isclose(k, round(k)):
        return fr"$2^{{{int(round(k))}}}$"
    return ""

# For consistent ticks across plots: powers-of-two covering ALL chosen h
_all_h_selected = pd.Series(min_rows["h"], dtype="float64").dropna()
_all_h_selected = _all_h_selected[_all_h_selected > 0]
if not _all_h_selected.empty:
    k_min_h = int(np.floor(np.log2(_all_h_selected.min())))
    k_max_h = int(np.ceil(np.log2(_all_h_selected.max())))
    pow2_ticks_h = [2**k for k in range(k_min_h, k_max_h + 1)]
else:
    pow2_ticks_h = [1, 2, 4, 8, 16, 32]

# ---------------------------------------
# DK CI source for the line (optional)
# ---------------------------------------
diag_ci_src = None
if ("triplet_df" in globals()) and isinstance(triplet_df, pd.DataFrame) \
   and {"m","d_model","DK_upper","DK_lower"}.issubset(triplet_df.columns):
    diag_ci_src = triplet_df[["m","d_model","DK_upper","DK_lower"]].copy()
elif ("tmp_trip" in globals()) and isinstance(tmp_trip, pd.DataFrame) \
     and {"m","d_model","DK_upper","DK_lower"}.issubset(tmp_trip.columns):
    diag_ci_src = tmp_trip[["m","d_model","DK_upper","DK_lower"]].copy()
else:
    # Fallback: reconstruct from pivot grids if available
    if all(name in globals() and isinstance(globals()[name], pd.DataFrame)
           for name in ("vals","dk_upper","dk_lower")):
        mean_long_f  = vals.stack(dropna=False).rename("DK_mean").reset_index()
        upper_long_f = dk_upper.stack(dropna=False).rename("DK_upper").reset_index()
        lower_long_f = dk_lower.stack(dropna=False).rename("DK_lower").reset_index()
        diag_ci_src = (
            mean_long_f
            .merge(upper_long_f, on=["m","d_model"], how="left")
            .merge(lower_long_f, on=["m","d_model"], how="left")
        )[["m","d_model","DK_upper","DK_lower"]].copy()

if diag_ci_src is not None:
    diag_ci_src["m"] = diag_ci_src["m"].astype(int)
    diag_ci_src["d_model"] = diag_ci_src["d_model"].astype(int)

# ---------------------------------------
# Build/locate tidy per-head F1 stats to compute head-range CIs (bars)
# Required columns: ['m','d_model','h','D_K','F1_mean','F1_std','F1_n']
# ---------------------------------------
def _coerce_per_head_df():
    # 1) If a tidy table already exists under common names, normalize its columns.
    tidy_candidates = (
        "per_head", "stats_long", "triplet_df_full", "all_stats", "summary_long", "f1_stats_by_head"
    )
    for name in tidy_candidates:
        if name in globals():
            obj = globals()[name]
            if isinstance(obj, pd.DataFrame):
                df = obj.copy()
                rename = {}
                for c in df.columns:
                    lc = str(c).lower()
                    if lc in ("d_k", "dk"):        rename[c] = "D_K"
                    if lc in ("f1_mean", "f1_mu"): rename[c] = "F1_mean"
                    if lc in ("f1_std", "f1_sigma","std_f1"): rename[c] = "F1_std"
                    if lc in ("f1_n", "n_f1", "count_f1", "n"): rename[c] = "F1_n"
                df = df.rename(columns=rename)
                need = {"m","d_model","h","D_K","F1_mean","F1_std","F1_n"}
                if need.issubset(set(df.columns)):
                    out = df[list(need)].copy()
                    # enforce types
                    out["m"] = out["m"].astype(int)
                    out["d_model"] = out["d_model"].astype(int)
                    out["h"] = out["h"].astype(int)
                    out["D_K"] = out["D_K"].astype(float)
                    out["F1_mean"] = out["F1_mean"].astype(float)
                    out["F1_std"] = out["F1_std"].astype(float)
                    out["F1_n"] = out["F1_n"].astype(int)
                    return out

    # 2) Else rebuild from per-head pivot grids if present
    grid_names_mean = ("F1_mean","f1_mean","F1_mu","f1_mu","mean_f1")
    grid_names_std  = ("F1_std","f1_std","F1_sigma","f1_sigma","std_f1")
    grid_names_n    = ("F1_n","f1_n","count_f1","n_f1","N_f1")
    grid_names_dk   = ("D_K","DK","DK_by_head","D_K_by_head","d_k","dk_grid")

    grids = {}
    for key in grid_names_mean:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_mean"] = globals()[key]; break
    for key in grid_names_std:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_std"] = globals()[key]; break
    for key in grid_names_n:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["F1_n"] = globals()[key]; break
    for key in grid_names_dk:
        if key in globals() and isinstance(globals()[key], pd.DataFrame):
            grids["D_K"] = globals()[key]; break

    if set(grids) == {"F1_mean","F1_std","F1_n","D_K"}:
        def _stack(df, out_name):
            tmp = df.copy()
            tmp.columns.name = "h"  # ensure stacked column is 'h'
            long = tmp.stack(dropna=False).rename(out_name).reset_index()
            long["h"] = long["h"].astype(int)
            return long

        df = _stack(grids["F1_mean"], "F1_mean") \
            .merge(_stack(grids["F1_std"], "F1_std"), on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["F1_n"],   "F1_n"),   on=["m","d_model","h"], how="left") \
            .merge(_stack(grids["D_K"],    "D_K"),    on=["m","d_model","h"], how="left")

        df["m"] = df["m"].astype(int)
        df["d_model"] = df["d_model"].astype(int)
        df["h"] = df["h"].astype(int)
        df["D_K"] = df["D_K"].astype(float)
        df["F1_mean"] = df["F1_mean"].astype(float)
        df["F1_std"] = df["F1_std"].astype(float)
        df["F1_n"] = df["F1_n"].astype(int)
        return df[["m","d_model","h","D_K","F1_mean","F1_std","F1_n"]].copy()

    return None

per_head = _coerce_per_head_df()
if per_head is None:
    raise RuntimeError(
        "Could not find per-head F1 stats. Provide `per_head` with columns "
        "['m','d_model','h','D_K','F1_mean','F1_std','F1_n'], or the pivot grids."
    )

# ---------------------------------------
# Welch one-sided t-test (candidate vs winner): H1: mu_c < mu_w
# Keep a candidate if p >= alpha (cannot eliminate at 95%)
# ---------------------------------------
def p_less_welch(mu_c, sd_c, n_c, mu_w, sd_w, n_w):
    v1 = (sd_c**2) / max(int(n_c), 1)
    v2 = (sd_w**2) / max(int(n_w), 1)
    se2 = v1 + v2
    if not np.isfinite(se2) or se2 <= 1e-16:
        # Degenerate variance: conservative fallback based on means
        return 1.0 if (mu_c >= mu_w) else 0.0
    t = (mu_c - mu_w) / math.sqrt(se2)
    # Welch–Satterthwaite df
    df_den = 0.0
    if n_c > 1 and v1 > 0:
        df_den += (v1**2) / (n_c - 1)
    if n_w > 1 and v2 > 0:
        df_den += (v2**2) / (n_w - 1)
    df = (se2**2) / df_den if df_den > 0 else 1e9  # large df => normal approx
    try:
        from scipy.stats import t as student_t
        p = float(student_t.cdf(t, df))  # one-sided: P(T <= t)
    except Exception:
        # Normal approximation if SciPy is unavailable
        p = 0.5 * (1.0 + math.erf(t / math.sqrt(2.0)))
    return p

alpha  = 0.05  # 95% rule
dk_tol = 0.10  # ±10% window on DK around the winner's DK

# ---------------------------------------
# Compute per-(m,d_model) head-range bounds & asymmetric y-errors
# ---------------------------------------
winners = (
    min_rows[["m","d_model","h"]]
    .dropna()
    .astype(int)
    .rename(columns={"h":"h_selected"})
)

bounds = []
for _, r in winners.iterrows():
    m_val, d_val, h_star = int(r["m"]), int(r["d_model"]), int(r["h_selected"])
    sub = per_head[(per_head["m"]==m_val) & (per_head["d_model"]==d_val)].copy()
    if sub.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    win = sub[sub["h"]==h_star]
    if win.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    mu_w = float(win["F1_mean"].iloc[0])
    sd_w = float(win["F1_std"].iloc[0])
    n_w  = int(win["F1_n"].iloc[0])
    DK_w = float(win["D_K"].iloc[0])

    # Candidates with DK within ±10% of the winner's DK
    lo, hi = (1.0 - dk_tol) * DK_w, (1.0 + dk_tol) * DK_w
    near = sub[(sub["D_K"] >= lo) & (sub["D_K"] <= hi)].copy()
    if near.empty:
        bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_star, "h_upper": h_star})
        continue

    keep_mask = []
    for _, c in near.iterrows():
        p_less = p_less_welch(
            float(c["F1_mean"]), float(c["F1_std"]), int(c["F1_n"]),
            mu_w, sd_w, n_w
        )
        # Keep if NOT significantly worse than winner at alpha
        keep_mask.append(p_less >= alpha)

    kept = near.loc[keep_mask]
    if kept.empty:
        h_lo = h_hi = h_star
    else:
        h_lo = int(np.nanmin(kept["h"].to_numpy()))
        h_hi = int(np.nanmax(kept["h"].to_numpy()))
        # Ensure the winner lies within the range
        h_lo = min(h_lo, h_star)
        h_hi = max(h_hi, h_star)

    bounds.append({"m": m_val, "d_model": d_val, "h_lower": h_lo, "h_upper": h_hi})

bounds_df = pd.DataFrame(bounds)

# Tidy CI table with asymmetric y-errors relative to the winner head
h_ci_df = (
    winners.merge(bounds_df, on=["m","d_model"], how="left")
    .assign(
        yerr_low  = lambda df: (df["h_selected"] - df["h_lower"]).clip(lower=0).astype(float),
        yerr_high = lambda df: (df["h_upper"] - df["h_selected"]).clip(lower=0).astype(float),
    )
)

# =============================================================================
# Plot: one figure per diagonal ratio
# =============================================================================
for ratio in DIAGONAL_RATIOS:
    diag_df = _pairs_on_ratio(minDK_and_h, ratio)
    if diag_df.empty:
        print(f"(Diagonal ratio {ratio}) No (m, d_model) pairs found on this diagonal; skipping.")
        continue

    # Order by increasing m
    diag_df = diag_df.sort_values(["m", "d_model"]).reset_index(drop=True)

    # X-axis labels "(m,d)"
    x_labels = [f"({int(m)},{int(d)})" for m, d in zip(diag_df["m"], diag_df["d_model"])]
    x = np.arange(len(diag_df), dtype=float)

    # Figure size scales with #points
    fig_w = max(6.5, 1.0 + 0.9 * len(diag_df))
    fig_h = 5.2
    fig_diag, ax_left = plt.subplots(figsize=(fig_w, fig_h))

    # ---------------------------------------
    # Left axis — line: min DK (with DK CI if available)
    # ---------------------------------------
    ax_left.plot(
        x, diag_df["min_DK"].astype(float).values,
        marker="o", linewidth=1.8, label="Min $D_K$",
        color=LINE_COLOR, zorder=2.2
    )

    # Merge DK CI bounds for these points (if we have them)
    if diag_ci_src is not None:
        _diag_ci = diag_df.merge(diag_ci_src, on=["m","d_model"], how="left")
        y_vals = _diag_ci["min_DK"].astype(float).to_numpy()
        yerr_low  = (y_vals - _diag_ci["DK_lower"].astype(float).to_numpy())
        yerr_high = (_diag_ci["DK_upper"].astype(float).to_numpy() - y_vals)

        # sanitize: nonnegative, finite only
        yerr_low[~np.isfinite(yerr_low) | (yerr_low < 0)] = np.nan
        yerr_high[~np.isfinite(yerr_high) | (yerr_high < 0)] = np.nan

        ax_left.errorbar(
            x, y_vals, yerr=np.vstack([yerr_low, yerr_high]),
            fmt="none", elinewidth=1.0, capsize=3, alpha=0.95,
            ecolor=LINE_ERR_COLOR, zorder=2.1
        )

    # Axis labels/grid
    ax_left.set_ylabel("Minimum $D_K$")
    ax_left.set_xlabel(r"$(m, d_{\mathrm{model}})$")
    ax_left.grid(alpha=0.3, linewidth=0.6)
    ax_left.margins(y=0.2)

    # X ticks
    ax_left.set_xticks(x)
    ax_left.set_xticklabels(x_labels, rotation=45, ha="right")

    # ---------------------------------------
    # Right axis — bars: selected heads, with asymmetric CI error bars
    # ---------------------------------------
    ax_right = ax_left.twinx()
    bar_heights = diag_df["h_selected"].astype(float).values
    bars = ax_right.bar(
        x, bar_heights,
        alpha=0.35, width=0.8, label="Heads",
        color=BAR_FACE_COLOR, edgecolor=BAR_EDGE_COLOR, linewidth=1.0, zorder=1.0
    )
    ax_right.set_ylabel("Heads producing min $D_K$")
    ax_right.set_yscale("log")

    # NEW: add asymmetric CI error bars on the bars, from h_ci_df
    _diag_hci = diag_df.merge(
        h_ci_df[["m","d_model","yerr_low","yerr_high"]],
        on=["m","d_model"], how="left"
    )
    _yl = _diag_hci["yerr_low"].to_numpy(dtype=float)
    _yh = _diag_hci["yerr_high"].to_numpy(dtype=float)
    _yl[~np.isfinite(_yl)] = 0.0
    _yh[~np.isfinite(_yh)] = 0.0
    _yl[_yl < 0] = 0.0
    _yh[_yh < 0] = 0.0

    ax_right.errorbar(
        x,
        bar_heights,
        yerr=np.vstack([_yl, _yh]),
        fmt="none", elinewidth=1.0, capsize=3, alpha=0.95,
        ecolor=BAR_ERR_COLOR, zorder=1.5
    )

    # Pad the log scale by one power of 2 on each side of the data
    diag_h = diag_df["h_selected"].astype(float)
    diag_h = diag_h[diag_h > 0]
    if not diag_h.empty:
        kmin = int(np.floor(np.log2(diag_h.min())))
        kmax = int(np.ceil(np.log2(diag_h.max())))
        y_min = 2 ** (kmin - 1)
        y_max = 2 ** (kmax + 1)
        ax_right.set_ylim(y_min, y_max)

    # Major ticks only at global powers-of-two for consistency
    ax_right.yaxis.set_major_locator(mticker.FixedLocator(pow2_ticks_h))
    ax_right.yaxis.set_minor_locator(mticker.NullLocator())
    ax_right.yaxis.set_major_formatter(mticker.FuncFormatter(_pow2_label))

    # ---------------------------------------
    # Title + legend
    # ---------------------------------------
    ax_left.set_title(fr"Diagonal: $m/d_{{\mathrm{{model}}}} = {ratio}$")
    handles_left, labels_left = ax_left.get_legend_handles_labels()
    handles_right, labels_right = ax_right.get_legend_handles_labels()
    ax_left.legend(handles_left + handles_right, labels_left + labels_right, fontsize=9)

    plt.tight_layout()

    # Save + show
    ratio_str = str(ratio).replace(".", "p")
    out_name = f"diagonal_ratio_{ratio_str}_minDK_and_selected_h.png"
    out_path = os.path.join(TARGET_DIR, out_name)
    fig_diag.savefig(out_path, dpi=150, bbox_inches="tight")
    print(f"Saved diagonal plot for ratio {ratio}: {out_path} | Exists? {os.path.exists(out_path)}")

    plt.show()
    plt.close(fig_diag)



# --------------------------------------------------------------------------------
# ERROR-BAR SOURCES DISCOVERY + HELPERS (used by per-(m,d_model) figures)
# --------------------------------------------------------------------------------

# We try to be flexible with column names
_STD_COLS = ["microF1_std", "microF1_stddev"]
_SE_COLS  = ["microF1_stderr", "microF1_se", "microF1_sem"]
_CI_COLS  = ["microF1_ci", "microF1_ci95", "microF1_CI"]

STD_COL = next((c for c in _STD_COLS if c in unique_df.columns), None)
SE_COL  = next((c for c in _SE_COLS  if c in unique_df.columns), None)
CI_COL  = next((c for c in _CI_COLS  if c in unique_df.columns), None)

if not any([STD_COL, SE_COL, CI_COL]):
    print("Per-combo plots will NOT include error bars (no std/sem/ci column found).")
else:
    present = ", ".join([c for c in [STD_COL, SE_COL, CI_COL] if c])
    print(f"Per-combo plots: error-bar sources detected -> {present}")

# 95% two-sided critical values for Student's t at df = 1..30
_T_CRIT_95 = {
    1:12.706, 2:4.303, 3:3.182, 4:2.776, 5:2.571, 6:2.447, 7:2.365, 8:2.306, 9:2.262,
    10:2.228, 11:2.201, 12:2.179, 13:2.160, 14:2.145, 15:2.131, 16:2.120, 17:2.110,
    18:2.101, 19:2.093, 20:2.086, 21:2.080, 22:2.074, 23:2.069, 24:2.064, 25:2.060,
    26:2.056, 27:2.052, 28:2.048, 29:2.045, 30:2.042
}

def _tcrit_95_df(df: int) -> float:
    """Return two-sided 95% t critical for given degrees of freedom."""
    df = int(max(1, df))
    if df <= 30:
        return _T_CRIT_95[df]
    elif df <= 60:
        return 2.000  # reasonable midrange approximation
    else:
        return 1.960  # ~z for large df

def _sanitize_err(arr):
    if arr is None:
        return None
    a = np.asarray(arr, dtype=float)
    a[~np.isfinite(a)] = np.nan
    a[a < 0] = np.nan
    return a

def compute_errorbars_for_group(dh: pd.DataFrame, mode: str):
    """
    Returns (yerr, label_str) for the requested mode:
      - mode == "std"  -> ±1 SD across runs
      - mode == "ci95" -> 95% CI half-width for the mean
    Tries to derive from whichever of {STD_COL, SE_COL, CI_COL} is available.
    If 'runs' is missing and only STD is present, we cannot form a CI for the mean -> returns None for ci95.
    """
    n = dh["runs"].to_numpy(dtype=float) if "runs" in dh.columns else None

    std = dh[STD_COL].to_numpy(dtype=float) if (STD_COL and STD_COL in dh.columns) else None
    se  = dh[SE_COL ].to_numpy(dtype=float) if (SE_COL  and SE_COL  in dh.columns) else None
    ci  = dh[CI_COL ].to_numpy(dtype=float) if (CI_COL  and CI_COL  in dh.columns) else None

    yerr_std = None
    yerr_ci95 = None
    src_used = None

    # --- Build ±1 SD if possible ---
    if std is not None:
        yerr_std = std.copy()
        src_used = STD_COL
    elif se is not None and n is not None:
        yerr_std = se * np.sqrt(n)
        src_used = SE_COL
    elif ci is not None and n is not None:
        # SD ≈ CI * sqrt(n) / tcrit
        df_arr = np.maximum(1, n.astype(int) - 1)
        tcrit = np.array([_tcrit_95_df(int(df)) for df in df_arr])
        with np.errstate(divide="ignore", invalid="ignore"):
            yerr_std = np.where((tcrit > 0), ci * np.sqrt(n) / tcrit, np.nan)
        src_used = CI_COL

    # --- Build 95% CI half-width for the mean if possible ---
    if ci is not None:
        yerr_ci95 = ci.copy()
    elif se is not None:
        if n is not None:
            df_arr = np.maximum(1, n.astype(int) - 1)
            tcrit = np.array([_tcrit_95_df(int(df)) for df in df_arr])
            yerr_ci95 = tcrit * se
        else:
            # Without runs, fall back to z≈1.96 * SE (still a CI for the mean)
            yerr_ci95 = 1.96 * se
    elif std is not None and n is not None:
        df_arr = np.maximum(1, n.astype(int) - 1)
        tcrit = np.array([_tcrit_95_df(int(df)) for df in df_arr])
        with np.errstate(divide="ignore", invalid="ignore"):
            yerr_ci95 = np.where(n > 0, tcrit * std / np.sqrt(n), np.nan)
    else:
        # std is present but runs is unknown -> cannot compute a CI for the mean reliably
        yerr_ci95 = None

    yerr_std  = _sanitize_err(yerr_std)
    yerr_ci95 = _sanitize_err(yerr_ci95)

    if mode == "std":
        return (yerr_std, f"±1 SD (source={src_used})")
    else:
        return (yerr_ci95, f"95% CI (source={src_used})")

# --------------------------------------------------------------------------------
# PER-(m, d_model) FIGURES — F1 vs DK (one line per head), saved to a subfolder
# Error-bar toggle: ERRORBAR_MODE in {"std", "ci95"}
# --------------------------------------------------------------------------------
PER_COMBO_SUBDIR = "per_combo_F1_vs_DK_byH"  # change if you like
per_combo_dir = os.path.join(TARGET_DIR, PER_COMBO_SUBDIR)
os.makedirs(per_combo_dir, exist_ok=True)

if not any([STD_COL, SE_COL, CI_COL]):
    print("Per-combo plots will NOT include error bars (no std/sem/ci column found).")
else:
    print(f"Per-combo plots error bars mode: {ERRORBAR_MODE!r} "
          f"(std→±1 SD, ci95→95% CI half-width)")

pair_count = 0
for (m_val, d_val), sub in unique_df.groupby(["m", "d_model"], sort=True):
    heads = sorted(sub["h"].unique().tolist())

    fig_w, fig_h = 8.3, 5.2
    fig3, ax3 = plt.subplots(figsize=(fig_w, fig_h))

    ymin_seen, ymax_seen = +1e9, -1e9

    for h_val in heads:
        dh = sub.loc[sub["h"] == h_val].sort_values("D_K")
        if dh.empty:
            continue

        x = dh["D_K"].astype(float).values
        y = dh["microF1_mean"].astype(float).values
        ymin_seen = min(ymin_seen, float(np.nanmin(y)))
        ymax_seen = max(ymax_seen, float(np.nanmax(y)))

        # Compute error bars for this head, per the toggle
        yerr, err_label = compute_errorbars_for_group(dh, ERRORBAR_MODE)

        if (yerr is not None) and np.any(np.isfinite(yerr)):
            ax3.errorbar(
                x, y, yerr=yerr, marker="o", linestyle="-", linewidth=1.8,
                capsize=3, label=f"h={int(h_val)}"
            )
        else:
            ax3.plot(x, y, marker="o", linestyle="-", linewidth=1.8, label=f"h={int(h_val)}")

    ax3.set_xlabel("Total Key Dim $D_K$")
    ax3.set_ylabel("Mean micro-F1 (test)")
    suffix = "±1 SD" if ERRORBAR_MODE == "std" else "95% CI"
    ax3.set_title(fr"F1 vs DK — m={int(m_val)}, $d_{{\mathrm{{model}}}}$={int(d_val)}")
    ax3.grid(alpha=0.3, linewidth=0.6)

    # y-limits: give a little headroom, clip to [0, 1.02]
    if ymin_seen < 1e9:
        pad = 0.03
        ax3.set_ylim(max(0.0, ymin_seen - pad), min(1.02, max(0.85, ymax_seen + pad)))

    # legend like the reference figure
    if len(heads) <= 10:
        ax3.legend(title="Heads", fontsize=9)
    else:
        ax3.legend(title="Heads", fontsize=8, bbox_to_anchor=(1.02, 1.0),
                   loc="upper left", borderaxespad=0.)

    plt.tight_layout()

    # save to subfolder
    out_name = f"F1_vs_DK_m{int(m_val)}_d{int(d_val)}.png"
    out_path = os.path.join(per_combo_dir, out_name)
    fig3.savefig(out_path, dpi=150, bbox_inches="tight")
    print(f"Saved per-combo plot: {out_path} | Exists? {os.path.exists(out_path)}")

    plt.show()
    plt.close(fig3)
    pair_count += 1

print(f"\nCreated {pair_count} per-(m,d_model) figure(s) in: {per_combo_dir}")

# Summary of which files were included
print("\nIncluded files:")
for p in good_files:
    print(" •", os.path.basename(p))
