{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "oLiu3ohakQvG",
        "outputId": "0f9a0498-1bf7-451e-e2fd-9dae9a7919d1"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 148/148 [00:01<00:00, 82.81it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 001  train_mse=2.4833  val_mse=1.9020  bad=0\n",
            "[TRAIN] PRED  y — no-motif max=2.192 | motif min/median/max=-1.209/0.149/1.571 (n_neg=37367, n_pos=315)\n",
            "[TRAIN] TRUTH y — no-motif max=10.060 | motif min/median/max=-2.491/3.529/13.860\n",
            "[VAL  ] PRED  y — no-motif max=2.279 | motif min/median/max=-1.102/0.063/1.541 (n_neg=3734, n_pos=34)\n",
            "[VAL  ] TRUTH y — no-motif max=11.449 | motif min/median/max=-1.455/1.818/12.009\n",
            "[TEST ] PRED  y — no-motif max=1.464 | motif min/median/max=-0.031/0.188/0.406 (n_neg=417, n_pos=2)\n",
            "[TEST ] TRUTH y — no-motif max=5.665 | motif min/median/max=5.309/5.469/5.628\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 148/148 [00:01<00:00, 114.88it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 002  train_mse=1.7994  val_mse=1.7755  bad=0\n",
            "[TRAIN] PRED  y — no-motif max=1.989 | motif min/median/max=-1.132/0.282/1.585 (n_neg=37367, n_pos=315)\n",
            "[TRAIN] TRUTH y — no-motif max=10.060 | motif min/median/max=-2.491/3.529/13.860\n",
            "[VAL  ] PRED  y — no-motif max=2.050 | motif min/median/max=-1.032/0.087/1.610 (n_neg=3734, n_pos=34)\n",
            "[VAL  ] TRUTH y — no-motif max=11.449 | motif min/median/max=-1.455/1.818/12.009\n",
            "[TEST ] PRED  y — no-motif max=1.410 | motif min/median/max=0.315/0.349/0.383 (n_neg=417, n_pos=2)\n",
            "[TEST ] TRUTH y — no-motif max=5.665 | motif min/median/max=5.309/5.469/5.628\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 148/148 [00:01<00:00, 110.78it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 003  train_mse=1.7039  val_mse=1.7097  bad=0\n",
            "[TRAIN] PRED  y — no-motif max=2.228 | motif min/median/max=-0.962/0.517/1.718 (n_neg=37367, n_pos=315)\n",
            "[TRAIN] TRUTH y — no-motif max=10.060 | motif min/median/max=-2.491/3.529/13.860\n",
            "[VAL  ] PRED  y — no-motif max=2.073 | motif min/median/max=-0.641/0.252/1.886 (n_neg=3734, n_pos=34)\n",
            "[VAL  ] TRUTH y — no-motif max=11.449 | motif min/median/max=-1.455/1.818/12.009\n",
            "[TEST ] PRED  y — no-motif max=1.585 | motif min/median/max=0.596/0.659/0.721 (n_neg=417, n_pos=2)\n",
            "[TEST ] TRUTH y — no-motif max=5.665 | motif min/median/max=5.309/5.469/5.628\n",
            "Done. Wrote: run_max\n"
          ]
        }
      ],
      "source": [
        "# ---- All-in-one notebook cell with TRAIN/VAL/TEST motif monitoring (pred + truth) ----\n",
        "import os, re, json, math, argparse, random, numpy as np, pandas as pd\n",
        "from pathlib import Path\n",
        "\n",
        "import os\n",
        "os.environ[\"TORCHDYNAMO_DISABLE\"] = \"1\"  # set before importing torch\n",
        "import torch\n",
        "\n",
        "\n",
        "# Optional: tqdm (safe fallback)\n",
        "try:\n",
        "    from tqdm import tqdm\n",
        "except Exception:\n",
        "    def tqdm(x, **kwargs): return x\n",
        "\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import Dataset, DataLoader, random_split\n",
        "\n",
        "\n",
        "# =========================\n",
        "# Helpers\n",
        "# =========================\n",
        "\n",
        "def set_seed(seed: int = 1337):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "def r2_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:\n",
        "    y_true = np.asarray(y_true).flatten()\n",
        "    y_pred = np.asarray(y_pred).flatten()\n",
        "    ss_res = float(np.sum((y_true - y_pred) ** 2))\n",
        "    ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2)) + 1e-9\n",
        "    return 1.0 - ss_res / ss_tot\n",
        "\n",
        "def spearmanr_np(a: np.ndarray, b: np.ndarray) -> float:\n",
        "    a = np.asarray(a).flatten()\n",
        "    b = np.asarray(b).flatten()\n",
        "    def rankdata(x):\n",
        "        order = np.argsort(x, kind=\"mergesort\")\n",
        "        ranks = np.empty_like(order, dtype=float)\n",
        "        ranks[order] = np.arange(len(x), dtype=float)\n",
        "        vals, inv, counts = np.unique(x, return_inverse=True, return_counts=True)\n",
        "        csum = np.cumsum(counts); start = np.concatenate(([0], csum[:-1]))\n",
        "        avg = (start + csum - 1) / 2.0\n",
        "        return avg[inv]\n",
        "    ra = rankdata(a); rb = rankdata(b)\n",
        "    ra = (ra - ra.mean()) / (ra.std() + 1e-9)\n",
        "    rb = (rb - rb.mean()) / (rb.std() + 1e-9)\n",
        "    return float(np.mean(ra * rb))\n",
        "\n",
        "# --- Robust PBM loader ---\n",
        "_float_pat = re.compile(r\"[-+]?\\d+(?:\\.\\d+)?(?:[eE][-+]?\\d+)?\")\n",
        "_seq_pat   = re.compile(r\"[ACGTacgt]{20,}\")\n",
        "\n",
        "def _parse_line(line: str):\n",
        "    if not line.strip() or line.lstrip().startswith((\"#\", \"//\")):\n",
        "        return None\n",
        "    ym = _float_pat.search(line)\n",
        "    if not ym: return None\n",
        "    y_str = ym.group(0)\n",
        "    seqs = _seq_pat.findall(line)\n",
        "    if not seqs: return None\n",
        "    seq = max(seqs, key=len).upper().strip()\n",
        "    try: y = float(y_str)\n",
        "    except Exception: return None\n",
        "    seq = re.sub(r\"[^ACGT]\", \"\", seq)\n",
        "    return y, seq\n",
        "\n",
        "def load_pbm(path: str) -> pd.DataFrame:\n",
        "    rows = []\n",
        "    try:\n",
        "        tmp = pd.read_csv(path, sep=r\"\\s+\", header=None, engine=\"python\")\n",
        "        for _, row in tmp.iterrows():\n",
        "            joined = \" \".join(str(x) for x in row if pd.notna(x))\n",
        "            parsed = _parse_line(joined)\n",
        "            if parsed: rows.append(parsed)\n",
        "    except Exception:\n",
        "        with open(path, \"r\") as f:\n",
        "            for line in f:\n",
        "                parsed = _parse_line(line)\n",
        "                if parsed: rows.append(parsed)\n",
        "    out = pd.DataFrame(rows, columns=[\"y_raw\", \"seq\"])\n",
        "    out[\"seq\"] = out[\"seq\"].str.replace(r\"[^ACGT]\", \"\", regex=True).str.upper()\n",
        "    out[\"len\"] = out[\"seq\"].str.len()\n",
        "    return out\n",
        "\n",
        "def normalize_y(y, method: str = \"robust\"):\n",
        "    y = np.asarray(y, dtype=float)\n",
        "    eps = 1e-9\n",
        "    if method == \"none\":\n",
        "        return y, y\n",
        "    ylog = np.log1p(np.maximum(y, 0.0) + eps)\n",
        "    if method == \"logz\":\n",
        "        mu, sd = float(np.mean(ylog)), float(np.std(ylog) + eps)\n",
        "        yn = (ylog - mu) / sd\n",
        "    elif method == \"robust\":\n",
        "        med = float(np.median(ylog))\n",
        "        mad = float(np.median(np.abs(ylog - med)))\n",
        "        scale = 1.4826 * mad if mad > 0 else float(np.std(ylog) + eps)\n",
        "        yn = (ylog - med) / (scale if scale > 0 else 1.0)\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown normalize method: {method}\")\n",
        "    return ylog, yn\n",
        "\n",
        "# =========================\n",
        "# Data encoding & model\n",
        "# =========================\n",
        "DNA = \"ACGT\"\n",
        "BASE_TO_IDX = {b: i for i, b in enumerate(DNA)}\n",
        "\n",
        "def bipolar_encode(seq: str, alphabet: str = DNA) -> np.ndarray:\n",
        "    L = len(seq); C = len(alphabet)\n",
        "    x = -np.ones((C, L), dtype=np.float32)\n",
        "    for i, ch in enumerate(seq):\n",
        "        j = BASE_TO_IDX.get(ch)\n",
        "        if j is not None:\n",
        "            x[j, i] = 1.0\n",
        "    return x\n",
        "\n",
        "class SeqDataset(Dataset):\n",
        "    def __init__(self, df: pd.DataFrame, target_col: str = \"y_norm\"):\n",
        "        self.seqs = df[\"seq\"].tolist()\n",
        "        self.y    = df[target_col].astype(np.float32).values.reshape(-1, 1)\n",
        "    def __len__(self): return len(self.seqs)\n",
        "    def __getitem__(self, idx):\n",
        "        x = bipolar_encode(self.seqs[idx])\n",
        "        return torch.from_numpy(x), torch.from_numpy(self.y[idx])\n",
        "\n",
        "class MaxCNN(nn.Module):\n",
        "    def __init__(self, L: int, k: int = 11, n_filters: int = 64):\n",
        "        super().__init__()\n",
        "        self.conv = nn.Conv1d(4, n_filters, kernel_size=k, padding=0)\n",
        "        self.act  = nn.ReLU(inplace=True)\n",
        "        self.head = nn.Linear(n_filters, 1)\n",
        "        nn.init.kaiming_normal_(self.conv.weight, nonlinearity=\"relu\")\n",
        "        nn.init.zeros_(self.conv.bias)\n",
        "        nn.init.xavier_uniform_(self.head.weight)\n",
        "        nn.init.zeros_(self.head.bias)\n",
        "    def forward(self, x):\n",
        "        h = self.act(self.conv(x))\n",
        "        h = torch.amax(h, dim=-1)   # global max\n",
        "        return self.head(h)\n",
        "\n",
        "# =========================\n",
        "# Monitoring helpers (updated to report PRED + TRUE)\n",
        "# =========================\n",
        "def eval_all(model, loader, device):\n",
        "    \"\"\"Return (y_pred, y_true) for all items in the given loader (no shuffle).\"\"\"\n",
        "    model.eval()\n",
        "    y_pred, y_true = [], []\n",
        "    with torch.no_grad():\n",
        "        for xb, yb in loader:\n",
        "            xb = xb.to(device)\n",
        "            p  = model(xb).cpu().numpy()\n",
        "            y_pred.append(p)\n",
        "            y_true.append(yb.numpy())\n",
        "    return np.vstack(y_pred).flatten(), np.vstack(y_true).flatten()\n",
        "\n",
        "def _stats_tuple(arr):\n",
        "    if arr.size == 0: return (np.nan, np.nan, np.nan, np.nan)  # (min, median, max, count=nan here)\n",
        "    return (float(np.min(arr)), float(np.median(arr)), float(np.max(arr)), float(arr.size))\n",
        "\n",
        "def motif_stats_pair(y_pred: np.ndarray, y_true: np.ndarray, subset, has_motif_all: np.ndarray):\n",
        "    \"\"\"\n",
        "    Compute motif-sliced stats for both predicted and true y on a subset.\n",
        "    Returns two dicts: pred_stats and true_stats, each with keys:\n",
        "        - 'neg_max'  (no motif, max)\n",
        "        - 'pos_min', 'pos_median', 'pos_max'\n",
        "        - 'n_neg', 'n_pos'\n",
        "    \"\"\"\n",
        "    if hasattr(subset, \"indices\"):\n",
        "        idx = np.array(subset.indices)\n",
        "    else:\n",
        "        idx = np.arange(len(y_pred))\n",
        "\n",
        "    flags = has_motif_all[idx].astype(bool)\n",
        "    neg = ~flags\n",
        "    pos =  flags\n",
        "\n",
        "    # predicted\n",
        "    neg_max_pred = float(np.max(y_pred[neg])) if np.any(neg) else np.nan\n",
        "    pos_min_pred, pos_med_pred, pos_max_pred, n_pos = _stats_tuple(y_pred[pos])\n",
        "    n_neg = float(neg.sum())\n",
        "\n",
        "    pred_stats = dict(\n",
        "        neg_max=neg_max_pred,\n",
        "        pos_min=pos_min_pred, pos_median=pos_med_pred, pos_max=pos_max_pred,\n",
        "        n_neg=n_neg, n_pos=n_pos\n",
        "    )\n",
        "\n",
        "    # true\n",
        "    neg_max_true = float(np.max(y_true[neg])) if np.any(neg) else np.nan\n",
        "    pos_min_true, pos_med_true, pos_max_true, _ = _stats_tuple(y_true[pos])\n",
        "\n",
        "    true_stats = dict(\n",
        "        neg_max=neg_max_true,\n",
        "        pos_min=pos_min_true, pos_median=pos_med_true, pos_max=pos_max_true,\n",
        "        n_neg=n_neg, n_pos=n_pos\n",
        "    )\n",
        "\n",
        "    return pred_stats, true_stats\n",
        "\n",
        "def print_motif_report(split_name, pred_stats, true_stats):\n",
        "    print(\n",
        "        f\"[{split_name}] PRED  y — no-motif max={pred_stats['neg_max']:.3f} | \"\n",
        "        f\"motif min/median/max={pred_stats['pos_min']:.3f}/{pred_stats['pos_median']:.3f}/{pred_stats['pos_max']:.3f} \"\n",
        "        f\"(n_neg={int(pred_stats['n_neg'])}, n_pos={int(pred_stats['n_pos'])})\"\n",
        "    )\n",
        "    print(\n",
        "        f\"[{split_name}] TRUTH y — no-motif max={true_stats['neg_max']:.3f} | \"\n",
        "        f\"motif min/median/max={true_stats['pos_min']:.3f}/{true_stats['pos_median']:.3f}/{true_stats['pos_max']:.3f}\"\n",
        "    )\n",
        "\n",
        "# =========================\n",
        "# Training with per-split motif monitoring (pred + truth)\n",
        "# =========================\n",
        "def main():\n",
        "    # Notebook-style args\n",
        "    args = argparse.Namespace()\n",
        "    args.input = \"Max_3864.1_v2_deBruijn.txt\"\n",
        "    args.outdir = \"run_max\"\n",
        "    args.normalize = \"robust\"\n",
        "    args.epochs = 300\n",
        "    args.batch = 256\n",
        "    args.lr = 1e-3\n",
        "    args.seed = 1337\n",
        "    args.checkpoint_dir = \"ckpts\"\n",
        "    motif = \"CACGTG\"\n",
        "\n",
        "    set_seed(args.seed)\n",
        "    outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)\n",
        "    ckpt_dir = Path(args.checkpoint_dir); ckpt_dir.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "    # Load & normalize\n",
        "    raw = load_pbm(args.input)\n",
        "    ylog, ynorm = normalize_y(raw[\"y_raw\"].values, method=args.normalize)\n",
        "    raw[\"y_log1p\"] = ylog\n",
        "    raw[\"y_norm\"]  = ynorm\n",
        "    raw[\"has_motif\"] = raw[\"seq\"].str.contains(motif)\n",
        "    raw.to_csv(outdir / \"processed.csv\", index=False)\n",
        "\n",
        "    # Splits\n",
        "    frac_train, frac_val = 0.90, 0.09\n",
        "    N = len(raw)\n",
        "    n_train = int(N * frac_train)\n",
        "    n_val   = int(N * frac_val)\n",
        "    n_test  = N - n_train - n_val\n",
        "    lengths = [n_train, n_val, n_test]\n",
        "\n",
        "    ds = SeqDataset(raw, target_col=\"y_norm\")\n",
        "    train_ds, val_ds, test_ds = random_split(ds, lengths, generator=torch.Generator().manual_seed(args.seed))\n",
        "\n",
        "    # Model\n",
        "    L = len(raw[\"seq\"].iloc[0])\n",
        "    model = MaxCNN(L=L, k=11, n_filters=64)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model.to(device)\n",
        "\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=args.lr)\n",
        "    loss_fn = nn.MSELoss()\n",
        "\n",
        "    # Loaders\n",
        "    train_loader      = DataLoader(train_ds, batch_size=args.batch, shuffle=True,  drop_last=False)\n",
        "    train_eval_loader = DataLoader(train_ds, batch_size=args.batch, shuffle=False, drop_last=False)  # for monitoring\n",
        "    val_loader        = DataLoader(val_ds,   batch_size=args.batch, shuffle=False, drop_last=False)\n",
        "    test_loader       = DataLoader(test_ds,  batch_size=args.batch, shuffle=False, drop_last=False)\n",
        "\n",
        "    best_val = float(\"inf\"); best_state = None; patience = 6; bad = 0\n",
        "    has_motif_all = raw[\"has_motif\"].to_numpy()\n",
        "\n",
        "    for epoch in range(1, args.epochs + 1):\n",
        "        # ---- Train ----\n",
        "        model.train()\n",
        "        tr_loss = 0.0\n",
        "        for xb, yb in tqdm(train_loader):\n",
        "            xb = xb.to(device); yb = yb.to(device)\n",
        "            opt.zero_grad()\n",
        "            pred = model(xb)\n",
        "            loss = loss_fn(pred, yb)\n",
        "            loss.backward(); opt.step()\n",
        "            tr_loss += float(loss.item()) * len(xb)\n",
        "        tr_loss /= len(train_loader.dataset)\n",
        "\n",
        "        # ---- Validate loss ----\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            vs, vt = [], []\n",
        "            for xb, yb in val_loader:\n",
        "                xb = xb.to(device); yb = yb.to(device)\n",
        "                p = model(xb)\n",
        "                vs.append(p.cpu().numpy()); vt.append(yb.cpu().numpy())\n",
        "        vpred = np.vstack(vs).flatten()\n",
        "        vtrue = np.vstack(vt).flatten()\n",
        "        v_mse = float(np.mean((vpred - vtrue) ** 2))\n",
        "\n",
        "        improved = v_mse < best_val\n",
        "        if improved:\n",
        "            best_val = v_mse\n",
        "            best_state = {k: v.cpu() for k, v in model.state_dict().items()}\n",
        "            bad = 0\n",
        "        else:\n",
        "            bad += 1\n",
        "\n",
        "        # ---- Per-split motif monitoring (PRED + TRUTH) ----\n",
        "        y_pred_train, y_true_train = eval_all(model, train_eval_loader, device)\n",
        "        y_pred_test,  y_true_test  = eval_all(model, test_loader, device)\n",
        "\n",
        "        tr_pred_stats, tr_true_stats = motif_stats_pair(y_pred_train, y_true_train, train_ds, has_motif_all)\n",
        "        va_pred_stats, va_true_stats = motif_stats_pair(vpred,          vtrue,       val_ds,   has_motif_all)\n",
        "        te_pred_stats, te_true_stats = motif_stats_pair(y_pred_test,    y_true_test, test_ds,  has_motif_all)\n",
        "\n",
        "        print(f\"Epoch {epoch:03d}  train_mse={tr_loss:.4f}  val_mse={v_mse:.4f}  bad={bad}\")\n",
        "        print_motif_report(\"TRAIN\", tr_pred_stats, tr_true_stats)\n",
        "        print_motif_report(\"VAL  \", va_pred_stats, va_true_stats)\n",
        "        print_motif_report(\"TEST \", te_pred_stats, te_true_stats)\n",
        "\n",
        "        # ---- Checkpoints every epoch ----\n",
        "        torch.save(model.state_dict(), ckpt_dir / f\"epoch_{epoch:03d}.pt\")\n",
        "        torch.save(model.state_dict(), outdir / \"last.pt\")\n",
        "        if improved:\n",
        "            torch.save(model.state_dict(), outdir / \"best.pt\")\n",
        "\n",
        "        # Optional early stop:\n",
        "        # if bad >= patience: break\n",
        "\n",
        "    # ---- Load best and final test metrics ----\n",
        "    if best_state is not None:\n",
        "        model.load_state_dict(best_state)\n",
        "\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        ps, ts = [], []\n",
        "        for xb, yb in test_loader:\n",
        "            xb = xb.to(device); yb = yb.to(device)\n",
        "            p = model(xb)\n",
        "            ps.append(p.cpu().numpy()); ts.append(yb.cpu().numpy())\n",
        "    pred = np.vstack(ps); true = np.vstack(ts)\n",
        "    mse = float(np.mean((pred - true) ** 2))\n",
        "    r2  = r2_score(true, pred)\n",
        "    spr = spearmanr_np(true, pred)\n",
        "\n",
        "    metrics = {\"val_best_mse\": best_val, \"test_mse\": mse, \"test_r2\": r2, \"test_spearman\": spr}\n",
        "    with open(outdir / \"metrics.json\", \"w\") as f:\n",
        "        json.dump(metrics, f, indent=2)\n",
        "\n",
        "    torch.save(model.state_dict(), outdir / \"model.pt\")\n",
        "\n",
        "    # Score all sequences and dump top hits\n",
        "    all_loader = DataLoader(ds, batch_size=args.batch, shuffle=False)\n",
        "    with torch.no_grad():\n",
        "        all_pred = []\n",
        "        for xb, _ in all_loader:\n",
        "            xb = xb.to(device)\n",
        "            p = model(xb).cpu().numpy().flatten().tolist()\n",
        "            all_pred.extend(p)\n",
        "    full = raw.copy()\n",
        "    full[\"y_pred\"] = all_pred\n",
        "    full.sort_values(\"y_pred\", ascending=False).head(200).to_csv(outdir / \"top200_predicted.csv\", index=False)\n",
        "\n",
        "    print(\"Done. Wrote:\", outdir)\n",
        "    return model\n",
        "\n",
        "# Run:\n",
        "model = main()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>y_raw</th>\n",
              "      <th>seq</th>\n",
              "      <th>len</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>230397.452833</td>\n",
              "      <td>CTGCCACGTGGTTGCAACTCGGTAACGATCTTGTTCGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>170603.094203</td>\n",
              "      <td>CGGTCACGTGGAGCTCATCATAACATCTATGAGCACGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>159794.900107</td>\n",
              "      <td>GATACAACACGTGTTTGTGACATGGGAGTTCAAACGGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>155864.962040</td>\n",
              "      <td>GCCAATCACGTGACAGGCCTCCGTTGAGAGGATTGAGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>131023.633683</td>\n",
              "      <td>AGAGCCACGTGGGACGGCCCCGCCTCATCGGTCAATGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>41864</th>\n",
              "      <td>-37.047415</td>\n",
              "      <td>CGAGAGTGTAGTGGCTCTGCTTAAAGTCGGAAGTCAGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>41865</th>\n",
              "      <td>-82.862175</td>\n",
              "      <td>GTGTCCACTCTTAACACCTGATGGAGGACAAAAGCCGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>41866</th>\n",
              "      <td>-94.409077</td>\n",
              "      <td>ATCGTAACCGCCGGAGTAGTGTTGTGCCGAGCCAAAGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>41867</th>\n",
              "      <td>-148.365057</td>\n",
              "      <td>GGCCCAAGGACCGTATTGAGGGTAGTAGACACGTAAGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>41868</th>\n",
              "      <td>-528.289795</td>\n",
              "      <td>GCAAGGGGTCAGAGCGAGCGCTACAAAATTGACATCGTCTGTGTTC...</td>\n",
              "      <td>60</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>41869 rows × 3 columns</p>\n",
              "</div>"
            ],
            "text/plain": [
              "               y_raw                                                seq  len\n",
              "0      230397.452833  CTGCCACGTGGTTGCAACTCGGTAACGATCTTGTTCGTCTGTGTTC...   60\n",
              "1      170603.094203  CGGTCACGTGGAGCTCATCATAACATCTATGAGCACGTCTGTGTTC...   60\n",
              "2      159794.900107  GATACAACACGTGTTTGTGACATGGGAGTTCAAACGGTCTGTGTTC...   60\n",
              "3      155864.962040  GCCAATCACGTGACAGGCCTCCGTTGAGAGGATTGAGTCTGTGTTC...   60\n",
              "4      131023.633683  AGAGCCACGTGGGACGGCCCCGCCTCATCGGTCAATGTCTGTGTTC...   60\n",
              "...              ...                                                ...  ...\n",
              "41864     -37.047415  CGAGAGTGTAGTGGCTCTGCTTAAAGTCGGAAGTCAGTCTGTGTTC...   60\n",
              "41865     -82.862175  GTGTCCACTCTTAACACCTGATGGAGGACAAAAGCCGTCTGTGTTC...   60\n",
              "41866     -94.409077  ATCGTAACCGCCGGAGTAGTGTTGTGCCGAGCCAAAGTCTGTGTTC...   60\n",
              "41867    -148.365057  GGCCCAAGGACCGTATTGAGGGTAGTAGACACGTAAGTCTGTGTTC...   60\n",
              "41868    -528.289795  GCAAGGGGTCAGAGCGAGCGCTACAAAATTGACATCGTCTGTGTTC...   60\n",
              "\n",
              "[41869 rows x 3 columns]"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import pandas as pd\n",
        "import re\n",
        "\n",
        "_float_pat = re.compile(r\"[-+]?\\d+(?:\\.\\d+)?(?:[eE][-+]?\\d+)?\")\n",
        "_seq_pat   = re.compile(r\"[ACGTacgt]{20,}\")\n",
        "\n",
        "def _parse_line(line: str):\n",
        "    if not line.strip() or line.lstrip().startswith((\"#\", \"//\")):\n",
        "        return None\n",
        "    ym = _float_pat.search(line)\n",
        "    if not ym: return None\n",
        "    y_str = ym.group(0)\n",
        "    seqs = _seq_pat.findall(line)\n",
        "    if not seqs: return None\n",
        "    seq = max(seqs, key=len).upper().strip()\n",
        "    try: y = float(y_str)\n",
        "    except Exception: return None\n",
        "    seq = re.sub(r\"[^ACGT]\", \"\", seq)\n",
        "    return y, seq\n",
        "\n",
        "def load_pbm(path: str) -> pd.DataFrame:\n",
        "    rows = []\n",
        "    try:\n",
        "        tmp = pd.read_csv(path, sep=r\"\\s+\", header=None, engine=\"python\")\n",
        "        for _, row in tmp.iterrows():\n",
        "            joined = \" \".join(str(x) for x in row if pd.notna(x))\n",
        "            parsed = _parse_line(joined)\n",
        "            if parsed: rows.append(parsed)\n",
        "    except Exception:\n",
        "        with open(path, \"r\") as f:\n",
        "            for line in f:\n",
        "                parsed = _parse_line(line)\n",
        "                if parsed: rows.append(parsed)\n",
        "    out = pd.DataFrame(rows, columns=[\"y_raw\", \"seq\"])\n",
        "    out[\"seq\"] = out[\"seq\"].str.replace(r\"[^ACGT]\", \"\", regex=True).str.upper()\n",
        "    out[\"len\"] = out[\"seq\"].str.len()\n",
        "    return out\n",
        "\n",
        "load_pbm(\"Max_3864.1_v2_deBruijn.txt\")\n",
        "    "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "hrWWocandVXl"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "from typing import List, Tuple\n",
        "\n",
        "# --- helpers for bipolar encoding (+1 for chosen base channel, -1 for others) ---\n",
        "\n",
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "DNA = \"ACGT\"\n",
        "BASE_TO_IDX = {b:i for i,b in enumerate(DNA)}\n",
        "\n",
        "def random_bipolar_sequence(L: int, device=None, dtype=torch.float32) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    (4, L) in {-1,+1}, exactly one +1 per column.\n",
        "    \"\"\"\n",
        "    idx = torch.randint(0, 4, (L,), device=device)\n",
        "    x = -torch.ones((4, L), dtype=dtype, device=device)\n",
        "    x[idx, torch.arange(L, device=device)] = 1.0\n",
        "    return x\n",
        "\n",
        "def random_bipolar_batch(N: int, L: int, device=None, dtype=torch.float32) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    (N, 4, L) in {-1,+1}, exactly one +1 per column for each sample.\n",
        "    Fully vectorized (no Python loops).\n",
        "    \"\"\"\n",
        "    idx = torch.randint(0, 4, (N, L), device=device)            # base index per (n, pos)\n",
        "    x = -torch.ones((N, 4, L), dtype=dtype, device=device)      # start at -1\n",
        "    arN = torch.arange(N, device=device)[:, None]               # shape (N,1)\n",
        "    arL = torch.arange(L, device=device)[None, :]               # shape (1,L)\n",
        "    x[arN, idx, arL] = 1.0                                      # set chosen base to +1\n",
        "    return x\n",
        "\n",
        "# (optional) batch decoder and boolean motif mask\n",
        "def decode_sequences(X_4L: torch.Tensor) -> list[str]:\n",
        "    \"\"\"\n",
        "    X_4L: (N,4,L) tensor in {-1,+1} -> list of N strings over A/C/G/T\n",
        "    \"\"\"\n",
        "    if X_4L.dim() == 2:  # (4,L) -> single string\n",
        "        return [\"\".join(DNA[i] for i in torch.argmax(X_4L, dim=0).tolist())]\n",
        "    idx = torch.argmax(X_4L, dim=1).cpu().tolist()              # (N,L)\n",
        "    return [\"\".join(DNA[i] for i in row) for row in idx]\n",
        "\n",
        "def motif_mask(X_4L: torch.Tensor, motif: str = \"CACGTG\") -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Returns boolean tensor (N,) indicating motif presence in decoded sequences.\n",
        "    \"\"\"\n",
        "    seqs = decode_sequences(X_4L)\n",
        "    return torch.tensor([motif in s for s in seqs], dtype=torch.bool, device=X_4L.device)\n",
        "\n",
        "\n",
        "# --- GWG sampler over categorical moves (per-position base changes) ---\n",
        "\n",
        "\n",
        "class GWGSamplerApprox:\n",
        "    \"\"\"\n",
        "    Gradient-based GWG over categorical bases with ±1 bipolar encoding.\n",
        "    Uses a first-order ΔE approximation from ∇y; proposals respect the 1-of-4 per column constraint.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, model: torch.nn.Module, beta: float = 1.0):\n",
        "        self.model = model\n",
        "        self.beta = float(beta)\n",
        "\n",
        "    def _grad_y(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        Compute ∇_x y(x) for x shaped (4,L). Returns (4,L).\n",
        "        NOTE: must be called OUTSIDE of no_grad().\n",
        "        \"\"\"\n",
        "        with torch.enable_grad(): \n",
        "            x = x.detach().clone().requires_grad_(True)\n",
        "            y = self.model(x.unsqueeze(0)).view(())  # scalar\n",
        "            (g,) = torch.autograd.grad(y, x, create_graph=False, retain_graph=False)\n",
        "        return g  # keep as leaf with grad info detached by caller if needed\n",
        "\n",
        "    def _approx_deltas_and_moves(self, x: torch.Tensor):\n",
        "        \"\"\"\n",
        "        Build approximate ΔE for every categorical move (3 per position) using g = ∇y.\n",
        "        Returns:\n",
        "            deltas : (M,) with ΔE ≈ -2*(g[b,j] - g[a,j])\n",
        "            j_idx  : (M,) positions\n",
        "            b_idx  : (M,) target bases\n",
        "        \"\"\"\n",
        "        device = x.device\n",
        "        L = x.shape[1]\n",
        "        g = self._grad_y(x)                  # (4,L)\n",
        "        curr = torch.argmax(x, dim=0)        # (L,)\n",
        "\n",
        "        # Broadcast g[a_j, j] across rows to subtract\n",
        "        g_a = g.gather(0, curr.unsqueeze(0).expand(4, -1))  # (4,L)\n",
        "        delta_mat = -2.0 * (g - g_a)                        # (4,L)\n",
        "\n",
        "        bgrid = torch.arange(4, device=device)[:, None].expand(4, L)\n",
        "        jgrid = torch.arange(L, device=device)[None, :].expand(4, L)\n",
        "        mask = (bgrid != curr[None, :])  # exclude staying at the same base\n",
        "\n",
        "        deltas = delta_mat[mask]\n",
        "        j_idx  = jgrid[mask]\n",
        "        b_idx  = bgrid[mask]\n",
        "        return deltas, j_idx, b_idx\n",
        "\n",
        "    def step(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        One approximate GWG step (needs gradients ON).\n",
        "        \"\"\"\n",
        "        x = x.detach().clone()\n",
        "        device = x.device\n",
        "\n",
        "        # Forward proposal at x\n",
        "        deltas, j_idx, b_idx = self._approx_deltas_and_moves(x)  # uses grad\n",
        "        logits = -self.beta * deltas / 2.0\n",
        "        probs  = torch.softmax(logits, dim=0)\n",
        "        i = torch.multinomial(probs, 1).item()\n",
        "\n",
        "        j = int(j_idx[i].item())\n",
        "        b = int(b_idx[i].item())\n",
        "        a = int(torch.argmax(x[:, j]).item())\n",
        "\n",
        "        # Proposed state x'\n",
        "        x_new = x.clone()\n",
        "        x_new[:, j] = -1.0\n",
        "        x_new[b, j] = 1.0\n",
        "\n",
        "        q_fwd   = probs[i]\n",
        "        delta_i = deltas[i]  # approx ΔE(x->x')\n",
        "\n",
        "        # Reverse proposal at x'\n",
        "        deltas_p, j_idx_p, b_idx_p = self._approx_deltas_and_moves(x_new)  # uses grad at x'\n",
        "        rev_mask = (j_idx_p == j) & (b_idx_p == a)\n",
        "        if not torch.any(rev_mask):\n",
        "            # Shouldn't happen; fall back to taking the move\n",
        "            return x_new\n",
        "\n",
        "        logits_p = -self.beta * deltas_p / 2.0\n",
        "        probs_p  = torch.softmax(logits_p, dim=0)\n",
        "        q_rev    = probs_p[rev_mask].item()\n",
        "\n",
        "        # MH accept with approximate ΔE and proposal ratio\n",
        "        accept = torch.exp(-self.beta * delta_i) * (q_rev / q_fwd)\n",
        "        if torch.rand((), device=device) < torch.clamp(accept, max=1.0):\n",
        "            return x_new\n",
        "        return x\n",
        "\n",
        "    def sample(self, x0: torch.Tensor, steps: int, record_every: int = 1):\n",
        "        \"\"\"\n",
        "        Run the chain, recording states/scores every 'record_every'.\n",
        "        Only the recording of y uses no_grad(); steps need gradients enabled.\n",
        "        \"\"\"\n",
        "        xs, ys, accs = [], [], []\n",
        "        x = x0.detach().clone()\n",
        "        self.model.eval()\n",
        "        for t in range(1, steps + 1):\n",
        "            x_prev = x\n",
        "            x = self.step(x)\n",
        "            accepted = float((x != x_prev).any().item())\n",
        "            if (t % record_every) == 0:\n",
        "                with torch.no_grad():\n",
        "                    y_val = self.model(x.unsqueeze(0)).view(()).item()\n",
        "                xs.append(x.clone())\n",
        "                ys.append(y_val)\n",
        "                accs.append(accepted)\n",
        "        return xs, ys, accs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [],
      "source": [
        "def check_if_particle_contains_motif(particles_np, motif=\"CACGTG\"):\n",
        "    \"\"\"\n",
        "    particles_np: (N, 4, L) bipolar one-hot\n",
        "    Returns: bool array of shape (N,) indicating presence of motif\n",
        "    \"\"\"\n",
        "    # if particles_np is torch.Tensor, convert to numpy\n",
        "    if isinstance(particles_np, torch.Tensor):\n",
        "        particles_np = particles_np.cpu().numpy()   \n",
        "    BASES = np.array(list(\"ACGT\"))\n",
        "    idx = particles_np.argmax(axis=1)                     # (N, L) best channel per position\n",
        "    seqs = [''.join(BASES[row]) for row in idx]            # list of strings length N\n",
        "    has = np.fromiter((motif in s for s in seqs), dtype=bool, count=len(seqs))\n",
        "    return has"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [],
      "source": [
        "N = 30 # number of particles\n",
        "\n",
        "run_count = 10\n",
        "\n",
        "L = 60"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {},
      "outputs": [],
      "source": [
        "def load_final_model():\n",
        "    checkpoint_path = \"ckpts/epoch_300.pt\"\n",
        "    model = MaxCNN(L=60, k=11, n_filters=64)\n",
        "    state = torch.load(checkpoint_path, map_location=\"cpu\")\n",
        "    model.load_state_dict(state)\n",
        "    model.eval()\n",
        "    return model #.cuda()    "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch, numpy as np\n",
        "\n",
        "task_name = \"GWG\"\n",
        "\n",
        "motif_fracs = []\n",
        "energies_per_run = []\n",
        "sequences_per_run = []\n",
        "\n",
        "for run in range(run_count):\n",
        "\n",
        "    baseline_motif_count = 0\n",
        "    energies = []\n",
        "    sequences = []\n",
        "\n",
        "    for i in range(N):\n",
        "        # Assumes MaxCNN, GWGSamplerApprox, random_bipolar_sequence, decode_sequence are already defined\n",
        "\n",
        "        L = 60\n",
        "        ckpt_path = \"ckpts/epoch_300.pt\"   # <- adjust if your path differs\n",
        "\n",
        "        # build model to match training hyperparams\n",
        "        model = MaxCNN(L=L, k=11, n_filters=64)\n",
        "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        model.to(device)\n",
        "\n",
        "        # load final weights (raw state_dict)\n",
        "        state = torch.load(ckpt_path, map_location=device)\n",
        "        model.load_state_dict(state)   # if your file was a dict with 'model_state_dict', use model.load_state_dict(state['model_state_dict'])\n",
        "        model.eval()\n",
        "        #print(\"Loaded final model from:\", ckpt_path)\n",
        "\n",
        "        # run GWG (approximate) for 30 steps\n",
        "        x0 = random_bipolar_sequence(L, device=device)\n",
        "        sampler = GWGSamplerApprox(model, beta=10.0)\n",
        "\n",
        "        xs, ys, accs = sampler.sample(x0, steps=60, record_every=1)\n",
        "\n",
        "        # quick summary\n",
        "        ys_np = np.array(ys)\n",
        "        best_i = int(np.argmax(ys_np))\n",
        "        #print(f\"steps={len(ys_np)} | y min/median/max = {ys_np.min():.3f}/{np.median(ys_np):.3f}/{ys_np.max():.3f}\")\n",
        "        #print(\"best y:\", ys_np[best_i])\n",
        "        #print(\"best seq:\", decode_sequences(xs[best_i].reshape(1,4,L)))\n",
        "\n",
        "        sequences.append(decode_sequences(xs[best_i].reshape(1,4,L))[0])\n",
        "\n",
        "        for x in xs:\n",
        "            if \"CACGTG\" in decode_sequences(x.reshape(1,4,L))[0]:\n",
        "                #print(\"Found motif in sampled sequence:\")\n",
        "                baseline_motif_count +=1\n",
        "                break\n",
        "            \n",
        "        energies.append(ys_np[best_i])\n",
        "    motif_fracs.append(baseline_motif_count / N)\n",
        "    energies_per_run.append(energies)\n",
        "    sequences_per_run.append(sequences) \n",
        "    #save above code in csv\n",
        "\n",
        "np.savetxt(f\"motif_fracs_{task_name}.csv\", np.array(motif_fracs), delimiter=\",\")\n",
        "np.savetxt(f\"energies_per_run_{task_name}.csv\", np.array(energies_per_run), delimiter=\",\")\n",
        "np.savetxt(f\"sequences_per_run_{task_name}.csv\", np.array(sequences_per_run), delimiter=\",\", fmt=\"%s\")  \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {},
      "outputs": [],
      "source": [
        "import math\n",
        "import random\n",
        "from typing import Iterable, List, Dict, Any, Optional, Tuple\n",
        "import torch\n",
        "\n",
        "def make_beta_ladder(\n",
        "    n_replicas: int,\n",
        "    beta_min: float = 0.05,\n",
        "    beta_max: float = 1.0,\n",
        "    spacing: str = \"geometric\",\n",
        ") -> List[float]:\n",
        "    \"\"\"\n",
        "    Build an inverse-temperature ladder β_1 >= ... >= β_K.\n",
        "    Geometric spacing tends to equalize swap rates across the ladder.\n",
        "    \"\"\"\n",
        "    assert n_replicas >= 2\n",
        "    assert 0.0 < beta_min <= beta_max <= 1e6\n",
        "    if spacing == \"geometric\":\n",
        "        r = (beta_min / beta_max) ** (1.0 / (n_replicas - 1))\n",
        "        return [float(beta_max * (r ** i)) for i in range(n_replicas)]\n",
        "    elif spacing == \"linear\":\n",
        "        return [float(beta_max - i * (beta_max - beta_min) / (n_replicas - 1)) for i in range(n_replicas)]\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown spacing: {spacing}\")\n",
        "\n",
        "class ParallelTemperingGWG:\n",
        "    \"\"\"\n",
        "    Parallel Tempering (Replica Exchange) wrapper for GWGSamplerApprox.\n",
        "\n",
        "    - Local proposals: your GWGSamplerApprox.step(x) at each β_k.\n",
        "    - Swap proposals: adjacent replicas (even/odd schedule).\n",
        "      Acceptance uses the exact energies E(x) = -y(x), with:\n",
        "          a = exp((β_i - β_j) * (y_i - y_j))    # since E = -y\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self,\n",
        "                 model: torch.nn.Module,\n",
        "                 betas: Iterable[float]):\n",
        "        betas = sorted([float(b) for b in betas], reverse=True)  # β_0 >= β_1 >= ... >= β_{K-1}\n",
        "        if len(betas) < 2:\n",
        "            raise ValueError(\"Need at least 2 replicas for parallel tempering.\")\n",
        "        self.model = model\n",
        "        self.betas = betas\n",
        "        self.samplers = [GWGSamplerApprox(model, beta=b) for b in betas]\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _forward_y(self, x: torch.Tensor) -> float:\n",
        "        return float(self.model(x.unsqueeze(0)).view(()).item())\n",
        "\n",
        "    def sample(self,\n",
        "               x0: torch.Tensor,\n",
        "               steps: int,\n",
        "               swap_every: int = 10,\n",
        "               record_every: int = 1,\n",
        "               record_all: bool = False,\n",
        "               init_mode: str = \"copy\",\n",
        "               rng: Optional[random.Random] = None) -> Dict[str, Any]:\n",
        "        \"\"\"\n",
        "        Run PT for `steps` local updates per replica.\n",
        "\n",
        "        Args\n",
        "        ----\n",
        "        x0 : (4, L) bipolar one-hot\n",
        "        steps : number of *local* GWG steps per replica\n",
        "        swap_every : attempt swaps after every `swap_every` local steps\n",
        "        record_every : record traces every N local steps\n",
        "        record_all : if False, record only the cold (highest β) chain\n",
        "        init_mode : 'copy' (all replicas start at x0) or 'mutate'\n",
        "                    (hotter replicas get random per-column bases)\n",
        "        rng : Python RNG to control init mutation; torch RNG drives MH\n",
        "\n",
        "        Returns\n",
        "        -------\n",
        "        dict with:\n",
        "          - 'xs':      list (if record_all=False) or list-of-lists (if True)\n",
        "          - 'ys':      same structure as xs, storing y = model(x)\n",
        "          - 'local_accept': per-replica acceptance rates\n",
        "          - 'swap_accept': per-adjacent-pair acceptance rates\n",
        "          - 'final_states': list of final x per replica (cold..hot)\n",
        "          - 'betas':   the β ladder used (cold..hot)\n",
        "        \"\"\"\n",
        "        if rng is None:\n",
        "            rng = random\n",
        "\n",
        "        device = x0.device\n",
        "        K = len(self.betas)\n",
        "\n",
        "        # --- init states per replica ---\n",
        "        xs: List[torch.Tensor] = []\n",
        "        if init_mode == \"copy\":\n",
        "            for _ in range(K):\n",
        "                xs.append(x0.detach().clone())\n",
        "        elif init_mode == \"mutate\":\n",
        "            L = x0.shape[1]\n",
        "            curr_idx = torch.argmax(x0, dim=0)  # (L,)\n",
        "            for k in range(K):\n",
        "                xk = x0.detach().clone()\n",
        "                # For hotter replicas, randomly mutate ~25% of positions\n",
        "                frac = min(0.25 * (k / (K - 1) + 1e-8), 0.5)\n",
        "                m = int(round(frac * L))\n",
        "                if m > 0:\n",
        "                    # choose m positions, resample a base != current one\n",
        "                    pos = torch.randperm(L, device=device)[:m]\n",
        "                    new_b = torch.randint(0, 4, (m,), device=device)\n",
        "                    # ensure different from current\n",
        "                    same = (new_b == curr_idx[pos])\n",
        "                    if torch.any(same):\n",
        "                        new_b[same] = (new_b[same] + 1) % 4\n",
        "                    # write bipolar one-hot\n",
        "                    xk[:, pos] = -1.0\n",
        "                    xk[new_b, torch.arange(m, device=device)] = 1.0\n",
        "                xs.append(xk)\n",
        "        else:\n",
        "            raise ValueError(\"init_mode must be 'copy' or 'mutate'.\")\n",
        "\n",
        "        # --- cache exact y for swap decisions ---\n",
        "        with torch.no_grad():\n",
        "            ys_exact = [self._forward_y(x) for x in xs]\n",
        "\n",
        "        # --- bookkeeping ---\n",
        "        local_accept_num = [0 for _ in range(K)]\n",
        "        local_attempts   = [0 for _ in range(K)]\n",
        "        # For K replicas, there are K-1 adjacent pairs (0,1), (1,2), ...\n",
        "        swap_accept_num  = [0 for _ in range(K - 1)]\n",
        "        swap_attempts    = [0 for _ in range(K - 1)]\n",
        "\n",
        "        # traces\n",
        "        if record_all:\n",
        "            xs_trace: List[List[torch.Tensor]] = [[] for _ in range(K)]\n",
        "            ys_trace: List[List[float]] = [[] for _ in range(K)]\n",
        "        else:\n",
        "            xs_trace: List[torch.Tensor] = []\n",
        "            ys_trace: List[float] = []\n",
        "\n",
        "        # control even/odd swap pattern\n",
        "        swap_phase_even = True\n",
        "\n",
        "        self.model.eval()\n",
        "\n",
        "        for t in range(1, steps + 1):\n",
        "            # --- local GWG moves (need grads ON) ---\n",
        "            torch.set_grad_enabled(True)\n",
        "            for k, sampler in enumerate(self.samplers):\n",
        "                x_prev = xs[k]\n",
        "                x_new  = sampler.step(x_prev)\n",
        "                local_attempts[k] += 1\n",
        "                if (x_new != x_prev).any().item():\n",
        "                    local_accept_num[k] += 1\n",
        "                    xs[k] = x_new\n",
        "                    # refresh exact y for this replica only\n",
        "                    with torch.no_grad():\n",
        "                        ys_exact[k] = self._forward_y(xs[k])\n",
        "            torch.set_grad_enabled(False)\n",
        "\n",
        "            # --- swap step ---\n",
        "            if (t % swap_every) == 0:\n",
        "                start = 0 if swap_phase_even else 1  # even pairs then odd pairs\n",
        "                for i in range(start, K - 1, 2):\n",
        "                    j = i + 1\n",
        "                    # Acceptance ratio for swapping states (x_i, x_j):\n",
        "                    # a = exp((β_i - β_j)*(y_i - y_j))\n",
        "                    beta_i, beta_j = self.betas[i], self.betas[j]\n",
        "                    yi, yj = ys_exact[i], ys_exact[j]\n",
        "                    log_a = (beta_i - beta_j) * (yi - yj)\n",
        "                    a = math.exp(min(0.0, 0.0) + log_a) if log_a < 80 else float('inf')  # guard overflow\n",
        "                    swap_attempts[i] += 1\n",
        "                    u = float(torch.rand((), device=device).item())\n",
        "                    if u < min(1.0, a):\n",
        "                        # swap states & cached y\n",
        "                        xs[i], xs[j] = xs[j], xs[i]\n",
        "                        ys_exact[i], ys_exact[j] = ys_exact[j], ys_exact[i]\n",
        "                        swap_accept_num[i] += 1\n",
        "                swap_phase_even = not swap_phase_even\n",
        "\n",
        "            # --- record ---\n",
        "            if (t % record_every) == 0:\n",
        "                if record_all:\n",
        "                    for k in range(K):\n",
        "                        xs_trace[k].append(xs[k].detach().clone())\n",
        "                        ys_trace[k].append(ys_exact[k])\n",
        "                else:\n",
        "                    # cold chain is index 0 (highest β)\n",
        "                    xs_trace.append(xs[0].detach().clone())\n",
        "                    ys_trace.append(ys_exact[0])\n",
        "\n",
        "        # --- finalize stats ---\n",
        "        local_accept = [a / max(1, n) for a, n in zip(local_accept_num, local_attempts)]\n",
        "        swap_accept  = [a / max(1, n) for a, n in zip(swap_accept_num, swap_attempts)]\n",
        "\n",
        "        out: Dict[str, Any] = {\n",
        "            \"xs\": xs_trace,\n",
        "            \"ys\": ys_trace,\n",
        "            \"local_accept\": local_accept,         # per replica (cold..hot)\n",
        "            \"swap_accept\": swap_accept,           # per adjacent pair\n",
        "            \"final_states\": [x.detach().clone() for x in xs],\n",
        "            \"betas\": list(self.betas),\n",
        "        }\n",
        "        return out\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {},
      "outputs": [],
      "source": [
        "task_name = \"PT-GWG\"\n",
        "\n",
        "motif_fracs = []\n",
        "energies_per_run = []\n",
        "sequences_per_run = []\n",
        "\n",
        "\n",
        "for run in range(run_count):\n",
        "\n",
        "    baseline_motif_count = 0\n",
        "    energies = []\n",
        "    sequences = []\n",
        "\n",
        "    for i in range(N):\n",
        "\n",
        "        # 1) Build a ladder (8–16 replicas is a good starting point).\n",
        "        betas = make_beta_ladder(n_replicas=3, beta_min=0.05, beta_max=10.0, spacing=\"geometric\")\n",
        "\n",
        "        x0 = random_bipolar_sequence(L=L, device=torch.device(\"cuda\"))\n",
        "\n",
        "        # 2) Wrap your model.\n",
        "        pt = ParallelTemperingGWG(model, betas)\n",
        "\n",
        "        # 3) Run PT starting from some x0 (shape 4×L, bipolar one-hot).\n",
        "        out = pt.sample(x0, steps=20, swap_every=5, record_every=1, record_all=True, init_mode=\"mutate\")\n",
        "\n",
        "        tensor_out_xs = torch.stack([torch.stack(x) for x in out[\"xs\"]]).reshape(-1,4,L)\n",
        "\n",
        "        model = load_final_model().cuda().eval()\n",
        "\n",
        "        final_y_val = model(tensor_out_xs)\n",
        "\n",
        "        best_seq_idx = int(torch.argmax(tensor_out_xs).item())\n",
        "        best_seq = decode_sequences(tensor_out_xs[best_seq_idx].reshape(1,4,L))[0]\n",
        "        best_seq_y = final_y_val[best_seq_idx].item()\n",
        "\n",
        "\n",
        "        baseline_motif_count  += check_if_particle_contains_motif(tensor_out_xs.reshape(-1,4,L).cpu().numpy(), motif=\"CACGTG\")\n",
        "        energies.append(best_seq_y)\n",
        "        sequences.append(best_seq)\n",
        "\n",
        "    motif_fracs.append(baseline_motif_count / (N))\n",
        "    energies_per_run.append(energies)\n",
        "    sequences_per_run.append(sequences)\n",
        "    #save above code in csv     \n",
        "np.savetxt(f\"motif_fracs_{task_name}.csv\", np.array(motif_fracs), delimiter=\",\")\n",
        "np.savetxt(f\"energies_per_run_{task_name}.csv\", np.array(energies_per_run), delimiter=\",\")\n",
        "np.savetxt(f\"sequences_per_run_{task_name}.csv\", np.array(sequences_per_run), delimiter=\",\", fmt=\"%s\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ais_autotemp_gwg.py\n",
        "# End-to-end AIS with Adaptive Tempering (no resampling) using your GWG kernel.\n",
        "\n",
        "import math\n",
        "from typing import Callable, Optional, Dict, Any, List, Tuple\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# AIS–Auto-Temp (no resampling)\n",
        "# ----------------------------\n",
        "\n",
        "class AISAutoTempGWG:\n",
        "    \"\"\"\n",
        "    AIS with ESS-targeted adaptive temperatures, using GWG for rejuvenation.\n",
        "    No resampling at any stage. Preserves one-to-one lineages with starting particles.\n",
        "\n",
        "    Target: pi_beta(x) ∝ exp(-beta * E(x)), default E(x) = -model(x).\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        model: nn.Module,\n",
        "        energy_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,  # (N,4,L)->(N,)\n",
        "        ess_target_frac: float = 0.6,\n",
        "        K: int = 5,                              # GWG steps per stage\n",
        "        max_stages: Optional[int] = None,        # cap on # of stages (for compute parity)\n",
        "        beta_final: float = 1.0,\n",
        "        line_search_tol: float = 1e-4,\n",
        "        dtype=torch.float32,\n",
        "    ):\n",
        "        self.model = model\n",
        "        self.energy_fn = energy_fn\n",
        "        self.ess_target_frac = ess_target_frac\n",
        "        self.K = K\n",
        "        self.max_stages = max_stages\n",
        "        self.beta_final = float(beta_final)\n",
        "        self.line_search_tol = line_search_tol\n",
        "        self.dtype = dtype\n",
        "\n",
        "    # ---- batch-safe scoring/energy helpers ----\n",
        "    #@torch.no_grad()\n",
        "    def _score_batch_default(self, X: torch.Tensor) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        Score y(x) from the model. Tries batching; falls back to per-sample if needed.\n",
        "        \"\"\"\n",
        "        try:\n",
        "            y = self.model(X).view(-1)  # expect (N,)->(N,)\n",
        "            return y.to(X.device)\n",
        "        except Exception:\n",
        "            N = X.shape[0]\n",
        "            ys = []\n",
        "            for i in range(N):\n",
        "                ys.append(self.model(X[i].unsqueeze(0)).view(()).to(X.device))\n",
        "            return torch.stack(ys)\n",
        "\n",
        "    def _energies(self, X: torch.Tensor) -> torch.Tensor:\n",
        "        if self.energy_fn is not None:\n",
        "            with torch.no_grad():\n",
        "                return self.energy_fn(X)\n",
        "        with torch.no_grad():\n",
        "            return -self._score_batch_default(X)\n",
        "\n",
        "    def _scores(self, X: torch.Tensor) -> torch.Tensor:\n",
        "        if self.energy_fn is not None:\n",
        "            with torch.no_grad():\n",
        "                return -self.energy_fn(X)\n",
        "        return self._score_batch_default(X)\n",
        "\n",
        "    # ---- AIS math ----\n",
        "    def _predicted_logw(self, logw: torch.Tensor, E: torch.Tensor, beta_old: float, beta_new: float) -> torch.Tensor:\n",
        "        return logw - (beta_new - beta_old) * E\n",
        "\n",
        "    def _ess_from_logw(self, logw: torch.Tensor) -> float:\n",
        "        a = logw - torch.max(logw)\n",
        "        w = torch.exp(a)\n",
        "        s1 = torch.sum(w)\n",
        "        s2 = torch.sum(w * w)\n",
        "        return float((s1 * s1 / s2).item())\n",
        "\n",
        "    def _choose_next_beta(self, logw: torch.Tensor, E: torch.Tensor, beta: float, N: int) -> float:\n",
        "        target = self.ess_target_frac * N\n",
        "\n",
        "        # try jumping to 1.0\n",
        "        logw_full = self._predicted_logw(logw, E, beta, self.beta_final)\n",
        "        if self._ess_from_logw(logw_full) >= target:\n",
        "            return self.beta_final\n",
        "\n",
        "        low, high = beta, self.beta_final  # ESS(low) >= target (trivially); ESS(high) < target\n",
        "        for _ in range(32):\n",
        "            mid = 0.5 * (low + high)\n",
        "            logw_mid = self._predicted_logw(logw, E, beta, mid)\n",
        "            ess_mid = self._ess_from_logw(logw_mid)\n",
        "            if ess_mid >= target:\n",
        "                low = mid\n",
        "            else:\n",
        "                high = mid\n",
        "            if (high - low) < self.line_search_tol:\n",
        "                break\n",
        "        return max(low, beta + 1e-8)\n",
        "\n",
        "    # ---- main run ----\n",
        "    def run(\n",
        "        self,\n",
        "        x0: torch.Tensor,                                         # (N,4,L)\n",
        "        motif_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,  # returns Bool[N]\n",
        "        seed: Optional[int] = None,\n",
        "    ) -> Dict[str, Any]:\n",
        "        if seed is not None:\n",
        "            torch.manual_seed(seed)\n",
        "\n",
        "        device = x0.device\n",
        "        N, _, L = x0.shape\n",
        "        self.model.eval()\n",
        "\n",
        "        X = x0.detach().clone().to(device, self.dtype)\n",
        "        X.requires_grad = True\n",
        "        logw = torch.zeros(N, device=device, dtype=self.dtype)\n",
        "        beta = 0.0\n",
        "\n",
        "        # init energies and per-lineage stats\n",
        "        E = self._energies(X).to(device, self.dtype)\n",
        "        best_y = self._scores(X)                     # (N,)\n",
        "        hit_ever = np.zeros(N, dtype=np.bool) # dtype=torch.bool, device=device)\n",
        "        if motif_fn is not None:\n",
        "            with torch.no_grad():\n",
        "                hit_ever |= motif_fn(X)\n",
        "\n",
        "        kernel = GWGSamplerApprox(self.model, beta=beta)\n",
        "\n",
        "        betas: List[float] = [beta]\n",
        "        ess_hist: List[float] = [self._ess_from_logw(logw)]\n",
        "        stage = 0\n",
        "        total_gwg_steps = 0\n",
        "\n",
        "        while beta < self.beta_final:\n",
        "            # 1) choose next beta' by ESS target\n",
        "            beta_next = self._choose_next_beta(logw, E, beta, N)\n",
        "\n",
        "            # 2) AIS weight update at old states\n",
        "            logw = self._predicted_logw(logw, E, beta, beta_next)\n",
        "\n",
        "            # 3) rejuvenation at beta' with K GWG sweeps\n",
        "            kernel.beta = float(beta_next)\n",
        "            for _ in range(self.K):\n",
        "                for i in range(N):\n",
        "                    X[i] = kernel.step(X[i])\n",
        "                total_gwg_steps += N\n",
        "                # per-sweep bookkeeping\n",
        "                if motif_fn is not None:\n",
        "                    with torch.no_grad():\n",
        "                        hit_ever |= motif_fn(X)\n",
        "                with torch.no_grad():\n",
        "                    y_now = self._scores(X)\n",
        "                best_y = torch.maximum(best_y, y_now)\n",
        "\n",
        "            # 4) refresh energies for next stage’s update\n",
        "            E = self._energies(X).to(device, self.dtype)\n",
        "\n",
        "            beta = beta_next\n",
        "            betas.append(beta)\n",
        "            ess_hist.append(self._ess_from_logw(logw))\n",
        "            stage += 1\n",
        "\n",
        "            if self.max_stages is not None and stage >= self.max_stages:\n",
        "                if beta < self.beta_final:\n",
        "                    logw = self._predicted_logw(logw, E, beta, self.beta_final)\n",
        "                    beta = self.beta_final\n",
        "                    betas.append(beta)\n",
        "                    ess_hist.append(self._ess_from_logw(logw))\n",
        "                break\n",
        "\n",
        "        # final weighted motif probability at beta=1 (pre-resample)\n",
        "        out: Dict[str, Any] = {\n",
        "            \"xs\": X,                     # (N,4,L) final states\n",
        "            \"logw\": logw,                # (N,) final log-weights\n",
        "            \"betas\": betas,              # list of betas visited\n",
        "            \"ess_history\": ess_hist,     # ESS after each stage\n",
        "            \"E\": E,                      # final energies\n",
        "            \"best_y\": best_y,            # (N,) best y along lineage\n",
        "            \"hit_ever\": hit_ever,        # (N,) ever-hit flag\n",
        "            \"total_gwg_steps\": total_gwg_steps,\n",
        "        }\n",
        "\n",
        "        return out\n",
        "\n",
        "\n",
        "task_name = \"AISAutoTemp-GWG\"\n",
        "\n",
        "motif_fracs = []\n",
        "energies_per_run = []\n",
        "sequences_per_run = []\n",
        "\n",
        "for run in range(run_count):\n",
        "\n",
        "    x0 = random_bipolar_batch(N, L, device=device)\n",
        "\n",
        "    # Total compute parity: e.g., ~12 stages * 5 GWG = 60 steps\n",
        "    ais = AISAutoTempGWG(\n",
        "            model=model,\n",
        "            energy_fn=None,           # default E = -model(x)\n",
        "            ess_target_frac=0.6,\n",
        "            K=20,\n",
        "            max_stages=3,            # keep L*K ≈ 60\n",
        "            beta_final=10.0,\n",
        "        )\n",
        "\n",
        "    out = ais.run(x0, motif_fn=check_if_particle_contains_motif)\n",
        "\n",
        "    final_y_val = load_final_model()(out[\"xs\"])\n",
        "\n",
        "    final_x_val = decode_sequences(out[\"xs\"].reshape(-1,4,L))\n",
        "\n",
        "    motif_frac = out[\"hit_ever\"].sum()/N\n",
        "\n",
        "    motif_fracs.append(motif_frac)\n",
        "    energies_per_run.append(final_y_val.cpu().numpy().tolist())\n",
        "    sequences_per_run.append(final_x_val)\n",
        "\n",
        "\n",
        "np.savetxt(f\"motif_fracs_{task_name}.csv\", np.array(motif_fracs), delimiter=\",\")\n",
        "np.savetxt(f\"energies_per_run_{task_name}.csv\", np.array(energies_per_run).squeeze(), delimiter=\",\")\n",
        "np.savetxt(f\"sequences_per_run_{task_name}.csv\", np.array(sequences_per_run), delimiter=\",\", fmt=\"%s\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [],
      "source": [
        "import os, re, glob, torch\n",
        "\n",
        "def resolve_checkpoint_path(ckpt_dir: str, epoch: int) -> str:\n",
        "    \"\"\"\n",
        "    Return a checkpoint path for the given epoch.\n",
        "    Tries common names, then falls back to a glob search.\n",
        "    \"\"\"\n",
        "    cand = [\n",
        "        os.path.join(ckpt_dir, f\"epoch_{epoch:03d}.pt\"),\n",
        "        os.path.join(ckpt_dir, f\"epoch_{epoch}.pt\"),\n",
        "        os.path.join(ckpt_dir, f\"ckpt_epoch_{epoch:03d}.pt\"),\n",
        "        os.path.join(ckpt_dir, f\"ckpt_epoch_{epoch}.pt\"),\n",
        "    ]\n",
        "    for p in cand:\n",
        "        if os.path.exists(p):\n",
        "            return p\n",
        "\n",
        "    # Fallback: any *.pt that clearly references this epoch number\n",
        "    pat = re.compile(rf\".*[^0-9]{epoch}[^0-9]*\\.pt$\")\n",
        "    for p in sorted(glob.glob(os.path.join(ckpt_dir, \"*.pt\"))):\n",
        "        if pat.match(os.path.basename(p)):\n",
        "            return p\n",
        "\n",
        "    raise FileNotFoundError(f\"No checkpoint found for epoch {epoch} in {ckpt_dir}\")\n",
        "\n",
        "def load_weights_into(model: torch.nn.Module, ckpt_path: str, device: str = \"cuda\", strict: bool = False):\n",
        "    \"\"\"\n",
        "    Load weights from a checkpoint into `model`.\n",
        "    Supports either:\n",
        "      - raw state_dict\n",
        "      - dict with key 'model_state_dict'\n",
        "    Also strips an optional leading 'module.' (from DataParallel).\n",
        "    \"\"\"\n",
        "    sd = torch.load(ckpt_path, map_location=device)\n",
        "\n",
        "    # unwrap container dicts\n",
        "    if isinstance(sd, dict) and \"model_state_dict\" in sd:\n",
        "        sd = sd[\"model_state_dict\"]\n",
        "\n",
        "    # handle DataParallel prefixes if present\n",
        "    if isinstance(sd, dict) and any(k.startswith(\"module.\") for k in sd.keys()):\n",
        "        sd = {k[len(\"module.\"):]: v for k, v in sd.items()}\n",
        "\n",
        "    missing, unexpected = model.load_state_dict(sd, strict=strict)\n",
        "    if missing or unexpected:\n",
        "        print(f\"[load_weights_into] missing keys: {missing} | unexpected keys: {unexpected}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "@torch.no_grad()\n",
        "def model_y_batch(model: torch.nn.Module, X_4L: torch.Tensor, chunk: int = 0) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    X_4L: (N, 4, L) -> (N,) predicted y values.\n",
        "    Set chunk>0 to evaluate in mini-batches (useful if N is large / low VRAM).\n",
        "    \"\"\"\n",
        "    if X_4L.dim() == 2:  # (4, L) -> (1,)\n",
        "        return model(X_4L.unsqueeze(0)).view(-1)\n",
        "\n",
        "    if chunk and X_4L.size(0) > chunk:\n",
        "        outs = []\n",
        "        for i in range(0, X_4L.size(0), chunk):\n",
        "            outs.append(model(X_4L[i:i+chunk]).view(-1))\n",
        "        return torch.cat(outs, dim=0)\n",
        "    else:\n",
        "        return model(X_4L).view(-1)\n",
        "\n",
        "@torch.no_grad()\n",
        "def energy_batch(model: torch.nn.Module, X_4L: torch.Tensor, chunk: int = 0) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Energy E(x) = -y(x). Returns (N,).\n",
        "    \"\"\"\n",
        "    return -model_y_batch(model, X_4L, chunk=chunk)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "def smc_via_checkpoints(\n",
        "    checkpoint_dir: str,\n",
        "    epochs: list[int],\n",
        "    model_factory,                 # e.g. lambda: MaxCNN(L=60, k=11, n_filters=64)\n",
        "    GWGSamplerClass,               # e.g. GWGSamplerApprox (needs grads ON)\n",
        "    L: int,\n",
        "    num_particles: int = 200,\n",
        "    mcmc_steps: int = 15,\n",
        "    resample_thresh: float = 0.5,\n",
        "    device: str = \"cuda\",\n",
        "    beta: float = 1.0,\n",
        "    last_epoch: int | None = None,  # <<< NEW: epoch number of the model to score \"best y\"\n",
        "):\n",
        "    \"\"\"\n",
        "    SMC over checkpoints where:\n",
        "      - Importance weights & GWG proposals use the CURRENT checkpoint model.\n",
        "      - 'Best y' tracking and final y use the model loaded from `last_epoch`.\n",
        "\n",
        "    Returns:\n",
        "        particles_np: (N, 4, L) final particle states\n",
        "        weights_np:   (N,) final normalized weights\n",
        "        ess:          final effective sample size\n",
        "        y_final_last_np: (N,) y under LAST-EPOCH model for final particles\n",
        "        y_best_last_np:  (N,) best y (under LAST-EPOCH model) seen during rejuvenations\n",
        "    \"\"\"\n",
        "    device = torch.device(device if torch.cuda.is_available() or device == \"cpu\" else \"cpu\")\n",
        "    epochs = sorted(list(epochs))\n",
        "    ckpts  = [resolve_checkpoint_path(checkpoint_dir, e) for e in epochs]\n",
        "\n",
        "    # --- load the model for \"last_epoch\" scoring ---\n",
        "    assert last_epoch is not None, \"Please pass last_epoch (int).\"\n",
        "    last_ckpt_path = resolve_checkpoint_path(checkpoint_dir, last_epoch)\n",
        "    last_model = model_factory().to(device).eval()\n",
        "    load_weights_into(last_model, last_ckpt_path, device=str(device))\n",
        "    #print(f\"[LAST model] loaded: {last_ckpt_path}\")\n",
        "\n",
        "    # init particles & weights\n",
        "    particles = random_bipolar_batch(num_particles, L, device=device)\n",
        "    weights   = torch.full((num_particles,), 1.0 / num_particles, device=device)\n",
        "    E_prev    = torch.zeros(num_particles, device=device)\n",
        "\n",
        "    all_seqs = []\n",
        "    idx = 0\n",
        "\n",
        "    for t, ckpt in enumerate(ckpts, start=1):\n",
        "        # Build current model for this stage\n",
        "        model = model_factory().to(device).eval()\n",
        "        load_weights_into(model, ckpt, device=str(device))\n",
        "\n",
        "        # ---- GWG rejuvenation (proposals use CURRENT model; scoring uses LAST model) ----\n",
        "        sampler = GWGSamplerClass(model, beta=beta)  # sampler needs grads on current model\n",
        "        for i in range(num_particles):\n",
        "            x = particles[i]\n",
        "\n",
        "            for _ in range(mcmc_steps):\n",
        "                all_seqs.append(x)\n",
        "                # step needs autograd for gradient-based sampler\n",
        "                with torch.enable_grad():\n",
        "                    x = sampler.step(x)\n",
        "\n",
        "            particles[i] = x\n",
        "\n",
        "    return all_seqs\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {},
      "outputs": [],
      "source": [
        "task_name = \"SMC-GWG\"\n",
        "\n",
        "\n",
        "motif_fracs = []\n",
        "energies_per_run = []\n",
        "sequences_per_run = []\n",
        "\n",
        "for run in range(run_count):\n",
        "    motif_count = 0\n",
        "    energies = []\n",
        "    sequences = []\n",
        "    for k in range(N):\n",
        "        L = 60\n",
        "        particles_smc = smc_via_checkpoints(\n",
        "            checkpoint_dir=\"ckpts\",\n",
        "            epochs=[25, 50, 300],                         # stages you want to traverse\n",
        "            model_factory=lambda: MaxCNN(L=L, k=11, n_filters=64),\n",
        "            GWGSamplerClass=GWGSamplerApprox,                 # or your exact GWG sampler\n",
        "            L=L,\n",
        "            num_particles=1,\n",
        "            mcmc_steps=20,\n",
        "            resample_thresh=0.5,\n",
        "            device=\"cuda\",\n",
        "            beta=10.0,\n",
        "            last_epoch=300,                                   # <<< use this epoch to score best/final y\n",
        "        )\n",
        "        particles_smc = torch.stack(particles_smc).reshape(-1,4,L) #.to(\"cuda\")\n",
        "        \n",
        "        final_y_val = load_final_model()(particles_smc)\n",
        "        idx = final_y_val.argmax(axis=0)\n",
        "        best_seq = decode_sequences(particles_smc[idx].reshape(1,4,L))[0]\n",
        "        best_seq_y = final_y_val[idx].item()\n",
        "\n",
        "        energies.append(best_seq_y)\n",
        "        sequences.append(best_seq)\n",
        "\n",
        "    \n",
        "\n",
        "        import numpy as np\n",
        "        BASES = np.array(list(\"ACGT\"))\n",
        "        MOTIF = \"CACGTG\"  # palindromic, so no RC needed\n",
        "\n",
        "        # Decode all particles at once\n",
        "        idx = particles_smc.argmax(axis=1).cpu().data.numpy()                     # (N, L) best channel per position\n",
        "        seqs = [''.join(BASES[row]) for row in idx]            # list of strings length N\n",
        "\n",
        "        # Boolean mask: does the sequence contain the motif?\n",
        "        has = np.fromiter((MOTIF in s for s in seqs), dtype=bool, count=len(seqs))\n",
        "\n",
        "        if has.sum() > 0:\n",
        "            #print(\"Contains motif\")\n",
        "            motif_count+=1\n",
        "\n",
        "\n",
        "\n",
        "        \n",
        "    motif_fracs.append(motif_count / N)\n",
        "    energies_per_run.append(energies)\n",
        "    sequences_per_run.append(sequences)\n",
        "\n",
        "\n",
        "np.savetxt(f\"motif_fracs_{task_name}.csv\", np.array(motif_fracs), delimiter=\",\")\n",
        "np.savetxt(f\"energies_per_run_{task_name}.csv\", np.array(energies_per_run), delimiter=\",\")\n",
        "np.savetxt(f\"sequences_per_run_{task_name}.csv\", np.array(sequences_per_run), delimiter=\",\", fmt=\"%s\")\n",
        "\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "torch-gpu-env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
