{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===========================================\n",
    "#   Image Augmentation Selection with CNN + LAKCP\n",
    "#   - CNN from scratch for scoring & classification\n",
    "#   - Doc-level marginal CP and LAKCP conditional CP\n",
    "#   - Generated files: only \"*__gen_<id>__mask.(png|jpg|jpeg)\"\n",
    "# ===========================================\n",
    "\n",
    "from __future__ import annotations\n",
    "import os, re, math, random\n",
    "from dataclasses import dataclass, field\n",
    "from pathlib import Path\n",
    "from typing import List, Dict, Tuple, Optional, Iterable, Any\n",
    "import scipy.spatial as sp\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as tud\n",
    "import torchvision.transforms as T\n",
    "import torchvision.models as tv_models\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "# ---------- CONFIG (edit these paths & knobs) ----------\n",
    "SEED        = 42\n",
    "DEVICE      = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "EXP_ROOT    = \"/image_net/generated/\"     # folder with generated: <class>/<base_id>/*__gen_<id>__mask.png\n",
    "ORIG_ROOT   = \"/image_net/experiment/train/\"  # base/original images by class: <class>/<base_id>.png\n",
    "VAL_ROOT    = \"/image_net/experiment/val/\"\n",
    "TEST_ROOT   = \"/image_net/experiment/test/\"\n",
    "\n",
    "# Split fractions for building (train/calib/eval) groups by base image\n",
    "TRAIN_FRAC  = 0.3\n",
    "CALIB_FRAC  = (1- TRAIN_FRAC)/2  # remaining (1 - TRAIN_FRAC - CALIB_FRAC) is \"eval\" groups\n",
    "\n",
    "# CP knobs (work on [0,1] CNN probabilities)\n",
    "EPSILON     = 0.10\n",
    "ALPHA_CP    = EPSILON\n",
    "RHO         = 0           # per-doc allowed # bad accepts\n",
    "LAMBDA_OBS  = 0.80        # doc considers a candidate \"bad\" if obs < LAMBDA_OBS\n",
    "\n",
    "# LAKCP conditioning features\n",
    "PCA_DIM     = 3          # PCA dim for raw base image features\n",
    "RAW_SIDE    = 64          # downsample side for raw base images before PCA\n",
    "\n",
    "# Training knobs for CNNs\n",
    "IMG_SIZE    = 128\n",
    "BATCH_SIZE  = 64\n",
    "EPOCHS      = 20\n",
    "LR          = 1e-3\n",
    "WEIGHT_DECAY= 1e-4\n",
    "PATIENCE    = 5\n",
    "\n",
    "# ------------------------------------------------------\n",
    "\n",
    "# ----- Reproducibility -----\n",
    "def set_seed(seed: int = SEED):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "set_seed(SEED)\n",
    "\n",
    "# ----- LAKCP import -----\n",
    "try:\n",
    "    from LatentKernCP.lakcp import LAKCP\n",
    "except Exception as e:\n",
    "    raise ImportError(\"Could not import LAKCP. Make sure lakcp.py is on your PYTHONPATH.\") from e\n",
    "\n",
    "# =========================\n",
    "# Dataset scanning\n",
    "# =========================\n",
    "\n",
    "@dataclass(frozen=True)\n",
    "class AugGroup:\n",
    "    \"\"\"One base/original image and all of its generated augmentations (mask-variant only).\"\"\"\n",
    "    cls: str\n",
    "    base_id: str\n",
    "    aug_paths: List[Path]\n",
    "\n",
    "@dataclass\n",
    "class ImageUnit:\n",
    "    doc_idx: int\n",
    "    class_name: str\n",
    "    base_id: str\n",
    "    cand_paths: List[str]\n",
    "    A_obs: Optional[List[float]] = field(default=None)  # observed quality (CNN prob)\n",
    "    A_hat: Optional[List[float]] = field(default=None)  # predicted quality (here we reuse prob)\n",
    "\n",
    "# Only keep \"*__gen_<id>__mask.(png|jpg|jpeg)\"\n",
    "_ID_RE       = re.compile(r\"gen[_-]?(\\d+)\", re.IGNORECASE)\n",
    "#_MASK_ONLY_RE= re.compile(r\"__gen[_-]?(\\d+?)\\.(png|jpg|jpeg)$\", re.IGNORECASE)\n",
    "\n",
    "_MASK_ONLY_RE = re.compile(\n",
    "    r\"__gen[_-]?(\\d+?)(?:__mask)?\\.(png|jpg|jpeg)$\",\n",
    "    re.IGNORECASE\n",
    ")\n",
    "\n",
    "\n",
    "def _is_mask_aug(p: Path) -> bool:\n",
    "    return bool(_MASK_ONLY_RE.search(p.name))\n",
    "\n",
    "def _extract_id(p: Path) -> int:\n",
    "    m = _ID_RE.search(p.name)\n",
    "    if m:\n",
    "        try: return int(m.group(1))\n",
    "        except ValueError: return 10**9\n",
    "    return 10**9\n",
    "\n",
    "def _sort_aug_paths(paths: Iterable[Path]) -> List[Path]:\n",
    "    paths = list(paths)\n",
    "    paths.sort(key=_extract_id)\n",
    "    return paths\n",
    "\n",
    "def scan_exp_generated(exp_root: str | Path) -> List[AugGroup]:\n",
    "    \"\"\"Scan classes/<base_id>/*__gen_<id>__mask.* under exp_root (or exp_root/generated).\"\"\"\n",
    "    root = Path(exp_root)\n",
    "    if (root / \"generated\").is_dir():\n",
    "        root = root / \"generated\"\n",
    "\n",
    "    groups: List[AugGroup] = []\n",
    "    if not root.is_dir():\n",
    "        return groups\n",
    "\n",
    "    for class_dir in sorted(d for d in root.iterdir() if d.is_dir()):\n",
    "        for base_dir in sorted(d for d in class_dir.iterdir() if d.is_dir()):\n",
    "            candidates = list(base_dir.glob(\"*.png\")) + list(base_dir.glob(\"*.jpg\")) + list(base_dir.glob(\"*.jpeg\"))\n",
    "            mask_only = [p for p in candidates if _is_mask_aug(p)]\n",
    "            if not mask_only:\n",
    "                continue\n",
    "            paths = _sort_aug_paths(mask_only)\n",
    "            groups.append(AugGroup(cls=class_dir.name, base_id=base_dir.name, aug_paths=paths))\n",
    "    return groups\n",
    "\n",
    "def group_key(g: AugGroup) -> str:\n",
    "    return f\"{g.cls}::{g.base_id}\"\n",
    "\n",
    "def split_by_group(\n",
    "    groups: List[AugGroup],\n",
    "    train_frac: float = TRAIN_FRAC,\n",
    "    calib_frac: float = CALIB_FRAC,\n",
    "    seed: int = SEED,\n",
    ") -> Tuple[List[AugGroup], List[AugGroup], List[AugGroup]]:\n",
    "    \"\"\"Deterministic split by hashing the group key.\"\"\"\n",
    "    assert train_frac < 1 and 0 < calib_frac < 1 and train_frac + calib_frac < 1\n",
    "    import hashlib\n",
    "    def bucket(k: str) -> float:\n",
    "        h = hashlib.sha1((str(seed) + \"|\" + k).encode(\"utf-8\")).hexdigest()\n",
    "        return int(h[:8], 16) / 0xFFFFFFFF\n",
    "    train, calib, eval_ = [], [], []\n",
    "    for g in groups:\n",
    "        r = bucket(group_key(g))\n",
    "        if r < train_frac: train.append(g)\n",
    "        elif r < train_frac + calib_frac: calib.append(g)\n",
    "        else: eval_.append(g)\n",
    "    return train, calib, eval_\n",
    "\n",
    "def build_image_units_from_disk(exp_root: str | Path) -> List[ImageUnit]:\n",
    "    groups = scan_exp_generated(exp_root)\n",
    "    units: List[ImageUnit] = []\n",
    "    for i, g in enumerate(groups):\n",
    "        units.append(ImageUnit(\n",
    "            doc_idx=i, class_name=g.cls, base_id=g.base_id,\n",
    "            cand_paths=[str(p) for p in g.aug_paths],\n",
    "        ))\n",
    "    return units\n",
    "\n",
    "def base_path_for_unit(u: ImageUnit, orig_root: Optional[str]) -> str:\n",
    "    assert orig_root is not None, \"ORIG_ROOT must be provided\"\n",
    "    return str(Path(orig_root) / u.class_name / f\"{u.base_id}.png\")\n",
    "\n",
    "# =========================\n",
    "# Simple CNN (from scratch)\n",
    "# =========================\n",
    "\n",
    "def make_transforms(train: bool) -> T.Compose:\n",
    "    if train:\n",
    "        return T.Compose([\n",
    "            T.Resize((IMG_SIZE, IMG_SIZE)),\n",
    "            T.ToTensor(),\n",
    "            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "        ])\n",
    "    else:\n",
    "        return T.Compose([\n",
    "            T.Resize((IMG_SIZE, IMG_SIZE)),\n",
    "            T.ToTensor(),\n",
    "            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "        ])\n",
    "\n",
    "class ImagePathDataset(tud.Dataset):\n",
    "    def __init__(self, paths: List[str], labels: List[int], transform: T.Compose):\n",
    "        self.paths = paths\n",
    "        self.labels = labels\n",
    "        self.transform = transform\n",
    "    def __len__(self): return len(self.paths)\n",
    "    def __getitem__(self, idx):\n",
    "        img = Image.open(self.paths[idx]).convert(\"RGB\")\n",
    "        x = self.transform(img)\n",
    "        y = self.labels[idx]\n",
    "        return x, y\n",
    "\n",
    "\n",
    "class MediumCNN(nn.Module):\n",
    "    def __init__(self, num_classes: int):\n",
    "        super().__init__()\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(3, 32, 3, padding=1),  # 32 filters\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2),\n",
    "\n",
    "            nn.Conv2d(32, 64, 3, padding=1),  # 64 filters\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2),\n",
    "\n",
    "            nn.Conv2d(64, 128, 3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2),\n",
    "\n",
    "            nn.Conv2d(128, 128, 3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.AdaptiveAvgPool2d((1, 1)),   # global avg pool\n",
    "        )\n",
    "        self.classifier = nn.Linear(128, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        return self.classifier(x)\n",
    "\n",
    "\n",
    "class SmallCNN(nn.Module):\n",
    "    \"\"\"\n",
    "    Lightweight CNN trained from scratch.\n",
    "    \"\"\"\n",
    "    def __init__(self, num_classes: int):\n",
    "        super().__init__()\n",
    "        def block(cin, cout):\n",
    "            return nn.Sequential(\n",
    "                nn.Conv2d(cin, cout, kernel_size=3, padding=1, bias=False),\n",
    "                nn.BatchNorm2d(cout),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(cout, cout, kernel_size=3, padding=1, bias=False),\n",
    "                nn.BatchNorm2d(cout),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(2)\n",
    "            )\n",
    "        self.features = nn.Sequential(\n",
    "            block(3, 32),     # 224 -> 112\n",
    "            block(32, 64),    # 112 -> 56\n",
    "            block(64, 128),   # 56 -> 28\n",
    "            block(128, 256),  # 28 -> 14\n",
    "        )\n",
    "        self.pool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        self.head = nn.Linear(256, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.pool(x).flatten(1)\n",
    "        return self.head(x)\n",
    "\n",
    "\n",
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "class MuchSmallerCNN(nn.Module):\n",
    "    \"\"\"\n",
    "    A robust and simple CNN for a tiny dataset that handles any image size.\n",
    "    \"\"\"\n",
    "    def __init__(self, num_classes: int = 2):\n",
    "        super().__init__()\n",
    "        # A simple convolutional block\n",
    "        def block(cin, cout):\n",
    "            return nn.Sequential(\n",
    "                nn.Conv2d(cin, cout, kernel_size=3, padding=1),\n",
    "                nn.BatchNorm2d(cout),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(2) # Halve the image size\n",
    "            )\n",
    "\n",
    "        self.features = nn.Sequential(\n",
    "            block(3, 16),\n",
    "            block(16, 32),\n",
    "        )\n",
    "\n",
    "        # This adaptive pooling layer is the key change\n",
    "        self.pool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "\n",
    "        self.head = nn.Sequential(\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(32, num_classes) # The input is now just the number of channels (32)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.pool(x)\n",
    "        return self.head(x)\n",
    "\n",
    "# Now this will work for any IMG_SIZE\n",
    "# model = MuchSmallerCNN(num_classes=2)\n",
    "\n",
    "# You would use it like this:\n",
    "# model = MuchSmallerCNN(num_classes=2)\n",
    "\n",
    "import torch, numpy as np, random\n",
    "\n",
    "def set_seed(seed=42):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    # fully deterministic (slower) – optional:\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "\n",
    "from copy import deepcopy\n",
    "import torch.nn as nn\n",
    "\n",
    "def he_init(m):\n",
    "    if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "        nn.init.kaiming_normal_(m.weight, nonlinearity=\"relu\")\n",
    "        if m.bias is not None:\n",
    "            nn.init.zeros_(m.bias)\n",
    "    elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n",
    "        nn.init.ones_(m.weight); nn.init.zeros_(m.bias)\n",
    "\n",
    "def get_initial_state_for_seed(num_classes: int, seed: int):\n",
    "    # isolate RNG so nothing else is affected\n",
    "    with torch.random.fork_rng(devices=[]):\n",
    "        torch.manual_seed(seed)\n",
    "        model = make_model(num_classes)     # your SmallCNN/MediumCNN/etc.\n",
    "        model.apply(he_init)                # explicit He init\n",
    "    init_state = deepcopy(model.state_dict())  # freeze a copy\n",
    "    return init_state\n",
    "\n",
    "\n",
    "def make_model(num_classes: int) -> nn.Module:\n",
    "    # ResNet18 with random init:\n",
    "    # Implement a simple CNN from scratch CNN\n",
    "    # m = tv_models.resnet18(weights=None)\n",
    "    # m.fc = nn.Linear(m.fc.in_features, num_classes)\n",
    "    #m = tv_models.mobilenet_v2(weights=None)\n",
    "    #m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes)\n",
    "    #m = MediumCNN(num_classes)\n",
    "    #m = SmallCNN(num_classes)\n",
    "    m = MuchSmallerCNN(num_classes)\n",
    "    \n",
    "    return m\n",
    "\n",
    "def class_weight_tensor(y: List[int], num_classes: int) -> torch.Tensor:\n",
    "    counts = np.bincount(np.asarray(y), minlength=num_classes).astype(np.float32)\n",
    "    counts[counts == 0] = 1.0\n",
    "    w = counts.sum() / counts\n",
    "    w = w / w.mean()\n",
    "    return torch.tensor(w, dtype=torch.float32, device=DEVICE)\n",
    "\n",
    "\n",
    "def make_loader_fixed_steps(paths, labels, batch_size, steps_per_epoch, seed=None):\n",
    "    ds = ImagePathDataset(paths, labels, make_transforms(train=True))\n",
    "    # oversample/undersample to hit steps_per_epoch exactly\n",
    "    N = len(ds)\n",
    "    rng = np.random.default_rng(seed)\n",
    "    idx = rng.integers(0, N, size=steps_per_epoch * batch_size) if N>0 else np.array([], int)\n",
    "    # create a subset “view” with replacement; fast & deterministic\n",
    "    xs = [paths[i] for i in idx]; ys = [labels[i] for i in idx]\n",
    "    sub = ImagePathDataset(xs, ys, make_transforms(train=True))\n",
    "    return tud.DataLoader(sub, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)\n",
    "\n",
    "\n",
    "\n",
    "def train_cnn_from_scratch(\n",
    "    train_paths: List[str], train_labels: List[int],\n",
    "    val_paths: List[str],   val_labels: List[int],\n",
    "    num_classes: int,\n",
    "    *,\n",
    "    epochs: int = EPOCHS, batch_size: int = BATCH_SIZE, lr: float = LR, weight_decay: float = WEIGHT_DECAY,\n",
    "    init_state=None, \n",
    "    seed = None\n",
    ") -> Tuple[nn.Module, float]:\n",
    "    \n",
    "    if seed is not None:\n",
    "        set_seed(seed)  # reproducible dataloader shuffle & torch ops\n",
    "\n",
    "    model = make_model(num_classes).to(DEVICE)\n",
    "    if init_state is not None:\n",
    "        missing, unexpected = model.load_state_dict(init_state, strict=False)\n",
    "        # optional: print mismatches if you expect exact match\n",
    "        # print(\"[init] missing:\", missing, \" unexpected:\", unexpected)\n",
    "\n",
    "    # ---- deterministic DataLoader shuffling (optional) ----\n",
    "    #   (on CUDA/Linux this is rock-solid; on Mac keep num_workers=0)\n",
    "    gen = torch.Generator()\n",
    "    if seed is not None:\n",
    "        gen.manual_seed(seed)\n",
    "\n",
    "    \n",
    "    num_workers = 2 if DEVICE == \"cuda\" else 0\n",
    "    pin = (DEVICE == \"cuda\")\n",
    "\n",
    "    tr_ds = ImagePathDataset(train_paths, train_labels, make_transforms(train=True))\n",
    "    va_ds = ImagePathDataset(val_paths,   val_labels,   make_transforms(train=False))\n",
    "\n",
    "    def seed_worker(worker_id):\n",
    "        worker_seed = seed + worker_id if seed is not None else worker_id\n",
    "        np.random.seed(worker_seed)\n",
    "        random.seed(worker_seed)\n",
    "\n",
    "    tr_ds = ImagePathDataset(train_paths, train_labels, make_transforms(train=True))\n",
    "    va_ds = ImagePathDataset(val_paths,   val_labels,   make_transforms(train=False))\n",
    "\n",
    "    tr_ld = tud.DataLoader(tr_ds, batch_size=batch_size, shuffle=True,\n",
    "                           num_workers=num_workers, pin_memory=pin,\n",
    "                           generator=gen, worker_init_fn=seed_worker if num_workers>0 else None,\n",
    "                           drop_last=False)\n",
    "    va_ld = tud.DataLoader(va_ds, batch_size=batch_size, shuffle=False,\n",
    "                           num_workers=num_workers, pin_memory=pin, drop_last=False)\n",
    "\n",
    "    # --- yo\n",
    "\n",
    "    crit  = nn.CrossEntropyLoss(weight=class_weight_tensor(train_labels, num_classes))\n",
    "    opt   = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "\n",
    "\n",
    "    best_val, best_state, bad = -1.0, None, 0\n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        for xb, yb in tr_ld:\n",
    "            xb, yb = xb.to(DEVICE), yb.to(DEVICE)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            logits = model(xb)\n",
    "            loss   = crit(logits, yb)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "\n",
    "        # validate\n",
    "        model.eval()\n",
    "        correct, total = 0, 0\n",
    "        with torch.no_grad():\n",
    "            for xb, yb in va_ld:\n",
    "                xb, yb = xb.to(DEVICE), yb.to(DEVICE)\n",
    "                logits = model(xb)\n",
    "                pred   = logits.argmax(dim=1)\n",
    "                correct+= int((pred == yb).sum().item())\n",
    "                total  += int(yb.numel())\n",
    "        val_acc = correct / max(1,total)\n",
    "        if val_acc > best_val:\n",
    "            best_val = val_acc\n",
    "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
    "            bad = 0\n",
    "        else:\n",
    "            bad += 1\n",
    "            if bad >= PATIENCE:\n",
    "                break\n",
    "\n",
    "    if best_state is not None:\n",
    "        model.load_state_dict(best_state)\n",
    "    return model, float(best_val)\n",
    "\n",
    "def eval_cnn(model: nn.Module, paths: List[str], labels: List[int], batch_size: int=BATCH_SIZE) -> float:\n",
    "    ds = ImagePathDataset(paths, labels, make_transforms(train=False))\n",
    "    ld = tud.DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "    model.eval()\n",
    "    correct, total = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for xb, yb in ld:\n",
    "            xb, yb = xb.to(DEVICE), yb.to(DEVICE)\n",
    "            pred = model(xb).argmax(dim=1)\n",
    "            correct += int((pred == yb).sum().item())\n",
    "            total   += int(yb.numel())\n",
    "    return correct / max(1,total)\n",
    "\n",
    "# =========================\n",
    "# Build classification datasets (paths/labels)\n",
    "# =========================\n",
    "\n",
    "def scan_labeled_folder(root: str, class_to_idx: Dict[str,int]) -> Tuple[List[str], List[int]]:\n",
    "    paths, labels = [], []\n",
    "    r = Path(root)\n",
    "    for class_dir in sorted(d for d in r.iterdir() if d.is_dir()):\n",
    "        c = class_dir.name\n",
    "        if c not in class_to_idx:\n",
    "            continue\n",
    "        for img in sorted(list(class_dir.glob(\"*.png\")) + list(class_dir.glob(\"*.jpg\")) + list(class_dir.glob(\"*.jpeg\"))):\n",
    "            paths.append(str(img))\n",
    "            labels.append(class_to_idx[c])\n",
    "    return paths, labels\n",
    "\n",
    "def make_label_map(units_all: List[ImageUnit]) -> Dict[str, int]:\n",
    "    classes = sorted({u.class_name for u in units_all})\n",
    "    return {c: i for i, c in enumerate(classes)}\n",
    "\n",
    "def build_train_paths_from_selected(\n",
    "    units: List[ImageUnit],\n",
    "    selected: Dict[int, List[str]],\n",
    "    class_to_idx: Dict[str, int],\n",
    ") -> Tuple[List[str], List[int]]:\n",
    "    paths, labels = [], []\n",
    "    for u in units:\n",
    "        y = class_to_idx[u.class_name]\n",
    "        kept = selected.get(u.doc_idx, [base_path_for_unit(u, ORIG_ROOT)])\n",
    "        # dedup preserve order\n",
    "        seen = set()\n",
    "        kept = [p for p in kept if not (p in seen or seen.add(p))]\n",
    "        for p in kept:\n",
    "            paths.append(p)\n",
    "            labels.append(y)\n",
    "    return paths, labels\n",
    "\n",
    "# =========================\n",
    "# Scoring (A_obs, A_hat) with CNN\n",
    "# =========================\n",
    "\n",
    "@torch.no_grad()\n",
    "def cnn_prob_for_paths(model: nn.Module, paths: List[str], class_index: int) -> List[float]:\n",
    "    tfm = make_transforms(train=False)\n",
    "    probs: List[float] = []\n",
    "    model.eval()\n",
    "    # batched for speed\n",
    "    bs = BATCH_SIZE\n",
    "    for i in range(0, len(paths), bs):\n",
    "        chunk = paths[i:i+bs]\n",
    "        batch = torch.stack([tfm(Image.open(p).convert(\"RGB\")) for p in chunk], dim=0).to(DEVICE)\n",
    "        logits = model(batch)\n",
    "        p = logits.softmax(dim=1)[:, class_index].detach().cpu().numpy().tolist()\n",
    "        probs.extend(p)\n",
    "    return probs\n",
    "\n",
    "\n",
    "\n",
    "# ===== Zero-shot CLIP classifier (image -> class probabilities) =====\n",
    "\n",
    "CLIP_IMG_CACHE: Dict[str, np.ndarray] = {}  # image path -> normalized image embedding\n",
    "\n",
    "class CLIPZeroShot:\n",
    "    \"\"\"\n",
    "    Zero-shot classifier with CLIP.\n",
    "    - Builds text embeddings for your class names via prompt templates\n",
    "    - Encodes images in batches\n",
    "    - Returns per-class probabilities (softmax over similarities)\n",
    "    Works with open_clip_torch if available; otherwise falls back to OpenAI's clip.\n",
    "    \"\"\"\n",
    "    def __init__(self, class_names: List[str], device: str = DEVICE, model_name: str = \"ViT-B/32\"):\n",
    "        self.device = device\n",
    "        self.class_names = class_names\n",
    "        self.model = None\n",
    "        self.preprocess = None\n",
    "        self.backend = None\n",
    "\n",
    "        # Prompt templates (average over a few to stabilize scores a bit)\n",
    "        self.templates = [\n",
    "            \"a photo of a {}\",\n",
    "            \"a close-up photo of a {}\",\n",
    "            \"a cropped photo of a {}\",\n",
    "            \"a photo of the {}\",\n",
    "            \"a good photo of a {}\"\n",
    "        ]\n",
    "\n",
    "        # Try open_clip first\n",
    "        try:\n",
    "            import open_clip\n",
    "            self.backend = \"open_clip\"\n",
    "            self.model, _, self.preprocess = open_clip.create_model_and_transforms(\n",
    "                model_name, pretrained=\"laion2b_s34b_b79k\", device=device\n",
    "            )\n",
    "            self.model.eval()\n",
    "            self._text_embed_openclip(class_names)\n",
    "        except Exception:\n",
    "            # Fallback to OpenAI's clip\n",
    "            import clip as clip_pkg\n",
    "            self.backend = \"clip\"\n",
    "            self.model, self.preprocess = clip_pkg.load(model_name, device=device, jit=False)\n",
    "            self.model.eval()\n",
    "            self._text_embed_clippkg(class_names, clip_pkg)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _text_embed_openclip(self, class_names: List[str]):\n",
    "        import open_clip\n",
    "        # build prompt list\n",
    "        prompts = []\n",
    "        for c in class_names:\n",
    "            for t in self.templates:\n",
    "                prompts.append(t.format(c))\n",
    "        tok = open_clip.tokenize(prompts).to(self.device)\n",
    "        text_feat = self.model.encode_text(tok)  # (T, D)\n",
    "        # group/average by class\n",
    "        T = len(self.templates)\n",
    "        text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)\n",
    "        text_feat = text_feat.view(len(class_names), T, -1).mean(dim=1)\n",
    "        self.text_emb = text_feat / text_feat.norm(dim=-1, keepdim=True)  # (C, D)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _text_embed_clippkg(self, class_names: List[str], clip_pkg):\n",
    "        # build prompt list\n",
    "        prompts = []\n",
    "        for c in class_names:\n",
    "            for t in self.templates:\n",
    "                prompts.append(t.format(c))\n",
    "        tok = clip_pkg.tokenize(prompts).to(self.device)\n",
    "        text_feat = self.model.encode_text(tok)  # (T, D)\n",
    "        text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)\n",
    "        T = len(self.templates)\n",
    "        text_feat = text_feat.view(len(class_names), T, -1).mean(dim=1)\n",
    "        self.text_emb = text_feat / text_feat.norm(dim=-1, keepdim=True)  # (C, D)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def image_embeds_batched(self, paths: List[str], batch_size: int = 256) -> Dict[str, np.ndarray]:\n",
    "        out: Dict[str, np.ndarray] = {}\n",
    "        todo = [p for p in paths if p not in CLIP_IMG_CACHE]\n",
    "        if todo:\n",
    "            bs = batch_size\n",
    "            for i in range(0, len(todo), bs):\n",
    "                chunk = todo[i:i+bs]\n",
    "                imgs = [self.preprocess(Image.open(p).convert(\"RGB\")) for p in chunk]\n",
    "                batch = torch.stack(imgs, dim=0).to(self.device)\n",
    "                # AMP on CUDA gives a nice speedup\n",
    "                use_amp = (self.device == \"cuda\")\n",
    "                ctx = torch.cuda.amp.autocast() if use_amp else torch.no_grad()\n",
    "                with ctx:\n",
    "                    img = self.model.encode_image(batch)  # (B, D)\n",
    "                    img = img / img.norm(dim=-1, keepdim=True)\n",
    "                for pth, vec in zip(chunk, img):\n",
    "                    v = vec.detach().cpu().numpy().astype(np.float32)\n",
    "                    CLIP_IMG_CACHE[pth] = v\n",
    "                    out[pth] = v\n",
    "        # also fill from cache\n",
    "        for p in paths:\n",
    "            if p not in out:\n",
    "                out[p] = CLIP_IMG_CACHE[p]\n",
    "        return out\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def probs_for_paths(self, paths: List[str], batch_size: int = 256) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Returns an array of shape (N, C) with class probabilities for each path.\n",
    "        \"\"\"\n",
    "        embeds = self.image_embeds_batched(paths, batch_size=batch_size)  # dict\n",
    "        # stack in same order\n",
    "        X = np.vstack([embeds[p] for p in paths])  # (N, D)\n",
    "        # to torch for matmul/softmax\n",
    "        X_t = torch.from_numpy(X).to(self.device)\n",
    "        T_t = self.text_emb.to(self.device)        # (C, D)\n",
    "        sims = (100.0 * (X_t @ T_t.T)).softmax(dim=-1)  # (N, C)\n",
    "        return sims.detach().cpu().numpy().astype(np.float32)\n",
    "\n",
    "@torch.no_grad()\n",
    "def populate_scores_with_cnn(\n",
    "    cnn_model: nn.Module,\n",
    "    units: List[ImageUnit],\n",
    "    class_to_idx: Dict[str,int],\n",
    "    zeroshot: CLIPZeroShot,\n",
    "    *,\n",
    "    conf_from_cnn: bool = True,         # if True, transform CNN prob -> |p-0.5|*2\n",
    "    clip_bs: int = 512,                 # CLIP batch size (L4 should handle 512 fine)\n",
    "    cnn_bs: int  = BATCH_SIZE,\n",
    ") -> None:\n",
    "    \"\"\"\n",
    "    Observed (A_obs): CLIP zero-shot probability for the true class.\n",
    "    Predicted (A_hat): CNN proxy; optionally mapped to confidence |p-0.5|*2 in [0,1].\n",
    "    \"\"\"\n",
    "    tfm = make_transforms(train=False)\n",
    "    cnn_model.eval()\n",
    "\n",
    "    for u in units:\n",
    "        y_idx = class_to_idx[u.class_name]\n",
    "        paths = u.cand_paths\n",
    "        if not paths:\n",
    "            u.A_obs = []; u.A_hat = []; continue\n",
    "\n",
    "        # ----- CLIP observed: P(class | image) via zero-shot -----\n",
    "        # We get probs for all classes then take the column for y_idx.\n",
    "        probs_clip = zeroshot.probs_for_paths(paths, batch_size=clip_bs)  # (N, C)\n",
    "        A_obs = probs_clip[:, y_idx].astype(np.float32).tolist()\n",
    "\n",
    "        # ----- CNN proxy predicted -----\n",
    "        A_hat_list: List[float] = []\n",
    "        for i in range(0, len(paths), cnn_bs):\n",
    "            chunk = paths[i:i+cnn_bs]\n",
    "            batch = torch.stack([tfm(Image.open(p).convert(\"RGB\")) for p in chunk], dim=0).to(DEVICE)\n",
    "            logits = cnn_model(batch)\n",
    "            p = logits.softmax(dim=1)[:, y_idx].detach().cpu().numpy().astype(np.float32)\n",
    "            if conf_from_cnn:\n",
    "                p = np.abs(p - 0.5) * 2.0  # map to [0,1] with 0.5 -> 0, {0,1} -> 1\n",
    "            A_hat_list.extend(p.tolist())\n",
    "\n",
    "        u.A_obs = A_obs                # observed = CLIP zero-shot probability\n",
    "        u.A_hat = A_hat_list           # predicted = CNN proxy\n",
    "\n",
    "\n",
    "\n",
    "# def populate_scores_with_cnn(\n",
    "#     cnn_model: nn.Module,\n",
    "#     units: List[ImageUnit],\n",
    "#     class_to_idx: Dict[str,int],\n",
    "#     clip_model: \"CLIPEmbedder\",\n",
    "# ) -> None:\n",
    "#     \"\"\"\n",
    "#     For each unit:\n",
    "#       - A_obs = CLIP cosine similarity between base and candidate images\n",
    "#       - A_hat = CNN probability for the true class (cheap proxy)\n",
    "#     \"\"\"\n",
    "#     for u in units:\n",
    "#         y_idx = class_to_idx[u.class_name]\n",
    "#         paths = u.cand_paths\n",
    "#         if not paths:\n",
    "#             u.A_obs = []\n",
    "#             u.A_hat = []\n",
    "#             continue\n",
    "\n",
    "#         # --- 1. CLIP observed scores ---\n",
    "#         base_path = base_path_for_unit(u, orig_root=ORIG_ROOT)\n",
    "#         base_clip = clip_model.embed_path(base_path)\n",
    "#         print(base_clip)\n",
    "#         obs_scores = []\n",
    "#         for p in paths:\n",
    "#             cand_clip = clip_model.embed_path(p)\n",
    "#             sim = float(np.dot(base_clip, cand_clip))  # cosine since CLIP outputs are normalized\n",
    "#             obs_scores.append(sim)\n",
    "\n",
    "#         # --- 2. CNN proxy scores ---\n",
    "#         probs = cnn_prob_for_paths(cnn_model, paths, y_idx)  # your helper for CNN probabilities\n",
    "#         #probs = np.abs(probs - 0.5) * 2.0  # map [0,1] -> [0,1] with 0.5->0, 1.0->1.0\n",
    "#         probs = np.abs(np.asarray(probs) - 0.5)* 2.0  # map [0,1] -> [0,1] with 0.5->0, 1.0->1.0\n",
    "\n",
    "#         u.A_hat = obs_scores         # observed = CLIP similarity\n",
    "#         u.A_obs = list(map(float, probs))  # predicted = CNN probability\n",
    "\n",
    "\n",
    "# =========================\n",
    "# CP selection (doc-level)\n",
    "# =========================\n",
    "\n",
    "def _splitconformal_quantile(vals: np.ndarray, alpha: float) -> float:\n",
    "    vals = np.asarray(vals, dtype=np.float64)\n",
    "    n = int(vals.size)\n",
    "    if n <= 0:\n",
    "        return float(\"-inf\")\n",
    "    q_idx = int(math.ceil((n + 1) * (1.0 - alpha)))\n",
    "    q_idx = min(max(1, q_idx), n)\n",
    "    return float(np.sort(vals)[q_idx - 1])\n",
    "\n",
    "def _S_doc_from_unit(u: ImageUnit, lambda_obs: float, rho: int) -> float:\n",
    "    \"\"\"(rho+1)-th largest predicted score among BAD (obs<lambda). If #BAD<=rho => -inf.\"\"\"\n",
    "    if u.A_obs is None or u.A_hat is None:\n",
    "        raise RuntimeError(\"A_obs/A_hat must be populated before CP selection.\")\n",
    "    obs = np.asarray(u.A_obs, dtype=float)\n",
    "    hat = np.asarray(u.A_hat, dtype=float)\n",
    "    bad_mask = (obs < float(lambda_obs))\n",
    "    m_bad = int(bad_mask.sum())\n",
    "    if m_bad <= int(rho):\n",
    "        return float(\"-inf\")\n",
    "    bad_pred = np.sort(hat[bad_mask])[::-1]\n",
    "    return float(bad_pred[int(rho)])\n",
    "\n",
    "def select_by_cp_marginal_doclevel(\n",
    "    calib_units: List[ImageUnit],\n",
    "    target_units: List[ImageUnit],\n",
    "    *,\n",
    "    lambda_obs: float,\n",
    "    alpha_cp: float,\n",
    "    rho: int,\n",
    ") -> Dict[int, List[str]]:\n",
    "    S_list = []\n",
    "    for u in calib_units:\n",
    "        s = _S_doc_from_unit(u, lambda_obs, rho)\n",
    "        if np.isfinite(s): S_list.append(s)\n",
    "    s_global = _splitconformal_quantile(np.asarray(S_list, dtype=float), alpha=alpha_cp)\n",
    "\n",
    "    selected: Dict[int, List[str]] = {}\n",
    "    for u in target_units:\n",
    "        base = base_path_for_unit(u, ORIG_ROOT)\n",
    "        scores = u.A_hat or [0.0]*len(u.cand_paths)\n",
    "        kept = [p for (p,s) in zip(u.cand_paths, scores) if float(s) >= s_global]\n",
    "        if base not in kept: kept = [base] + kept\n",
    "        else: kept = [base] + [p for p in kept if p != base]\n",
    "        seen = set(); kept = [p for p in kept if not (p in seen or seen.add(p))]\n",
    "        selected[u.doc_idx] = kept\n",
    "\n",
    "    print(f\"[marg-doc] |calib docs|={len(calib_units)} |S_doc|={len(S_list)} s_global={s_global:.4f}\")\n",
    "    return selected\n",
    "\n",
    "# ---- LAKCP conditional doc-level CP (condition on PCA(raw base image)) ----\n",
    "\n",
    "def _base_image_raw_flat(path: str, size: int = RAW_SIDE) -> np.ndarray:\n",
    "    \"\"\"Downsampled raw RGB -> flattened [0,1].\"\"\"\n",
    "    im = Image.open(path).convert(\"RGB\").resize((size, size))\n",
    "    arr = np.asarray(im, dtype=np.float32) / 255.0\n",
    "    return arr.reshape(-1)\n",
    "\n",
    "def select_by_cp_conditional_doclevel(\n",
    "    calib_units: List[ImageUnit],\n",
    "    target_units: List[ImageUnit],\n",
    "    *,\n",
    "    lambda_obs: float,\n",
    "    alpha_cp: float,\n",
    "    rho: int,\n",
    "    pca_dim: int = PCA_DIM,\n",
    "    raw_side: int = RAW_SIDE,\n",
    ") -> Dict[int, List[str]]:\n",
    "    # 1) calibration S_doc and raw base features\n",
    "    Xc_raw, S_cal = [], []\n",
    "    for u in calib_units:\n",
    "        s = _S_doc_from_unit(u, lambda_obs=lambda_obs, rho=rho)\n",
    "        if not np.isfinite(s):\n",
    "            continue\n",
    "        base = base_path_for_unit(u, ORIG_ROOT)\n",
    "        Xc_raw.append(_base_image_raw_flat(base, size=raw_side))\n",
    "        S_cal.append(float(s))\n",
    "    if len(S_cal) == 0:\n",
    "        print(\"[cond-doc-lakcp] No finite S_doc on calib; falling back to marginal.\")\n",
    "        return select_by_cp_marginal_doclevel(calib_units, target_units, lambda_obs=lambda_obs, alpha_cp=alpha_cp, rho=rho)\n",
    "\n",
    "    Xc_raw = np.vstack(Xc_raw).astype(np.float32)\n",
    "    S_cal  = np.asarray(S_cal, dtype=np.float32)\n",
    "\n",
    "    # 2) target raw base features\n",
    "    Xt_raw = np.vstack([_base_image_raw_flat(base_path_for_unit(u, ORIG_ROOT), size=raw_side)\n",
    "                        for u in target_units]).astype(np.float32)\n",
    "\n",
    "    # 3) PCA on combined\n",
    "    X_all = np.vstack([Xc_raw, Xt_raw])\n",
    "    pca_k = int(min(pca_dim, X_all.shape[1], max(1, X_all.shape[0]-1)))\n",
    "    \n",
    "    X_all_std = StandardScaler(with_mean=True, with_std=True).fit_transform(X_all)\n",
    "    Z_all = PCA(n_components=pca_k, random_state=SEED).fit_transform(X_all_std)\n",
    "    Zcal  = Z_all[:Xc_raw.shape[0], :]\n",
    "    Ztest = Z_all[Xc_raw.shape[0]:, :]\n",
    "\n",
    "    \n",
    "    d = sp.distance.pdist(Zcal, 'euclidean')\n",
    "    sigma = np.median(d) if d.size else 1.0\n",
    "    # if sigma ~ 0 (identical points), set a small floor\n",
    "    sigma = max(sigma, 1e-6)\n",
    "    gamma_grid = (1.0 / (2.0 * sigma**2)) * np.logspace(-2, 2, 25)\n",
    "\n",
    "\n",
    "    # 4) LAKCP for per-target cutoffs\n",
    "    lakcp = LAKCP(\n",
    "        alpha=alpha_cp, max_steps=200, eps=1e-3, tol=1e-6, thres=10.0,\n",
    "        ridge=1e-8, start_side=\"left\", gamma=None, gamma_grid=np.logspace(-1,1,30),\n",
    "        randomize=True, verbose=False\n",
    "    )\n",
    "    Phi_cal  = np.ones((Zcal.shape[0], 1), dtype=np.float64)\n",
    "    Phi_test = np.ones((Ztest.shape[0], 1), dtype=np.float64)\n",
    "\n",
    "    cutoffs, _ = lakcp.fit(Zcal, Phi_cal, S_cal.ravel(), Ztest, Phi_test)\n",
    "    cutoffs = np.asarray(cutoffs).reshape(-1)\n",
    "    if cutoffs.size != len(target_units):\n",
    "        raise RuntimeError(f\"LAKCP returned {cutoffs.size} cutoffs for {len(target_units)} targets\")\n",
    "\n",
    "    # 5) apply thresholds\n",
    "    selected: Dict[int, List[str]] = {}\n",
    "    for u, thr in zip(target_units, cutoffs):\n",
    "        base = base_path_for_unit(u, ORIG_ROOT)\n",
    "        scores = u.A_hat or [0.0]*len(u.cand_paths)\n",
    "        kept = [p for (p,s) in zip(u.cand_paths, scores) if float(s) >= float(thr)]\n",
    "        if base not in kept: kept = [base] + kept\n",
    "        else: kept = [base] + [p for p in kept if p != base]\n",
    "        seen = set(); kept = [p for p in kept if not (p in seen or seen.add(p))]\n",
    "        selected[u.doc_idx] = kept\n",
    "\n",
    "    print(f\"[cond-doc-lakcp] |calib used|={Zcal.shape[0]} pca_dim={Zcal.shape[1]} -> thresholds for {Ztest.shape[0]} targets\")\n",
    "    return selected\n",
    "\n",
    "# ----- Simple filter baseline -----\n",
    "\n",
    "def select_aug_filtered(units: List[ImageUnit], lambda_obs: float) -> Dict[int, List[str]]:\n",
    "    selected: Dict[int, List[str]] = {}\n",
    "    for u in units:\n",
    "        base = base_path_for_unit(u, ORIG_ROOT)\n",
    "        scores = u.A_hat or [0.0]*len(u.cand_paths)\n",
    "        kept = [p for (p,s) in zip(u.cand_paths, scores) if float(s) >= float(lambda_obs)]\n",
    "        if base not in kept: kept = [base] + kept\n",
    "        else: kept = [base] + [p for p in kept if p != base]\n",
    "        seen = set(); kept = [p for p in kept if not (p in seen or seen.add(p))]\n",
    "        selected[u.doc_idx] = kept\n",
    "    return selected\n",
    "\n",
    "def acceptance_stats(units: List[ImageUnit], selected: Dict[int, List[str]]) -> Tuple[int,int,float]:\n",
    "    total_aug, kept_aug = 0, 0\n",
    "    for u in units:\n",
    "        base = base_path_for_unit(u, ORIG_ROOT)\n",
    "        base_in_cands = base in set(u.cand_paths)\n",
    "        total_doc = len(u.cand_paths) - (1 if base_in_cands else 0)\n",
    "        total_aug += max(0, total_doc)\n",
    "        paths = selected.get(u.doc_idx, [base])\n",
    "        kept_doc = sum(1 for p in paths if p != base)\n",
    "        kept_aug += kept_doc\n",
    "        if kept_doc > total_doc:\n",
    "            print(f\"[warn] doc_idx={u.doc_idx} kept {kept_doc} > total {total_doc} \"\n",
    "                  f\"(base_in_cands={base_in_cands}, cand_count={len(u.cand_paths)})\")\n",
    "    rate = kept_aug / max(1,total_aug)\n",
    "    return kept_aug, total_aug, rate\n",
    "\n",
    "# =========================\n",
    "# Main experiment\n",
    "# =========================\n",
    "\n",
    "\n",
    "from collections import Counter\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "def merge_eval_selection_with_all_bases(\n",
    "    selection_for_eval: Dict[int, List[str]],\n",
    "    eval_units: List[ImageUnit],\n",
    "    calib_units: List[ImageUnit],\n",
    ") -> Dict[int, List[str]]:\n",
    "    \"\"\"\n",
    "    Start from selections on EVAL docs (which may or may not contain the base),\n",
    "    then ensure:\n",
    "      1) every EVAL doc has its base (first position);\n",
    "      2) every CALIB doc has its base (first position);\n",
    "    Never add any CALIB augmentations.\n",
    "    \"\"\"\n",
    "    merged: Dict[int, List[str]] = {k: list(v) for k, v in selection_for_eval.items()}\n",
    "\n",
    "    def add_base_first(doc_idx: int, base_path: str):\n",
    "        cur = merged.get(doc_idx, [])\n",
    "        if base_path in cur:\n",
    "            cur = [base_path] + [p for p in cur if p != base_path]\n",
    "        else:\n",
    "            cur = [base_path] + cur\n",
    "        # de‑dup while preserving order\n",
    "        seen = set()\n",
    "        cur = [p for p in cur if not (p in seen or seen.add(p))]\n",
    "        merged[doc_idx] = cur\n",
    "\n",
    "    # 1) ensure ALL EVAL bases are present (even if doc had no augs or was missing in selection)\n",
    "    for u in eval_units:\n",
    "        base = base_path_for_unit(u, ORIG_ROOT)\n",
    "        add_base_first(u.doc_idx, base)\n",
    "\n",
    "    # 2) ensure ALL CALIB bases are present (no calib augs)\n",
    "    for u in calib_units:\n",
    "        base = base_path_for_unit(u, ORIG_ROOT)\n",
    "        add_base_first(u.doc_idx, base)\n",
    "\n",
    "    return merged\n",
    "\n",
    "\n",
    "\n",
    "def run_one_seed(seed: int, lambda_grid: List[float], save_each_csv: Optional[str] = None) -> pd.DataFrame:\n",
    "    set_seed(seed)\n",
    "\n",
    "    # 0) scan and split\n",
    "    groups = scan_exp_generated(EXP_ROOT)\n",
    "    if not groups:\n",
    "        raise ValueError(f\"No generated images found under {EXP_ROOT}\")\n",
    "\n",
    "    units_all = [ImageUnit(i, g.cls, g.base_id, [str(p) for p in g.aug_paths]) for i, g in enumerate(groups)]\n",
    "    train_g, calib_g, eval_g = split_by_group(groups, TRAIN_FRAC, CALIB_FRAC, seed)\n",
    "\n",
    "    # map (cls, base_id) -> unit\n",
    "    key2unit = {(u.class_name, u.base_id): u for u in units_all}\n",
    "    def pick(gs): return [key2unit[(g.cls, g.base_id)] for g in gs if (g.cls, g.base_id) in key2unit]\n",
    "    train_units = pick(train_g)\n",
    "    calib_units = pick(calib_g)\n",
    "    eval_units  = pick(eval_g)\n",
    "\n",
    "    class_to_idx = make_label_map(units_all)\n",
    "    num_classes  = len(class_to_idx)\n",
    "    print(f\"[info seed={seed}] #classes={num_classes}  #train_docs={len(train_units)}  \"\n",
    "          f\"#calib_docs={len(calib_units)}  #eval_docs={len(eval_units)}\")\n",
    "\n",
    "    # Sanity: counts per class\n",
    "    print(\"[train per class]\", dict(Counter([u.class_name for u in train_units])))\n",
    "    print(\"[calib per class]\", dict(Counter([u.class_name for u in calib_units])))\n",
    "    print(\"[eval  per class]\", dict(Counter([u.class_name for u in eval_units])))\n",
    "\n",
    "    # 1) Train a CNN from scratch on UN-AUGMENTED bases from calib + eval\n",
    "    #base_train_paths  = [base_path_for_unit(u, ORIG_ROOT) for u in (calib_units + eval_units)]\n",
    "    #base_train_labels = [class_to_idx[u.class_name]       for u in (calib_units + eval_units)]\n",
    "    base_train_paths  = [base_path_for_unit(u, ORIG_ROOT) for u in (train_units)]\n",
    "    base_train_labels = [class_to_idx[u.class_name]       for u in (train_units)]\n",
    "\n",
    "    # Validation set\n",
    "    val_paths, val_labels = scan_labeled_folder(VAL_ROOT, class_to_idx)\n",
    "    if len(val_paths) == 0:\n",
    "        n = len(base_train_paths)\n",
    "        idx = np.arange(n); np.random.shuffle(idx)\n",
    "        m = max(1, int(0.2 * n))\n",
    "        val_idx, tr_idx = idx[:m], idx[m:]\n",
    "        val_paths  = [base_train_paths[i]  for i in val_idx]\n",
    "        val_labels = [base_train_labels[i] for i in val_idx]\n",
    "        base_train_paths  = [base_train_paths[i]  for i in tr_idx]\n",
    "        base_train_labels = [base_train_labels[i] for i in tr_idx]\n",
    "        print(\"[warn] VAL_ROOT empty; using 20% of base_train as validation.\")\n",
    "\n",
    "    base_model, base_val_acc = train_cnn_from_scratch(\n",
    "        base_train_paths, base_train_labels, val_paths, val_labels, num_classes\n",
    "    )\n",
    "    print(f\"[base-cnn seed={seed}] validation accuracy (base-only): {base_val_acc:.3f}\")\n",
    "\n",
    "    # 2) Score candidates (observed = CLIP zero-shot class prob; predicted = CNN proxy)\n",
    "    idx_to_class = {v: k for k, v in class_to_idx.items()}\n",
    "    class_names  = [idx_to_class[i] for i in range(len(idx_to_class))]\n",
    "    zeroshot     = CLIPZeroShot(class_names=class_names, device=DEVICE, model_name=\"ViT-B/32\")\n",
    "\n",
    "    populate_scores_with_cnn(base_model, calib_units, class_to_idx, zeroshot)\n",
    "    populate_scores_with_cnn(base_model, eval_units,  class_to_idx, zeroshot)\n",
    "\n",
    "    # fixed validation/test sets for final CNNs\n",
    "    X_val,  y_val  = scan_labeled_folder(VAL_ROOT,  class_to_idx)\n",
    "    X_test, y_test = scan_labeled_folder(TEST_ROOT, class_to_idx)\n",
    "    print(f\"[clf-data] Val={len(y_val)}  Test={len(y_test)}\")\n",
    "\n",
    "    # list of result rows we’ll accumulate\n",
    "    rows = []\n",
    "\n",
    "    # build init state once for this seed\n",
    "    \n",
    "    init_state = get_initial_state_for_seed(num_classes, seed)\n",
    "\n",
    "\n",
    "\n",
    "    # sweep over lambda values\n",
    "    for lambda_obs in lambda_grid:\n",
    "        # selections computed on EVAL (targets), calibrated on CALIB\n",
    "        sel_unaug_eval = select_aug_filtered(eval_units, lambda_obs=1.2)  # reject all augs\n",
    "        sel_unfilt_eval = select_aug_filtered(eval_units, lambda_obs=0.0)  # accept all augs\n",
    "        sel_filt_eval   = select_aug_filtered(eval_units, lambda_obs=lambda_obs)\n",
    "        sel_marg_eval   = select_by_cp_marginal_doclevel(\n",
    "            calib_units, eval_units, lambda_obs=lambda_obs, alpha_cp=ALPHA_CP, rho=RHO\n",
    "        )\n",
    "        sel_lakcp_eval  = select_by_cp_conditional_doclevel(\n",
    "            calib_units, eval_units, lambda_obs=lambda_obs, alpha_cp=ALPHA_CP, rho=RHO,\n",
    "            pca_dim=PCA_DIM, raw_side=RAW_SIDE\n",
    "        )\n",
    "\n",
    "        # stats (over eval docs because that’s where we selected augs)\n",
    "        kept_u, tot_u, rate_u = acceptance_stats(eval_units, sel_unfilt_eval)\n",
    "        kept_f, tot_f, rate_f = acceptance_stats(eval_units, sel_filt_eval)\n",
    "        kept_m, tot_m, rate_m = acceptance_stats(eval_units, sel_marg_eval)\n",
    "        kept_c, tot_c, rate_c = acceptance_stats(eval_units, sel_lakcp_eval)\n",
    "        print(f\"[select seed={seed} λ={lambda_obs:.2f}] \"\n",
    "              f\"Unfiltered {kept_u}/{tot_u} ({rate_u:.1%})  \"\n",
    "              f\"Filtered {kept_f}/{tot_f} ({rate_f:.1%})  \"\n",
    "              f\"MargCP {kept_m}/{tot_m} ({rate_m:.1%})  \"\n",
    "              f\"LAKCP {kept_c}/{tot_c} ({rate_c:.1%})\")\n",
    "\n",
    "        # merge: keep eval selections, add ALL calib bases\n",
    "        train_sets = {\n",
    "            \"Aug-Unaugmented\": merge_eval_selection_with_all_bases(sel_unaug_eval, eval_units, calib_units),\n",
    "            \"Aug-Unfiltered\": merge_eval_selection_with_all_bases(sel_unfilt_eval, eval_units, calib_units),\n",
    "            \"Aug-Filtered\":   merge_eval_selection_with_all_bases(sel_filt_eval,   eval_units, calib_units),\n",
    "            \"CP-Marginal\":    merge_eval_selection_with_all_bases(sel_marg_eval,   eval_units, calib_units),\n",
    "            \"CP-LAKCP\":       merge_eval_selection_with_all_bases(sel_lakcp_eval,  eval_units, calib_units),\n",
    "        }\n",
    "\n",
    "        # After building train_sets\n",
    "        all_eval_ids  = {u.doc_idx for u in eval_units}\n",
    "        all_calib_ids = {u.doc_idx for u in calib_units}\n",
    "\n",
    "        for name, sel in train_sets.items():\n",
    "            have_eval_bases = all(base_path_for_unit(u, ORIG_ROOT) in sel[u.doc_idx]\n",
    "                                for u in eval_units if u.doc_idx in sel)\n",
    "            missing_eval    = sorted(list(all_eval_ids - set(sel.keys())))\n",
    "            have_calib_bases= all(base_path_for_unit(u, ORIG_ROOT) in sel.get(u.doc_idx, [])\n",
    "                                for u in calib_units)\n",
    "            print(f\"[{name}] eval_bases_ok={have_eval_bases}  \"\n",
    "                f\"missing_eval_docs={len(missing_eval)}  \"\n",
    "                f\"calib_bases_ok={have_calib_bases}\")\n",
    "    \n",
    "\n",
    "        # final train pool is EVAL + CALIB units (no TRAIN units)\n",
    "        units_train_all = eval_units + calib_units\n",
    "\n",
    "        # train/eval per regime\n",
    "        for name, selection_all in train_sets.items():\n",
    "            tr_paths, tr_labels = build_train_paths_from_selected(units_train_all, selection_all, class_to_idx)\n",
    "            if len(tr_paths) == 0:\n",
    "                rows.append({\n",
    "                    \"seed\": seed, \"lambda_obs\": float(lambda_obs), \"Regime\": name,\n",
    "                    \"TrainN\": 0, \"ValAcc\": float(\"nan\"), \"TestAcc\": float(\"nan\")\n",
    "                })\n",
    "                print(f\"[{name}] Train set empty; skipping.\")\n",
    "                continue\n",
    "\n",
    "            print(f\"[seed={seed} λ={lambda_obs:.2f} {name}] TrainN={len(tr_labels)} (docs={len(units_train_all)})\")\n",
    "            model, best_val = train_cnn_from_scratch(tr_paths, tr_labels, X_val, y_val, num_classes,\n",
    "                                                    init_state=init_state, seed=seed)\n",
    "            test_acc = eval_cnn(model, X_test, y_test)\n",
    "\n",
    "            rows.append({\n",
    "                \"seed\": seed, \"lambda_obs\": float(lambda_obs), \"Regime\": name,\n",
    "                \"TrainN\": int(len(tr_labels)), \"ValAcc\": float(best_val), \"TestAcc\": float(test_acc),\n",
    "                \"kept_unfiltered\": kept_u, \"tot_unfiltered\": tot_u,\n",
    "                \"kept_filtered\": kept_f,   \"tot_filtered\": tot_f,\n",
    "                \"kept_marg\": kept_m,       \"tot_marg\": tot_m,\n",
    "                \"kept_lakcp\": kept_c,      \"tot_lakcp\": tot_c,\n",
    "            })\n",
    "\n",
    "        # optional: save partial results after each lambda (useful for long runs)\n",
    "        if save_each_csv:\n",
    "            pd.DataFrame(rows).to_csv(save_each_csv, index=False)\n",
    "\n",
    "    df = pd.DataFrame(rows,\n",
    "                      columns=[\"seed\",\"lambda_obs\",\"Regime\",\"TrainN\",\"ValAcc\",\"TestAcc\",\n",
    "                               \"kept_unfiltered\",\"tot_unfiltered\",\n",
    "                               \"kept_filtered\",\"tot_filtered\",\n",
    "                               \"kept_marg\",\"tot_marg\",\n",
    "                               \"kept_lakcp\",\"tot_lakcp\"])\n",
    "    return df\n",
    "\n",
    "\n",
    "def run_grid(seeds: List[int], lambda_grid: List[float], out_csv: Optional[str] = None) -> pd.DataFrame:\n",
    "    all_rows = []\n",
    "    for seed in seeds:\n",
    "        df_seed = run_one_seed(seed, lambda_grid, save_each_csv=str(seed) + \"results_partial.csv\")\n",
    "        all_rows.append(df_seed)\n",
    "        # optional incremental save\n",
    "        if out_csv:\n",
    "            pd.concat(all_rows, ignore_index=True).to_csv(out_csv, index=False)\n",
    "    results = pd.concat(all_rows, ignore_index=True) if all_rows else pd.DataFrame()\n",
    "    if out_csv:\n",
    "        results.to_csv(out_csv, index=False)\n",
    "    print(\"\\n=== Aggregate over seeds & λ ===\")\n",
    "    if len(results):\n",
    "        # quick pivot summary\n",
    "        print(results.groupby([\"Regime\",\"lambda_obs\"])[[\"ValAcc\",\"TestAcc\",\"TrainN\"]].mean().round(3))\n",
    "    return results\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def main():\n",
    "    seeds = [12, 22, 32, 42, 52]                  # 5 seeds\n",
    "    lambda_grid = [0.50, 0.75, 0.80, 0.90]  # 5 λ values\n",
    "    results_df = run_grid(seeds, lambda_grid, out_csv=\"new_debug_2_results_seeds_lambdas.csv\")\n",
    "    results_df.to_csv(\"final_results.csv\", index=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "new-llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
