{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "oLiu3ohakQvG",
        "outputId": "0f9a0498-1bf7-451e-e2fd-9dae9a7919d1"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 148/148 [00:01<00:00, 94.73it/s] \n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 001  train_mse=2.4833  val_mse=1.9020  bad=0\n",
            "[TRAIN] PRED  y — no-motif max=2.192 | motif min/median/max=-1.209/0.149/1.571 (n_neg=37367, n_pos=315)\n",
            "[TRAIN] TRUTH y — no-motif max=10.060 | motif min/median/max=-2.491/3.529/13.860\n",
            "[VAL  ] PRED  y — no-motif max=2.279 | motif min/median/max=-1.102/0.063/1.541 (n_neg=3734, n_pos=34)\n",
            "[VAL  ] TRUTH y — no-motif max=11.449 | motif min/median/max=-1.455/1.818/12.009\n",
            "[TEST ] PRED  y — no-motif max=1.464 | motif min/median/max=-0.031/0.188/0.406 (n_neg=417, n_pos=2)\n",
            "[TEST ] TRUTH y — no-motif max=5.665 | motif min/median/max=5.309/5.469/5.628\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 148/148 [00:01<00:00, 109.74it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 002  train_mse=1.7994  val_mse=1.7755  bad=0\n",
            "[TRAIN] PRED  y — no-motif max=1.989 | motif min/median/max=-1.132/0.282/1.585 (n_neg=37367, n_pos=315)\n",
            "[TRAIN] TRUTH y — no-motif max=10.060 | motif min/median/max=-2.491/3.529/13.860\n",
            "[VAL  ] PRED  y — no-motif max=2.050 | motif min/median/max=-1.032/0.087/1.610 (n_neg=3734, n_pos=34)\n",
            "[VAL  ] TRUTH y — no-motif max=11.449 | motif min/median/max=-1.455/1.818/12.009\n",
            "[TEST ] PRED  y — no-motif max=1.410 | motif min/median/max=0.315/0.349/0.383 (n_neg=417, n_pos=2)\n",
            "[TEST ] TRUTH y — no-motif max=5.665 | motif min/median/max=5.309/5.469/5.628\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 148/148 [00:01<00:00, 108.28it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 003  train_mse=1.7039  val_mse=1.7097  bad=0\n",
            "[TRAIN] PRED  y — no-motif max=2.228 | motif min/median/max=-0.962/0.517/1.718 (n_neg=37367, n_pos=315)\n",
            "[TRAIN] TRUTH y — no-motif max=10.060 | motif min/median/max=-2.491/3.529/13.860\n",
            "[VAL  ] PRED  y — no-motif max=2.073 | motif min/median/max=-0.641/0.252/1.886 (n_neg=3734, n_pos=34)\n",
            "[VAL  ] TRUTH y — no-motif max=11.449 | motif min/median/max=-1.455/1.818/12.009\n",
            "[TEST ] PRED  y — no-motif max=1.585 | motif min/median/max=0.596/0.659/0.721 (n_neg=417, n_pos=2)\n",
            "[TEST ] TRUTH y — no-motif max=5.665 | motif min/median/max=5.309/5.469/5.628\n",
            "Done. Wrote: run_max\n"
          ]
        }
      ],
      "source": [
        "# ---- All-in-one notebook cell with TRAIN/VAL/TEST motif monitoring (pred + truth) ----\n",
        "import os, re, json, math, argparse, random, numpy as np, pandas as pd\n",
        "from pathlib import Path\n",
        "\n",
        "import os\n",
        "os.environ[\"TORCHDYNAMO_DISABLE\"] = \"1\"  # set before importing torch\n",
        "import torch\n",
        "\n",
        "\n",
        "# Optional: tqdm (safe fallback)\n",
        "try:\n",
        "    from tqdm import tqdm\n",
        "except Exception:\n",
        "    def tqdm(x, **kwargs): return x\n",
        "\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import Dataset, DataLoader, random_split\n",
        "\n",
        "\n",
        "# =========================\n",
        "# Helpers\n",
        "# =========================\n",
        "\n",
        "def set_seed(seed: int = 1337):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "def r2_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:\n",
        "    y_true = np.asarray(y_true).flatten()\n",
        "    y_pred = np.asarray(y_pred).flatten()\n",
        "    ss_res = float(np.sum((y_true - y_pred) ** 2))\n",
        "    ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2)) + 1e-9\n",
        "    return 1.0 - ss_res / ss_tot\n",
        "\n",
        "def spearmanr_np(a: np.ndarray, b: np.ndarray) -> float:\n",
        "    a = np.asarray(a).flatten()\n",
        "    b = np.asarray(b).flatten()\n",
        "    def rankdata(x):\n",
        "        order = np.argsort(x, kind=\"mergesort\")\n",
        "        ranks = np.empty_like(order, dtype=float)\n",
        "        ranks[order] = np.arange(len(x), dtype=float)\n",
        "        vals, inv, counts = np.unique(x, return_inverse=True, return_counts=True)\n",
        "        csum = np.cumsum(counts); start = np.concatenate(([0], csum[:-1]))\n",
        "        avg = (start + csum - 1) / 2.0\n",
        "        return avg[inv]\n",
        "    ra = rankdata(a); rb = rankdata(b)\n",
        "    ra = (ra - ra.mean()) / (ra.std() + 1e-9)\n",
        "    rb = (rb - rb.mean()) / (rb.std() + 1e-9)\n",
        "    return float(np.mean(ra * rb))\n",
        "\n",
        "# --- Robust PBM loader ---\n",
        "_float_pat = re.compile(r\"[-+]?\\d+(?:\\.\\d+)?(?:[eE][-+]?\\d+)?\")\n",
        "_seq_pat   = re.compile(r\"[ACGTacgt]{20,}\")\n",
        "\n",
        "def _parse_line(line: str):\n",
        "    if not line.strip() or line.lstrip().startswith((\"#\", \"//\")):\n",
        "        return None\n",
        "    ym = _float_pat.search(line)\n",
        "    if not ym: return None\n",
        "    y_str = ym.group(0)\n",
        "    seqs = _seq_pat.findall(line)\n",
        "    if not seqs: return None\n",
        "    seq = max(seqs, key=len).upper().strip()\n",
        "    try: y = float(y_str)\n",
        "    except Exception: return None\n",
        "    seq = re.sub(r\"[^ACGT]\", \"\", seq)\n",
        "    return y, seq\n",
        "\n",
        "def load_pbm(path: str) -> pd.DataFrame:\n",
        "    rows = []\n",
        "    try:\n",
        "        tmp = pd.read_csv(path, sep=r\"\\s+\", header=None, engine=\"python\")\n",
        "        for _, row in tmp.iterrows():\n",
        "            joined = \" \".join(str(x) for x in row if pd.notna(x))\n",
        "            parsed = _parse_line(joined)\n",
        "            if parsed: rows.append(parsed)\n",
        "    except Exception:\n",
        "        with open(path, \"r\") as f:\n",
        "            for line in f:\n",
        "                parsed = _parse_line(line)\n",
        "                if parsed: rows.append(parsed)\n",
        "    out = pd.DataFrame(rows, columns=[\"y_raw\", \"seq\"])\n",
        "    out[\"seq\"] = out[\"seq\"].str.replace(r\"[^ACGT]\", \"\", regex=True).str.upper()\n",
        "    out[\"len\"] = out[\"seq\"].str.len()\n",
        "    return out\n",
        "\n",
        "def normalize_y(y, method: str = \"robust\"):\n",
        "    y = np.asarray(y, dtype=float)\n",
        "    eps = 1e-9\n",
        "    if method == \"none\":\n",
        "        return y, y\n",
        "    ylog = np.log1p(np.maximum(y, 0.0) + eps)\n",
        "    if method == \"logz\":\n",
        "        mu, sd = float(np.mean(ylog)), float(np.std(ylog) + eps)\n",
        "        yn = (ylog - mu) / sd\n",
        "    elif method == \"robust\":\n",
        "        med = float(np.median(ylog))\n",
        "        mad = float(np.median(np.abs(ylog - med)))\n",
        "        scale = 1.4826 * mad if mad > 0 else float(np.std(ylog) + eps)\n",
        "        yn = (ylog - med) / (scale if scale > 0 else 1.0)\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown normalize method: {method}\")\n",
        "    return ylog, yn\n",
        "\n",
        "# =========================\n",
        "# Data encoding & model\n",
        "# =========================\n",
        "DNA = \"ACGT\"\n",
        "BASE_TO_IDX = {b: i for i, b in enumerate(DNA)}\n",
        "\n",
        "def bipolar_encode(seq: str, alphabet: str = DNA) -> np.ndarray:\n",
        "    L = len(seq); C = len(alphabet)\n",
        "    x = -np.ones((C, L), dtype=np.float32)\n",
        "    for i, ch in enumerate(seq):\n",
        "        j = BASE_TO_IDX.get(ch)\n",
        "        if j is not None:\n",
        "            x[j, i] = 1.0\n",
        "    return x\n",
        "\n",
        "class SeqDataset(Dataset):\n",
        "    def __init__(self, df: pd.DataFrame, target_col: str = \"y_norm\"):\n",
        "        self.seqs = df[\"seq\"].tolist()\n",
        "        self.y    = df[target_col].astype(np.float32).values.reshape(-1, 1)\n",
        "    def __len__(self): return len(self.seqs)\n",
        "    def __getitem__(self, idx):\n",
        "        x = bipolar_encode(self.seqs[idx])\n",
        "        return torch.from_numpy(x), torch.from_numpy(self.y[idx])\n",
        "\n",
        "class MaxCNN(nn.Module):\n",
        "    def __init__(self, L: int, k: int = 11, n_filters: int = 64):\n",
        "        super().__init__()\n",
        "        self.conv = nn.Conv1d(4, n_filters, kernel_size=k, padding=0)\n",
        "        self.act  = nn.ReLU(inplace=True)\n",
        "        self.head = nn.Linear(n_filters, 1)\n",
        "        nn.init.kaiming_normal_(self.conv.weight, nonlinearity=\"relu\")\n",
        "        nn.init.zeros_(self.conv.bias)\n",
        "        nn.init.xavier_uniform_(self.head.weight)\n",
        "        nn.init.zeros_(self.head.bias)\n",
        "    def forward(self, x):\n",
        "        h = self.act(self.conv(x))\n",
        "        h = torch.amax(h, dim=-1)   # global max\n",
        "        return self.head(h)\n",
        "\n",
        "# =========================\n",
        "# Monitoring helpers (updated to report PRED + TRUE)\n",
        "# =========================\n",
        "def eval_all(model, loader, device):\n",
        "    \"\"\"Return (y_pred, y_true) for all items in the given loader (no shuffle).\"\"\"\n",
        "    model.eval()\n",
        "    y_pred, y_true = [], []\n",
        "    with torch.no_grad():\n",
        "        for xb, yb in loader:\n",
        "            xb = xb.to(device)\n",
        "            p  = model(xb).cpu().numpy()\n",
        "            y_pred.append(p)\n",
        "            y_true.append(yb.numpy())\n",
        "    return np.vstack(y_pred).flatten(), np.vstack(y_true).flatten()\n",
        "\n",
        "def _stats_tuple(arr):\n",
        "    if arr.size == 0: return (np.nan, np.nan, np.nan, np.nan)  # (min, median, max, count=nan here)\n",
        "    return (float(np.min(arr)), float(np.median(arr)), float(np.max(arr)), float(arr.size))\n",
        "\n",
        "def motif_stats_pair(y_pred: np.ndarray, y_true: np.ndarray, subset, has_motif_all: np.ndarray):\n",
        "    \"\"\"\n",
        "    Compute motif-sliced stats for both predicted and true y on a subset.\n",
        "    Returns two dicts: pred_stats and true_stats, each with keys:\n",
        "        - 'neg_max'  (no motif, max)\n",
        "        - 'pos_min', 'pos_median', 'pos_max'\n",
        "        - 'n_neg', 'n_pos'\n",
        "    \"\"\"\n",
        "    if hasattr(subset, \"indices\"):\n",
        "        idx = np.array(subset.indices)\n",
        "    else:\n",
        "        idx = np.arange(len(y_pred))\n",
        "\n",
        "    flags = has_motif_all[idx].astype(bool)\n",
        "    neg = ~flags\n",
        "    pos =  flags\n",
        "\n",
        "    # predicted\n",
        "    neg_max_pred = float(np.max(y_pred[neg])) if np.any(neg) else np.nan\n",
        "    pos_min_pred, pos_med_pred, pos_max_pred, n_pos = _stats_tuple(y_pred[pos])\n",
        "    n_neg = float(neg.sum())\n",
        "\n",
        "    pred_stats = dict(\n",
        "        neg_max=neg_max_pred,\n",
        "        pos_min=pos_min_pred, pos_median=pos_med_pred, pos_max=pos_max_pred,\n",
        "        n_neg=n_neg, n_pos=n_pos\n",
        "    )\n",
        "\n",
        "    # true\n",
        "    neg_max_true = float(np.max(y_true[neg])) if np.any(neg) else np.nan\n",
        "    pos_min_true, pos_med_true, pos_max_true, _ = _stats_tuple(y_true[pos])\n",
        "\n",
        "    true_stats = dict(\n",
        "        neg_max=neg_max_true,\n",
        "        pos_min=pos_min_true, pos_median=pos_med_true, pos_max=pos_max_true,\n",
        "        n_neg=n_neg, n_pos=n_pos\n",
        "    )\n",
        "\n",
        "    return pred_stats, true_stats\n",
        "\n",
        "def print_motif_report(split_name, pred_stats, true_stats):\n",
        "    print(\n",
        "        f\"[{split_name}] PRED  y — no-motif max={pred_stats['neg_max']:.3f} | \"\n",
        "        f\"motif min/median/max={pred_stats['pos_min']:.3f}/{pred_stats['pos_median']:.3f}/{pred_stats['pos_max']:.3f} \"\n",
        "        f\"(n_neg={int(pred_stats['n_neg'])}, n_pos={int(pred_stats['n_pos'])})\"\n",
        "    )\n",
        "    print(\n",
        "        f\"[{split_name}] TRUTH y — no-motif max={true_stats['neg_max']:.3f} | \"\n",
        "        f\"motif min/median/max={true_stats['pos_min']:.3f}/{true_stats['pos_median']:.3f}/{true_stats['pos_max']:.3f}\"\n",
        "    )\n",
        "\n",
        "# =========================\n",
        "# Training with per-split motif monitoring (pred + truth)\n",
        "# =========================\n",
        "def main():\n",
        "    # Notebook-style args\n",
        "    args = argparse.Namespace()\n",
        "    args.input = \"Max_3864.1_v2_deBruijn.txt\"\n",
        "    args.outdir = \"run_max\"\n",
        "    args.normalize = \"robust\"\n",
        "    args.epochs = 300\n",
        "    args.batch = 256\n",
        "    args.lr = 1e-3\n",
        "    args.seed = 1337\n",
        "    args.checkpoint_dir = \"ckpts\"\n",
        "    motif = \"CACGTG\"\n",
        "\n",
        "    set_seed(args.seed)\n",
        "    outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)\n",
        "    ckpt_dir = Path(args.checkpoint_dir); ckpt_dir.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "    # Load & normalize\n",
        "    raw = load_pbm(args.input)\n",
        "    ylog, ynorm = normalize_y(raw[\"y_raw\"].values, method=args.normalize)\n",
        "    raw[\"y_log1p\"] = ylog\n",
        "    raw[\"y_norm\"]  = ynorm\n",
        "    raw[\"has_motif\"] = raw[\"seq\"].str.contains(motif)\n",
        "    raw.to_csv(outdir / \"processed.csv\", index=False)\n",
        "\n",
        "    # Splits\n",
        "    frac_train, frac_val = 0.90, 0.09\n",
        "    N = len(raw)\n",
        "    n_train = int(N * frac_train)\n",
        "    n_val   = int(N * frac_val)\n",
        "    n_test  = N - n_train - n_val\n",
        "    lengths = [n_train, n_val, n_test]\n",
        "\n",
        "    ds = SeqDataset(raw, target_col=\"y_norm\")\n",
        "    train_ds, val_ds, test_ds = random_split(ds, lengths, generator=torch.Generator().manual_seed(args.seed))\n",
        "\n",
        "    # Model\n",
        "    L = len(raw[\"seq\"].iloc[0])\n",
        "    model = MaxCNN(L=L, k=11, n_filters=64)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model.to(device)\n",
        "\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=args.lr)\n",
        "    loss_fn = nn.MSELoss()\n",
        "\n",
        "    # Loaders\n",
        "    train_loader      = DataLoader(train_ds, batch_size=args.batch, shuffle=True,  drop_last=False)\n",
        "    train_eval_loader = DataLoader(train_ds, batch_size=args.batch, shuffle=False, drop_last=False)  # for monitoring\n",
        "    val_loader        = DataLoader(val_ds,   batch_size=args.batch, shuffle=False, drop_last=False)\n",
        "    test_loader       = DataLoader(test_ds,  batch_size=args.batch, shuffle=False, drop_last=False)\n",
        "\n",
        "    best_val = float(\"inf\"); best_state = None; patience = 6; bad = 0\n",
        "    has_motif_all = raw[\"has_motif\"].to_numpy()\n",
        "\n",
        "    for epoch in range(1, args.epochs + 1):\n",
        "        # ---- Train ----\n",
        "        model.train()\n",
        "        tr_loss = 0.0\n",
        "        for xb, yb in tqdm(train_loader):\n",
        "            xb = xb.to(device); yb = yb.to(device)\n",
        "            opt.zero_grad()\n",
        "            pred = model(xb)\n",
        "            loss = loss_fn(pred, yb)\n",
        "            loss.backward(); opt.step()\n",
        "            tr_loss += float(loss.item()) * len(xb)\n",
        "        tr_loss /= len(train_loader.dataset)\n",
        "\n",
        "        # ---- Validate loss ----\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            vs, vt = [], []\n",
        "            for xb, yb in val_loader:\n",
        "                xb = xb.to(device); yb = yb.to(device)\n",
        "                p = model(xb)\n",
        "                vs.append(p.cpu().numpy()); vt.append(yb.cpu().numpy())\n",
        "        vpred = np.vstack(vs).flatten()\n",
        "        vtrue = np.vstack(vt).flatten()\n",
        "        v_mse = float(np.mean((vpred - vtrue) ** 2))\n",
        "\n",
        "        improved = v_mse < best_val\n",
        "        if improved:\n",
        "            best_val = v_mse\n",
        "            best_state = {k: v.cpu() for k, v in model.state_dict().items()}\n",
        "            bad = 0\n",
        "        else:\n",
        "            bad += 1\n",
        "\n",
        "        # ---- Per-split motif monitoring (PRED + TRUTH) ----\n",
        "        y_pred_train, y_true_train = eval_all(model, train_eval_loader, device)\n",
        "        y_pred_test,  y_true_test  = eval_all(model, test_loader, device)\n",
        "\n",
        "        tr_pred_stats, tr_true_stats = motif_stats_pair(y_pred_train, y_true_train, train_ds, has_motif_all)\n",
        "        va_pred_stats, va_true_stats = motif_stats_pair(vpred,          vtrue,       val_ds,   has_motif_all)\n",
        "        te_pred_stats, te_true_stats = motif_stats_pair(y_pred_test,    y_true_test, test_ds,  has_motif_all)\n",
        "\n",
        "        print(f\"Epoch {epoch:03d}  train_mse={tr_loss:.4f}  val_mse={v_mse:.4f}  bad={bad}\")\n",
        "        print_motif_report(\"TRAIN\", tr_pred_stats, tr_true_stats)\n",
        "        print_motif_report(\"VAL  \", va_pred_stats, va_true_stats)\n",
        "        print_motif_report(\"TEST \", te_pred_stats, te_true_stats)\n",
        "\n",
        "        # ---- Checkpoints every epoch ----\n",
        "        torch.save(model.state_dict(), ckpt_dir / f\"epoch_{epoch:03d}.pt\")\n",
        "        torch.save(model.state_dict(), outdir / \"last.pt\")\n",
        "        if improved:\n",
        "            torch.save(model.state_dict(), outdir / \"best.pt\")\n",
        "\n",
        "        # Optional early stop:\n",
        "        # if bad >= patience: break\n",
        "\n",
        "    # ---- Load best and final test metrics ----\n",
        "    if best_state is not None:\n",
        "        model.load_state_dict(best_state)\n",
        "\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        ps, ts = [], []\n",
        "        for xb, yb in test_loader:\n",
        "            xb = xb.to(device); yb = yb.to(device)\n",
        "            p = model(xb)\n",
        "            ps.append(p.cpu().numpy()); ts.append(yb.cpu().numpy())\n",
        "    pred = np.vstack(ps); true = np.vstack(ts)\n",
        "    mse = float(np.mean((pred - true) ** 2))\n",
        "    r2  = r2_score(true, pred)\n",
        "    spr = spearmanr_np(true, pred)\n",
        "\n",
        "    metrics = {\"val_best_mse\": best_val, \"test_mse\": mse, \"test_r2\": r2, \"test_spearman\": spr}\n",
        "    with open(outdir / \"metrics.json\", \"w\") as f:\n",
        "        json.dump(metrics, f, indent=2)\n",
        "\n",
        "    torch.save(model.state_dict(), outdir / \"model.pt\")\n",
        "\n",
        "    # Score all sequences and dump top hits\n",
        "    all_loader = DataLoader(ds, batch_size=args.batch, shuffle=False)\n",
        "    with torch.no_grad():\n",
        "        all_pred = []\n",
        "        for xb, _ in all_loader:\n",
        "            xb = xb.to(device)\n",
        "            p = model(xb).cpu().numpy().flatten().tolist()\n",
        "            all_pred.extend(p)\n",
        "    full = raw.copy()\n",
        "    full[\"y_pred\"] = all_pred\n",
        "    full.sort_values(\"y_pred\", ascending=False).head(200).to_csv(outdir / \"top200_predicted.csv\", index=False)\n",
        "\n",
        "    print(\"Done. Wrote:\", outdir)\n",
        "    return model\n",
        "\n",
        "# Run:\n",
        "model = main()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "hrWWocandVXl"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "from typing import List, Tuple\n",
        "\n",
        "# --- helpers for bipolar encoding (+1 for chosen base channel, -1 for others) ---\n",
        "\n",
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "DNA = \"ACGT\"\n",
        "BASE_TO_IDX = {b:i for i,b in enumerate(DNA)}\n",
        "\n",
        "def random_bipolar_sequence(L: int, device=None, dtype=torch.float32) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    (4, L) in {-1,+1}, exactly one +1 per column.\n",
        "    \"\"\"\n",
        "    idx = torch.randint(0, 4, (L,), device=device)\n",
        "    x = -torch.ones((4, L), dtype=dtype, device=device)\n",
        "    x[idx, torch.arange(L, device=device)] = 1.0\n",
        "    return x\n",
        "\n",
        "def random_bipolar_batch(N: int, L: int, device=None, dtype=torch.float32) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    (N, 4, L) in {-1,+1}, exactly one +1 per column for each sample.\n",
        "    Fully vectorized (no Python loops).\n",
        "    \"\"\"\n",
        "    idx = torch.randint(0, 4, (N, L), device=device)            # base index per (n, pos)\n",
        "    x = -torch.ones((N, 4, L), dtype=dtype, device=device)      # start at -1\n",
        "    arN = torch.arange(N, device=device)[:, None]               # shape (N,1)\n",
        "    arL = torch.arange(L, device=device)[None, :]               # shape (1,L)\n",
        "    x[arN, idx, arL] = 1.0                                      # set chosen base to +1\n",
        "    return x\n",
        "\n",
        "# (optional) batch decoder and boolean motif mask\n",
        "def decode_sequences(X_4L: torch.Tensor) -> list[str]:\n",
        "    \"\"\"\n",
        "    X_4L: (N,4,L) tensor in {-1,+1} -> list of N strings over A/C/G/T\n",
        "    \"\"\"\n",
        "    if X_4L.dim() == 2:  # (4,L) -> single string\n",
        "        return [\"\".join(DNA[i] for i in torch.argmax(X_4L, dim=0).tolist())]\n",
        "    idx = torch.argmax(X_4L, dim=1).cpu().tolist()              # (N,L)\n",
        "    return [\"\".join(DNA[i] for i in row) for row in idx]\n",
        "\n",
        "def motif_mask(X_4L: torch.Tensor, motif: str = \"CACGTG\") -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Returns boolean tensor (N,) indicating motif presence in decoded sequences.\n",
        "    \"\"\"\n",
        "    seqs = decode_sequences(X_4L)\n",
        "    return torch.tensor([motif in s for s in seqs], dtype=torch.bool, device=X_4L.device)\n",
        "\n",
        "\n",
        "# --- GWG sampler over categorical moves (per-position base changes) ---\n",
        "\n",
        "import torch\n",
        "\n",
        "class GWGSamplerApproxConstrained:\n",
        "    \"\"\"\n",
        "    Gradient-based GWG over categorical bases with ±1 bipolar encoding,\n",
        "    under a hard Hamming-radius constraint around a reference sequence x_ref.\n",
        "\n",
        "    - Uses the same first-order ΔE approximation from ∇y as your original.\n",
        "    - If a sampled single-flip would exceed the Hamming cap, we *pair* it\n",
        "      with a gradient-guided reversion on some already-flipped position so\n",
        "      the net move stays on the Hamming boundary. Both legs are sampled\n",
        "      from the same GWG softmax over their respective candidate sets.\n",
        "    - MH acceptance uses the total approx ΔE and the exact q_rev / q_fwd,\n",
        "      mirroring whether the reverse path is 1-step or 2-step.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, model: torch.nn.Module, x_ref: torch.Tensor, max_hamming: int, beta: float = 1.0):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            model: scalar-output torch.nn.Module mapping (1, 4, L) -> scalar\n",
        "            x_ref: reference sequence (4, L) in ±1 one-hot/bipolar encoding\n",
        "            max_hamming: maximum allowed Hamming distance from x_ref\n",
        "            beta: inverse temperature for GWG proposals / MH acceptance\n",
        "        \"\"\"\n",
        "        self.model = model\n",
        "        self.beta  = float(beta)\n",
        "        self.x_ref = x_ref.detach().clone()\n",
        "        self.orig  = torch.argmax(self.x_ref, dim=0)  # (L,)\n",
        "        self.max_h = int(max_hamming)\n",
        "\n",
        "    # ---------- internals ----------\n",
        "    def _grad_y(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        with torch.enable_grad():\n",
        "            x = x.detach().clone().requires_grad_(True)\n",
        "            y = self.model(x.unsqueeze(0)).view(())  # scalar\n",
        "            (g,) = torch.autograd.grad(y, x, create_graph=False, retain_graph=False)\n",
        "        return g  # (4, L)\n",
        "\n",
        "    def _approx_deltas_and_moves(self, x: torch.Tensor):\n",
        "        \"\"\"\n",
        "        All single-coordinate categorical moves (3 per position).\n",
        "        Returns:\n",
        "            deltas : (M,) with ΔE ≈ -2 * (g[b,j] - g[a,j])\n",
        "            j_idx  : (M,)\n",
        "            b_idx  : (M,)\n",
        "        \"\"\"\n",
        "        device = x.device\n",
        "        L = x.shape[1]\n",
        "        g = self._grad_y(x)                  # (4, L)\n",
        "        curr = torch.argmax(x, dim=0)        # (L,)\n",
        "\n",
        "        g_a = g.gather(0, curr.unsqueeze(0).expand(4, -1))  # (4, L)\n",
        "        delta_mat = -2.0 * (g - g_a)                        # (4, L)\n",
        "\n",
        "        bgrid = torch.arange(4, device=device)[:, None].expand(4, L)\n",
        "        jgrid = torch.arange(L, device=device)[None, :].expand(4, L)\n",
        "        mask  = (bgrid != curr[None, :])  # exclude staying at same base\n",
        "\n",
        "        deltas = delta_mat[mask]\n",
        "        j_idx  = jgrid[mask]\n",
        "        b_idx  = bgrid[mask]\n",
        "        return deltas, j_idx, b_idx\n",
        "\n",
        "    def _hamming(self, x: torch.Tensor) -> int:\n",
        "        curr = torch.argmax(x, dim=0)\n",
        "        return int((curr != self.orig.to(x.device)).sum().item())\n",
        "\n",
        "    # ---------- one step with constraint ----------\n",
        "    def step(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        One constrained GWG step. May perform a 1-flip or (if needed) a 2-flip\n",
        "        (increase + reversion) to respect the Hamming cap, with correct MH ratio.\n",
        "        \"\"\"\n",
        "        x = x.detach().clone()\n",
        "        device = x.device\n",
        "\n",
        "        # --- forward: choose a single flip by GWG ---\n",
        "        deltas, j_idx, b_idx = self._approx_deltas_and_moves(x)  # uses grad at x\n",
        "        logits = -self.beta * deltas / 2.0\n",
        "        probs  = torch.softmax(logits, dim=0)\n",
        "        i = torch.multinomial(probs, 1).item()\n",
        "\n",
        "        j = int(j_idx[i].item())\n",
        "        b = int(b_idx[i].item())\n",
        "        a = int(torch.argmax(x[:, j]).item())\n",
        "\n",
        "        # Hamming change if we applied (j, a->b)\n",
        "        orig = self.orig.to(device)\n",
        "        old_diff = int(a != int(orig[j].item()))\n",
        "        new_diff = int(b != int(orig[j].item()))\n",
        "        ham_now  = self._hamming(x)\n",
        "        ham_next = ham_now + (new_diff - old_diff)\n",
        "\n",
        "        # Apply first flip candidate\n",
        "        x1 = x.clone()\n",
        "        x1[:, j] = -1.0\n",
        "        x1[b, j] = 1.0\n",
        "\n",
        "        q1      = probs[i]\n",
        "        delta1  = deltas[i]\n",
        "\n",
        "        # --- if within cap: standard 1-flip MH ---\n",
        "        if ham_next <= self.max_h:\n",
        "            deltas_p, j_idx_p, b_idx_p = self._approx_deltas_and_moves(x1)\n",
        "            rev_mask = (j_idx_p == j) & (b_idx_p == a)\n",
        "            if not torch.any(rev_mask):\n",
        "                return x1  # safety fallback; should be rare\n",
        "            logits_p = -self.beta * deltas_p / 2.0\n",
        "            probs_p  = torch.softmax(logits_p, dim=0)\n",
        "            q_rev    = probs_p[rev_mask].item()\n",
        "\n",
        "            accept = torch.exp(-self.beta * delta1) * (q_rev / q1)\n",
        "            if torch.rand((), device=device) < torch.clamp(accept, max=1.0):\n",
        "                return x1\n",
        "            return x\n",
        "\n",
        "        # --- otherwise: we must pair with a reversion to stay on boundary ---\n",
        "        # Build reversion candidate set at x1: moves that set some j' to its original base\n",
        "        deltas2, j_idx2, b_idx2 = self._approx_deltas_and_moves(x1)  # grad at x1\n",
        "        curr1 = torch.argmax(x1, dim=0)\n",
        "\n",
        "        # valid reversion entries: currently different AND target is original base\n",
        "        mask_revert_cols = (curr1 != orig)  # (L,)\n",
        "        # For each move entry, we want b_idx2 == orig[j_idx2] and that column is currently different\n",
        "        mask_entries = (b_idx2 == orig[j_idx2]) & mask_revert_cols[j_idx2]\n",
        "\n",
        "        # Sanity: must have at least one candidate (including possibly j itself)\n",
        "        if not torch.any(mask_entries):\n",
        "            # If this ever triggers, fall back to rejecting the first flip\n",
        "            return x\n",
        "\n",
        "        deltas_rev = deltas2[mask_entries]\n",
        "        j_rev_all  = j_idx2[mask_entries]\n",
        "        b_rev_all  = b_idx2[mask_entries]  # equals orig[j_rev]\n",
        "\n",
        "        logits2 = -self.beta * deltas_rev / 2.0\n",
        "        probs2  = torch.softmax(logits2, dim=0)\n",
        "        k = torch.multinomial(probs2, 1).item()\n",
        "\n",
        "        j2 = int(j_rev_all[k].item())\n",
        "        b2 = int(b_rev_all[k].item())      # = orig[j2]\n",
        "        a2 = int(torch.argmax(x1[:, j2]).item())\n",
        "\n",
        "        # Apply the reversion to get x2 on the boundary\n",
        "        x2 = x1.clone()\n",
        "        x2[:, j2] = -1.0\n",
        "        x2[b2, j2] = 1.0\n",
        "\n",
        "        q2     = probs2[k]\n",
        "        delta2 = deltas_rev[k]\n",
        "\n",
        "        # --- reverse proposal prob q_rev: mirror the same rule from x2 back to x ---\n",
        "        # From x2, the reverse proceeds as:\n",
        "        #   (1) flip j2 from orig -> a2 (this *exceeds* the cap),\n",
        "        #   (2) then revert j back to its original base to come to x.\n",
        "        # Compute q1_rev at x2:\n",
        "        deltas_b1, j_idx_b1, b_idx_b1 = self._approx_deltas_and_moves(x2)\n",
        "        logits_b1 = -self.beta * deltas_b1 / 2.0\n",
        "        probs_b1  = torch.softmax(logits_b1, dim=0)\n",
        "        mask_b1   = (j_idx_b1 == j2) & (b_idx_b1 == a2)\n",
        "        if not torch.any(mask_b1):\n",
        "            # Safety fallback: accept the paired move without MH correction\n",
        "            return x2\n",
        "        q1_rev = probs_b1[mask_b1].item()\n",
        "\n",
        "        # After (1), construct the intermediate x2_1 and sample the reversion for j back to orig\n",
        "        x2_1 = x2.clone()\n",
        "        x2_1[:, j2] = -1.0\n",
        "        x2_1[a2, j2] = 1.0\n",
        "\n",
        "        deltas_b2, j_idx_b2, b_idx_b2 = self._approx_deltas_and_moves(x2_1)\n",
        "        logits_b2 = -self.beta * deltas_b2 / 2.0\n",
        "        probs_b2  = torch.softmax(logits_b2, dim=0)\n",
        "        mask_b2   = (j_idx_b2 == j) & (b_idx_b2 == int(orig[j].item()))\n",
        "        if not torch.any(mask_b2):\n",
        "            return x2\n",
        "        q2_rev = probs_b2[mask_b2].item()\n",
        "\n",
        "        q_fwd = q1 * q2\n",
        "        q_rev = q1_rev * q2_rev\n",
        "\n",
        "        accept = torch.exp(-self.beta * (delta1 + delta2)) * (q_rev / q_fwd)\n",
        "        if torch.rand((), device=device) < torch.clamp(accept, max=1.0):\n",
        "            return x2\n",
        "        return x\n",
        "\n",
        "    # ---------- run the chain ----------\n",
        "    def sample(self, x0: torch.Tensor, steps: int, record_every: int = 1):\n",
        "        \"\"\"\n",
        "        Run the constrained chain, recording states/scores every 'record_every'.\n",
        "        Only the evaluation of y uses no_grad(); steps use gradients.\n",
        "        \"\"\"\n",
        "        xs, ys, accs = [], [], []\n",
        "        x = x0.detach().clone()\n",
        "        self.model.eval()\n",
        "        for t in range(1, steps + 1):\n",
        "            x_prev = x\n",
        "            x = self.step(x)\n",
        "            accepted = float((x != x_prev).any().item())\n",
        "            if (t % record_every) == 0:\n",
        "                with torch.no_grad():\n",
        "                    y_val = self.model(x.unsqueeze(0)).view(()).item()\n",
        "                xs.append(x.clone())\n",
        "                ys.append(y_val)\n",
        "                accs.append(accepted)\n",
        "        return xs, ys, accs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {},
      "outputs": [],
      "source": [
        "def check_if_particle_contains_motif(particles_np, motif=\"CACGTG\"):\n",
        "    \"\"\"\n",
        "    particles_np: (N, 4, L) bipolar one-hot\n",
        "    Returns: bool array of shape (N,) indicating presence of motif\n",
        "    \"\"\"\n",
        "    # if particles_np is torch.Tensor, convert to numpy\n",
        "    if isinstance(particles_np, torch.Tensor):\n",
        "        particles_np = particles_np.cpu().numpy()   \n",
        "    BASES = np.array(list(\"ACGT\"))\n",
        "    idx = particles_np.argmax(axis=1)                     # (N, L) best channel per position\n",
        "    seqs = [''.join(BASES[row]) for row in idx]            # list of strings length N\n",
        "    has = np.fromiter((motif in s for s in seqs), dtype=bool, count=len(seqs))\n",
        "    return has"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [],
      "source": [
        "N = 30 # number of particles\n",
        "\n",
        "run_count = 10\n",
        "\n",
        "max_hamming_distance = 7"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "def load_final_model():\n",
        "    checkpoint_path = \"ckpts/epoch_300.pt\"\n",
        "    model = MaxCNN(L=60, k=11, n_filters=64)\n",
        "    state = torch.load(checkpoint_path, map_location=\"cpu\")\n",
        "    model.load_state_dict(state)\n",
        "    model.eval()\n",
        "    return model #.cuda()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 59,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch, numpy as np\n",
        "\n",
        "task_name = \"GWG_constrained\"\n",
        "\n",
        "motif_fracs = []\n",
        "energies_per_run = []\n",
        "sequences_per_run = []\n",
        "\n",
        "for run in range(run_count):\n",
        "\n",
        "    baseline_motif_count = 0\n",
        "    energies = []\n",
        "    sequences = []\n",
        "\n",
        "    for i in range(N):\n",
        "        # Assumes MaxCNN, GWGSamplerApprox, random_bipolar_sequence, decode_sequence are already defined\n",
        "\n",
        "        L = 60\n",
        "        ckpt_path = \"ckpts/epoch_300.pt\"   # <- adjust if your path differs\n",
        "\n",
        "        # build model to match training hyperparams\n",
        "        model = MaxCNN(L=L, k=11, n_filters=64)\n",
        "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        model.to(device)\n",
        "\n",
        "        # load final weights (raw state_dict)\n",
        "        state = torch.load(ckpt_path, map_location=device)\n",
        "        model.load_state_dict(state)   # if your file was a dict with 'model_state_dict', use model.load_state_dict(state['model_state_dict'])\n",
        "        model.eval()\n",
        "\n",
        "\n",
        "        # run GWG (approximate) for 30 steps\n",
        "        x0 = random_bipolar_sequence(L, device=device)\n",
        "        sampler = GWGSamplerApproxConstrained(model, x_ref = x0, max_hamming = max_hamming_distance, beta=10.0)\n",
        "\n",
        "        xs, ys, accs = sampler.sample(x0, steps=60, record_every=1)\n",
        "\n",
        "        # quick summary\n",
        "        ys_np = np.array(ys)\n",
        "        best_i = int(np.argmax(ys_np))\n",
        "\n",
        "        sequences.append(decode_sequences(xs[best_i].reshape(1,4,L))[0])\n",
        "\n",
        "        for x in xs:\n",
        "            if \"CACGTG\" in decode_sequences(x.reshape(1,4,L))[0]:\n",
        "                #print(\"Found motif in sampled sequence:\")\n",
        "                baseline_motif_count +=1\n",
        "                break\n",
        "            \n",
        "        energies.append(ys_np[best_i])\n",
        "    motif_fracs.append(baseline_motif_count / N)\n",
        "    energies_per_run.append(energies)\n",
        "    sequences_per_run.append(sequences) \n",
        "    #save above code in csv\n",
        "\n",
        "np.savetxt(f\"motif_fracs_{task_name}.csv\", np.array(motif_fracs), delimiter=\",\")\n",
        "np.savetxt(f\"energies_per_run_{task_name}.csv\", np.array(energies_per_run), delimiter=\",\")\n",
        "np.savetxt(f\"sequences_per_run_{task_name}.csv\", np.array(sequences_per_run), delimiter=\",\", fmt=\"%s\")  \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 65,
      "metadata": {},
      "outputs": [],
      "source": [
        "import math\n",
        "import random\n",
        "from typing import Iterable, List, Dict, Any, Optional, Tuple\n",
        "import torch\n",
        "\n",
        "\n",
        "def make_beta_ladder(\n",
        "    n_replicas: int,\n",
        "    beta_min: float = 0.05,\n",
        "    beta_max: float = 1.0,\n",
        "    spacing: str = \"geometric\",\n",
        ") -> List[float]:\n",
        "    \"\"\"\n",
        "    Build an inverse-temperature ladder β_1 >= ... >= β_K.\n",
        "    Geometric spacing tends to equalize swap rates across the ladder.\n",
        "    \"\"\"\n",
        "    assert n_replicas >= 2\n",
        "    assert 0.0 < beta_min <= beta_max <= 1e6\n",
        "    if spacing == \"geometric\":\n",
        "        r = (beta_min / beta_max) ** (1.0 / (n_replicas - 1))\n",
        "        return [float(beta_max * (r ** i)) for i in range(n_replicas)]\n",
        "    elif spacing == \"linear\":\n",
        "        return [float(beta_max - i * (beta_max - beta_min) / (n_replicas - 1)) for i in range(n_replicas)]\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown spacing: {spacing}\")\n",
        "\n",
        "# assumes GWGSamplerApproxConstrained is defined (from earlier)\n",
        "\n",
        "def _hamming_to_ref(x: torch.Tensor, x_ref: torch.Tensor) -> int:\n",
        "    curr = torch.argmax(x, dim=0)\n",
        "    ref  = torch.argmax(x_ref, dim=0)\n",
        "    return int((curr != ref).sum().item())\n",
        "\n",
        "def _trim_init_to_cap(x: torch.Tensor, x_ref: torch.Tensor, max_hamming: int) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Ensure the *initial* state respects the Hamming cap w.r.t. x_ref.\n",
        "    If it's over, deterministically snap the minimum number of differing\n",
        "    columns back to the reference. (Only used at init; runtime moves stay\n",
        "    within cap via the constrained sampler.)\n",
        "    \"\"\"\n",
        "    x = x.clone()\n",
        "    curr = torch.argmax(x, dim=0)       # (L,)\n",
        "    ref  = torch.argmax(x_ref, dim=0)   # (L,)\n",
        "    diff = (curr != ref).nonzero(as_tuple=False).flatten()  # indices where differ\n",
        "    over = diff.numel() - max_hamming\n",
        "    if over > 0:\n",
        "        # revert the first 'over' differing positions back to the reference base\n",
        "        pos = diff[:over]\n",
        "        x[:, pos] = -1.0\n",
        "        x[ref[pos], pos] = 1.0\n",
        "    return x\n",
        "\n",
        "class ParallelTemperingGWGConstrained:\n",
        "    \"\"\"\n",
        "    Parallel Tempering (Replica Exchange) with constrained local GWG moves.\n",
        "\n",
        "    Key changes vs your baseline:\n",
        "      - Local proposals use GWGSamplerApproxConstrained(model, x_ref=x0, max_hamming, beta=β_k).\n",
        "      - The Hamming cap is enforced against a *global* reference x0 for all replicas.\n",
        "        This guarantees that replica swaps never violate the cap.\n",
        "      - Initial states (including 'mutate' mode) are clipped to the cap if needed.\n",
        "\n",
        "    Swap acceptance is unchanged; it uses exact energies E = -y(x):\n",
        "        a = exp((β_i - β_j) * (y_i - y_j)).\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self,\n",
        "                 model: torch.nn.Module,\n",
        "                 betas: Iterable[float],\n",
        "                 max_hamming: int):\n",
        "        betas = sorted([float(b) for b in betas], reverse=True)  # β_0 >= β_1 >= ...\n",
        "        if len(betas) < 2:\n",
        "            raise ValueError(\"Need at least 2 replicas for parallel tempering.\")\n",
        "        if max_hamming < 0:\n",
        "            raise ValueError(\"max_hamming must be >= 0.\")\n",
        "        self.model = model\n",
        "        self.betas = betas\n",
        "        self.max_h = int(max_hamming)\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _forward_y(self, x: torch.Tensor) -> float:\n",
        "        return float(self.model(x.unsqueeze(0)).view(()).item())\n",
        "\n",
        "    def sample(self,\n",
        "               x0: torch.Tensor,\n",
        "               steps: int,\n",
        "               swap_every: int = 10,\n",
        "               record_every: int = 1,\n",
        "               record_all: bool = False,\n",
        "               init_mode: str = \"copy\",\n",
        "               rng: Optional[random.Random] = None) -> Dict[str, Any]:\n",
        "        \"\"\"\n",
        "        Run PT for `steps` local updates per replica.\n",
        "\n",
        "        New behavior:\n",
        "          - The Hamming cap is enforced w.r.t. x0 (global reference) for *all* replicas.\n",
        "          - Initial states are clipped to the cap if init mutation goes too far.\n",
        "        \"\"\"\n",
        "        if rng is None:\n",
        "            rng = random\n",
        "\n",
        "        device = x0.device\n",
        "        K = len(self.betas)\n",
        "\n",
        "        # --- init states per replica (then clip to Hamming cap vs x0) ---\n",
        "        xs: List[torch.Tensor] = []\n",
        "        if init_mode == \"copy\":\n",
        "            for _ in range(K):\n",
        "                xs.append(x0.detach().clone())\n",
        "        elif init_mode == \"mutate\":\n",
        "            L = x0.shape[1]\n",
        "            ref_idx = torch.argmax(x0, dim=0)  # (L,)\n",
        "            for k in range(K):\n",
        "                xk = x0.detach().clone()\n",
        "                # For hotter replicas, randomly mutate ~25% of positions\n",
        "                frac = min(0.25 * (k / (K - 1) + 1e-8), 0.5)\n",
        "                m = int(round(frac * L))\n",
        "                if m > 0:\n",
        "                    pos = torch.randperm(L, device=device)[:m]\n",
        "                    new_b = torch.randint(0, 4, (m,), device=device)\n",
        "                    same = (new_b == ref_idx[pos])\n",
        "                    if torch.any(same):\n",
        "                        new_b[same] = (new_b[same] + 1) % 4\n",
        "                    xk[:, pos] = -1.0\n",
        "                    xk[new_b, torch.arange(m, device=device)] = 1.0\n",
        "                # clip to cap if needed\n",
        "                xk = _trim_init_to_cap(xk, x0, self.max_h)\n",
        "                xs.append(xk)\n",
        "        else:\n",
        "            raise ValueError(\"init_mode must be 'copy' or 'mutate'.\")\n",
        "\n",
        "        # --- build constrained local samplers, one per replica (global x_ref = x0) ---\n",
        "        samplers = [\n",
        "            GWGSamplerApproxConstrained(model=self.model, x_ref=x0, max_hamming=self.max_h, beta=b)\n",
        "            for b in self.betas\n",
        "        ]\n",
        "\n",
        "        # --- cache exact y for swap decisions ---\n",
        "        self.model.eval()\n",
        "        with torch.no_grad():\n",
        "            ys_exact = [self._forward_y(x) for x in xs]\n",
        "\n",
        "        # --- bookkeeping ---\n",
        "        local_accept_num = [0 for _ in range(K)]\n",
        "        local_attempts   = [0 for _ in range(K)]\n",
        "        swap_accept_num  = [0 for _ in range(K - 1)]\n",
        "        swap_attempts    = [0 for _ in range(K - 1)]\n",
        "\n",
        "        # traces\n",
        "        if record_all:\n",
        "            xs_trace: List[List[torch.Tensor]] = [[] for _ in range(K)]\n",
        "            ys_trace: List[List[float]] = [[] for _ in range(K)]\n",
        "        else:\n",
        "            xs_trace: List[torch.Tensor] = []\n",
        "            ys_trace: List[float] = []\n",
        "\n",
        "        # control even/odd swap pattern\n",
        "        swap_phase_even = True\n",
        "\n",
        "        for t in range(1, steps + 1):\n",
        "            # --- local constrained GWG moves (need grads ON) ---\n",
        "            torch.set_grad_enabled(True)\n",
        "            for k, sampler in enumerate(samplers):\n",
        "                x_prev = xs[k]\n",
        "                x_new  = sampler.step(x_prev)\n",
        "                local_attempts[k] += 1\n",
        "                if (x_new != x_prev).any().item():\n",
        "                    # sampler guarantees x_new is within the Hamming cap vs x0\n",
        "                    local_accept_num[k] += 1\n",
        "                    xs[k] = x_new\n",
        "                    with torch.no_grad():\n",
        "                        ys_exact[k] = self._forward_y(xs[k])\n",
        "            torch.set_grad_enabled(False)\n",
        "\n",
        "            # --- swap step (unchanged logic) ---\n",
        "            if (t % swap_every) == 0:\n",
        "                start = 0 if swap_phase_even else 1\n",
        "                for i in range(start, K - 1, 2):\n",
        "                    j = i + 1\n",
        "                    beta_i, beta_j = self.betas[i], self.betas[j]\n",
        "                    yi, yj = ys_exact[i], ys_exact[j]\n",
        "                    log_a = (beta_i - beta_j) * (yi - yj)\n",
        "                    a = math.exp(log_a) if log_a < 80 else float('inf')  # overflow guard\n",
        "                    swap_attempts[i] += 1\n",
        "                    u = float(torch.rand((), device=device).item())\n",
        "                    if u < min(1.0, a):\n",
        "                        # Swapping states preserves the global-cap guarantee because x_ref is x0 for all replicas\n",
        "                        xs[i], xs[j] = xs[j], xs[i]\n",
        "                        ys_exact[i], ys_exact[j] = ys_exact[j], ys_exact[i]\n",
        "                        swap_accept_num[i] += 1\n",
        "                swap_phase_even = not swap_phase_even\n",
        "\n",
        "            # --- record ---\n",
        "            if (t % record_every) == 0:\n",
        "                if record_all:\n",
        "                    for k in range(K):\n",
        "                        xs_trace[k].append(xs[k].detach().clone())\n",
        "                        ys_trace[k].append(ys_exact[k])\n",
        "                else:\n",
        "                    xs_trace.append(xs[0].detach().clone())  # cold chain (highest β)\n",
        "                    ys_trace.append(ys_exact[0])\n",
        "\n",
        "        # --- finalize stats ---\n",
        "        local_accept = [a / max(1, n) for a, n in zip(local_accept_num, local_attempts)]\n",
        "        swap_accept  = [a / max(1, n) for a, n in zip(swap_accept_num, swap_attempts)]\n",
        "\n",
        "        out: Dict[str, Any] = {\n",
        "            \"xs\": xs_trace,\n",
        "            \"ys\": ys_trace,\n",
        "            \"local_accept\": local_accept,         # per replica (cold..hot)\n",
        "            \"swap_accept\": swap_accept,           # per adjacent pair\n",
        "            \"final_states\": [x.detach().clone() for x in xs],\n",
        "            \"betas\": list(self.betas),\n",
        "        }\n",
        "        return out\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 66,
      "metadata": {},
      "outputs": [],
      "source": [
        "task_name = \"PT-GWG_constrained\"\n",
        "\n",
        "motif_fracs = []\n",
        "energies_per_run = []\n",
        "sequences_per_run = []\n",
        "\n",
        "for run in range(run_count):\n",
        "    baseline_motif_count = 0\n",
        "    energies = []\n",
        "    sequences = []\n",
        "\n",
        "    for i in range(N):\n",
        "        betas = make_beta_ladder(n_replicas=3, beta_min=0.05, beta_max=10.0, spacing=\"geometric\")\n",
        "        x0 = random_bipolar_sequence(L, device=device)\n",
        "\n",
        "        # Use the constrained PT wrapper (cap relative to x0 for *all* replicas)\n",
        "        pt = ParallelTemperingGWGConstrained(model, betas, max_hamming=max_hamming_distance)\n",
        "\n",
        "        out = pt.sample(x0, steps=20, swap_every=5, record_every=1, record_all=True, init_mode=\"mutate\")\n",
        "\n",
        "        tensor_out_xs = torch.stack([torch.stack(x) for x in out[\"xs\"]]).reshape(-1,4,L)\n",
        "\n",
        "        final_y_val = load_final_model()(tensor_out_xs)\n",
        "        best_seq_idx = int(torch.argmax(final_y_val).item())\n",
        "        best_seq = decode_sequences(tensor_out_xs[best_seq_idx].reshape(1,4,L))[0]\n",
        "        best_seq_y = final_y_val[best_seq_idx].item()\n",
        "\n",
        "        baseline_motif_count += check_if_particle_contains_motif(tensor_out_xs.reshape(-1,4,L).cpu().numpy(), motif=\"CACGTG\")\n",
        "        energies.append(best_seq_y)\n",
        "        sequences.append(best_seq)\n",
        "\n",
        "    motif_fracs.append(baseline_motif_count / (N))\n",
        "    energies_per_run.append(energies)\n",
        "    sequences_per_run.append(sequences)\n",
        "\n",
        "np.savetxt(f\"motif_fracs_{task_name}.csv\", np.array(motif_fracs), delimiter=\",\")\n",
        "np.savetxt(f\"energies_per_run_{task_name}.csv\", np.array(energies_per_run), delimiter=\",\")\n",
        "np.savetxt(f\"sequences_per_run_{task_name}.csv\", np.array(sequences_per_run), delimiter=\",\", fmt=\"%s\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "ename": "ValueError",
          "evalue": "Expected 1D or 2D array, got 3D array instead",
          "output_type": "error",
          "traceback": [
            "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
            "\u001b[31mValueError\u001b[39m                                Traceback (most recent call last)",
            "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[73]\u001b[39m\u001b[32m, line 262\u001b[39m\n\u001b[32m    259\u001b[39m     sequences_per_run.append(final_x_val)\n\u001b[32m    261\u001b[39m np.savetxt(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mmotif_fracs_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtask_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.csv\u001b[39m\u001b[33m\"\u001b[39m, np.array(motif_fracs), delimiter=\u001b[33m\"\u001b[39m\u001b[33m,\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m262\u001b[39m \u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43msavetxt\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43menergies_per_run_\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mtask_name\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[33;43m.csv\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43menergies_per_run\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdelimiter\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m,\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m    263\u001b[39m np.savetxt(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33msequences_per_run_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtask_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.csv\u001b[39m\u001b[33m\"\u001b[39m, np.array(sequences_per_run), delimiter=\u001b[33m\"\u001b[39m\u001b[33m,\u001b[39m\u001b[33m\"\u001b[39m, fmt=\u001b[33m\"\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m\"\u001b[39m)\n",
            "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/torch-gpu-env/lib/python3.13/site-packages/numpy/lib/_npyio_impl.py:1579\u001b[39m, in \u001b[36msavetxt\u001b[39m\u001b[34m(fname, X, fmt, delimiter, newline, header, footer, comments, encoding)\u001b[39m\n\u001b[32m   1577\u001b[39m \u001b[38;5;66;03m# Handle 1-dimensional arrays\u001b[39;00m\n\u001b[32m   1578\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m X.ndim == \u001b[32m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m X.ndim > \u001b[32m2\u001b[39m:\n\u001b[32m-> \u001b[39m\u001b[32m1579\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m   1580\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33mExpected 1D or 2D array, got \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[33mD array instead\u001b[39m\u001b[33m\"\u001b[39m % X.ndim)\n\u001b[32m   1581\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m X.ndim == \u001b[32m1\u001b[39m:\n\u001b[32m   1582\u001b[39m     \u001b[38;5;66;03m# Common case -- 1d array of numbers\u001b[39;00m\n\u001b[32m   1583\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m X.dtype.names \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
            "\u001b[31mValueError\u001b[39m: Expected 1D or 2D array, got 3D array instead"
          ]
        }
      ],
      "source": [
        "# ais_autotemp_gwg_constrained.py\n",
        "# End-to-end AIS with Adaptive Tempering (no resampling) using the *constrained* GWG kernel.\n",
        "\n",
        "import math\n",
        "from typing import Callable, Optional, Dict, Any, List, Tuple\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "# Assumes GWGSamplerApproxConstrained is available in scope:\n",
        "# from <your_module> import GWGSamplerApproxConstrained\n",
        "\n",
        "\n",
        "def _trim_init_to_cap_batch(X: torch.Tensor, X_ref: torch.Tensor, max_hamming: int) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Ensure each initial particle is within the Hamming cap vs its own reference.\n",
        "    If over, revert the minimum number of differing columns back to the reference base.\n",
        "    X, X_ref: (N, 4, L), bipolar one-hot.\n",
        "    \"\"\"\n",
        "    if max_hamming is None:\n",
        "        return X\n",
        "    X = X.clone()\n",
        "    N, _, L = X.shape\n",
        "    curr = torch.argmax(X, dim=1)       # (N, L)\n",
        "    ref  = torch.argmax(X_ref, dim=1)   # (N, L)\n",
        "    for i in range(N):\n",
        "        diff_pos = torch.nonzero(curr[i] != ref[i], as_tuple=False).flatten()\n",
        "        over = int(diff_pos.numel()) - int(max_hamming)\n",
        "        if over > 0:\n",
        "            pos = diff_pos[:over]\n",
        "            X[i, :, pos] = -1.0\n",
        "            X[i, ref[i, pos], pos] = 1.0\n",
        "    return X\n",
        "\n",
        "\n",
        "class AISAutoTempGWG:\n",
        "    \"\"\"\n",
        "    AIS with ESS-targeted adaptive temperatures, using *constrained* GWG for rejuvenation.\n",
        "    No resampling at any stage. Preserves one-to-one lineages with starting particles.\n",
        "\n",
        "    Target: pi_beta(x) ∝ exp(-beta * E(x)), default E(x) = -model(x).\n",
        "\n",
        "    Constraint: For each lineage i, we enforce Hamming(x_i, x_ref_i) <= max_hamming\n",
        "                at all times, where x_ref_i is the *initial* sequence x0[i].\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        model: nn.Module,\n",
        "        energy_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,  # (N,4,L)->(N,)\n",
        "        ess_target_frac: float = 0.6,\n",
        "        K: int = 5,                              # GWG steps per stage\n",
        "        max_stages: Optional[int] = None,        # cap on # of stages (compute parity)\n",
        "        beta_final: float = 1.0,\n",
        "        line_search_tol: float = 1e-4,\n",
        "        dtype=torch.float32,\n",
        "        max_hamming: int = 10,                   # <<< NEW: per-lineage Hamming radius\n",
        "        trim_init_to_cap: bool = True,           # <<< NEW: clip x0 to the cap if needed\n",
        "    ):\n",
        "        self.model = model\n",
        "        self.energy_fn = energy_fn\n",
        "        self.ess_target_frac = ess_target_frac\n",
        "        self.K = K\n",
        "        self.max_stages = max_stages\n",
        "        self.beta_final = float(beta_final)\n",
        "        self.line_search_tol = line_search_tol\n",
        "        self.dtype = dtype\n",
        "        self.max_hamming = int(max_hamming)\n",
        "        self.trim_init_to_cap = bool(trim_init_to_cap)\n",
        "\n",
        "    # ---- batch-safe scoring/energy helpers ----\n",
        "    def _score_batch_default(self, X: torch.Tensor) -> torch.Tensor:\n",
        "        \"\"\"Score y(x) from the model. Tries batching; falls back to per-sample if needed.\"\"\"\n",
        "        try:\n",
        "            y = self.model(X).view(-1)  # expect (N,)->(N,)\n",
        "            return y.to(X.device)\n",
        "        except Exception:\n",
        "            N = X.shape[0]\n",
        "            ys = []\n",
        "            for i in range(N):\n",
        "                ys.append(self.model(X[i].unsqueeze(0)).view(()).to(X.device))\n",
        "            return torch.stack(ys)\n",
        "\n",
        "    def _energies(self, X: torch.Tensor) -> torch.Tensor:\n",
        "        if self.energy_fn is not None:\n",
        "            with torch.no_grad():\n",
        "                return self.energy_fn(X)\n",
        "        with torch.no_grad():\n",
        "            return -self._score_batch_default(X)\n",
        "\n",
        "    def _scores(self, X: torch.Tensor) -> torch.Tensor:\n",
        "        if self.energy_fn is not None:\n",
        "            with torch.no_grad():\n",
        "                return -self.energy_fn(X)\n",
        "        return self._score_batch_default(X)\n",
        "\n",
        "    # ---- AIS math ----\n",
        "    def _predicted_logw(self, logw: torch.Tensor, E: torch.Tensor, beta_old: float, beta_new: float) -> torch.Tensor:\n",
        "        return logw - (beta_new - beta_old) * E\n",
        "\n",
        "    def _ess_from_logw(self, logw: torch.Tensor) -> float:\n",
        "        a = logw - torch.max(logw)\n",
        "        w = torch.exp(a)\n",
        "        s1 = torch.sum(w)\n",
        "        s2 = torch.sum(w * w)\n",
        "        return float((s1 * s1 / s2).item())\n",
        "\n",
        "    def _choose_next_beta(self, logw: torch.Tensor, E: torch.Tensor, beta: float, N: int) -> float:\n",
        "        target = self.ess_target_frac * N\n",
        "\n",
        "        # try jumping to beta_final\n",
        "        logw_full = self._predicted_logw(logw, E, beta, self.beta_final)\n",
        "        if self._ess_from_logw(logw_full) >= target:\n",
        "            return self.beta_final\n",
        "\n",
        "        low, high = beta, self.beta_final  # ESS(low) >= target; ESS(high) < target\n",
        "        for _ in range(32):\n",
        "            mid = 0.5 * (low + high)\n",
        "            logw_mid = self._predicted_logw(logw, E, beta, mid)\n",
        "            ess_mid = self._ess_from_logw(logw_mid)\n",
        "            if ess_mid >= target:\n",
        "                low = mid\n",
        "            else:\n",
        "                high = mid\n",
        "            if (high - low) < self.line_search_tol:\n",
        "                break\n",
        "        return max(low, beta + 1e-8)\n",
        "\n",
        "    # ---- main run ----\n",
        "    def run(\n",
        "        self,\n",
        "        x0: torch.Tensor,                                         # (N,4,L)\n",
        "        motif_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,  # returns Bool[N] (torch)\n",
        "        seed: Optional[int] = None,\n",
        "    ) -> Dict[str, Any]:\n",
        "        if seed is not None:\n",
        "            torch.manual_seed(seed)\n",
        "\n",
        "        device = x0.device\n",
        "        N, _, L = x0.shape\n",
        "        self.model.eval()\n",
        "\n",
        "        # Per-lineage references are *frozen* to the initial x0\n",
        "        X_ref = x0.detach().clone()\n",
        "\n",
        "        # Optionally trim the actual initial X to satisfy the cap\n",
        "        X = x0.detach().clone().to(device, self.dtype)\n",
        "        if self.trim_init_to_cap and (self.max_hamming is not None):\n",
        "            X = _trim_init_to_cap_batch(X, X_ref, self.max_hamming)\n",
        "\n",
        "        logw = torch.zeros(N, device=device, dtype=self.dtype)\n",
        "        beta = 0.0\n",
        "\n",
        "        # init energies and per-lineage stats\n",
        "        E = self._energies(X).to(device, self.dtype)\n",
        "        best_y = self._scores(X)                     # (N,)\n",
        "        hit_ever = torch.zeros(N, dtype=torch.bool, device=device)\n",
        "        if motif_fn is not None:\n",
        "            with torch.no_grad():\n",
        "                # motif_fn should accept (N,4,L) and return Bool[N]\n",
        "                hit_ever |= torch.from_numpy(motif_fn(X)).cuda()\n",
        "\n",
        "        # Build *per-particle* constrained kernels (all start at beta=0)\n",
        "        kernels: List[GWGSamplerApproxConstrained] = [\n",
        "            GWGSamplerApproxConstrained(model=self.model, x_ref=X_ref[i], max_hamming=self.max_hamming, beta=beta)\n",
        "            for i in range(N)\n",
        "        ]\n",
        "\n",
        "        betas: List[float] = [beta]\n",
        "        ess_hist: List[float] = [self._ess_from_logw(logw)]\n",
        "        stage = 0\n",
        "        total_gwg_steps = 0\n",
        "\n",
        "        while beta < self.beta_final:\n",
        "            # 1) choose next beta' by ESS target\n",
        "            beta_next = self._choose_next_beta(logw, E, beta, N)\n",
        "\n",
        "            # 2) AIS weight update at old states\n",
        "            logw = self._predicted_logw(logw, E, beta, beta_next)\n",
        "\n",
        "            # 3) rejuvenation at beta' with K constrained GWG sweeps\n",
        "            for k in kernels:\n",
        "                k.beta = float(beta_next)\n",
        "\n",
        "            for _ in range(self.K):\n",
        "                # Need grads ON inside step; step turns them on locally via torch.enable_grad()\n",
        "                for i in range(N):\n",
        "                    X[i] = kernels[i].step(X[i])\n",
        "                total_gwg_steps += N\n",
        "\n",
        "                # per-sweep bookkeeping\n",
        "                if motif_fn is not None:\n",
        "                    with torch.no_grad():\n",
        "                        hit_ever |= torch.from_numpy(motif_fn(X)).cuda()\n",
        "                with torch.no_grad():\n",
        "                    y_now = self._scores(X)\n",
        "                best_y = torch.maximum(best_y, y_now)\n",
        "\n",
        "            # 4) refresh energies for next stage’s update\n",
        "            E = self._energies(X).to(device, self.dtype)\n",
        "\n",
        "            beta = beta_next\n",
        "            betas.append(beta)\n",
        "            ess_hist.append(self._ess_from_logw(logw))\n",
        "            stage += 1\n",
        "\n",
        "            if self.max_stages is not None and stage >= self.max_stages:\n",
        "                if beta < self.beta_final:\n",
        "                    logw = self._predicted_logw(logw, E, beta, self.beta_final)\n",
        "                    beta = self.beta_final\n",
        "                    betas.append(beta)\n",
        "                    ess_hist.append(self._ess_from_logw(logw))\n",
        "                break\n",
        "\n",
        "        # final weighted motif probability at beta=beta_final (pre-resample)\n",
        "        out: Dict[str, Any] = {\n",
        "            \"xs\": X,                     # (N,4,L) final states\n",
        "            \"logw\": logw,                # (N,) final log-weights\n",
        "            \"betas\": betas,              # list of betas visited\n",
        "            \"ess_history\": ess_hist,     # ESS after each stage\n",
        "            \"E\": E,                      # final energies\n",
        "            \"best_y\": best_y,            # (N,) best y along lineage\n",
        "            \"hit_ever\": hit_ever,        # (N,) ever-hit flag (torch.bool)\n",
        "            \"total_gwg_steps\": total_gwg_steps,\n",
        "        }\n",
        "        return out\n",
        "\n",
        "task_name = \"AISAutoTemp-GWG_constrained\"\n",
        "\n",
        "motif_fracs = []\n",
        "energies_per_run = []\n",
        "sequences_per_run = []\n",
        "\n",
        "for run in range(run_count):\n",
        "    x0 = random_bipolar_batch(N, L, device=device)\n",
        "\n",
        "    # Compute parity example: ~3 stages * 20 GWG = 60 steps\n",
        "    ais = AISAutoTempGWG(\n",
        "        model=model,\n",
        "        energy_fn=None,           # default E = -model(x)\n",
        "        ess_target_frac=0.6,\n",
        "        K=20,\n",
        "        max_stages=3,             # keep total ≈ 60\n",
        "        beta_final=10.0,\n",
        "        max_hamming=max_hamming_distance,           # <<< enforce cap per particle vs its own x0[i]\n",
        "        trim_init_to_cap=True,    # <<< ensure starting states satisfy the cap\n",
        "    )\n",
        "\n",
        "    out = ais.run(x0, motif_fn=check_if_particle_contains_motif)\n",
        "\n",
        "    final_y_val = load_final_model()(out[\"xs\"])\n",
        "    final_x_val = decode_sequences(out[\"xs\"].reshape(-1,4,L))\n",
        "\n",
        "    motif_frac = (out[\"hit_ever\"].sum().item()) / N\n",
        "\n",
        "    motif_fracs.append(motif_frac)\n",
        "    energies_per_run.append(final_y_val.cpu().numpy().tolist())\n",
        "    sequences_per_run.append(final_x_val)\n",
        "\n",
        "np.savetxt(f\"motif_fracs_{task_name}.csv\", np.array(motif_fracs), delimiter=\",\")\n",
        "np.savetxt(f\"energies_per_run_{task_name}.csv\", np.array(energies_per_run).squeeze(), delimiter=\",\")\n",
        "np.savetxt(f\"sequences_per_run_{task_name}.csv\", np.array(sequences_per_run), delimiter=\",\", fmt=\"%s\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [],
      "source": [
        "import os, re, glob, torch\n",
        "\n",
        "def resolve_checkpoint_path(ckpt_dir: str, epoch: int) -> str:\n",
        "    \"\"\"\n",
        "    Return a checkpoint path for the given epoch.\n",
        "    Tries common names, then falls back to a glob search.\n",
        "    \"\"\"\n",
        "    cand = [\n",
        "        os.path.join(ckpt_dir, f\"epoch_{epoch:03d}.pt\"),\n",
        "        os.path.join(ckpt_dir, f\"epoch_{epoch}.pt\"),\n",
        "        os.path.join(ckpt_dir, f\"ckpt_epoch_{epoch:03d}.pt\"),\n",
        "        os.path.join(ckpt_dir, f\"ckpt_epoch_{epoch}.pt\"),\n",
        "    ]\n",
        "    for p in cand:\n",
        "        if os.path.exists(p):\n",
        "            return p\n",
        "\n",
        "    # Fallback: any *.pt that clearly references this epoch number\n",
        "    pat = re.compile(rf\".*[^0-9]{epoch}[^0-9]*\\.pt$\")\n",
        "    for p in sorted(glob.glob(os.path.join(ckpt_dir, \"*.pt\"))):\n",
        "        if pat.match(os.path.basename(p)):\n",
        "            return p\n",
        "\n",
        "    raise FileNotFoundError(f\"No checkpoint found for epoch {epoch} in {ckpt_dir}\")\n",
        "\n",
        "def load_weights_into(model: torch.nn.Module, ckpt_path: str, device: str = \"cuda\", strict: bool = False):\n",
        "    \"\"\"\n",
        "    Load weights from a checkpoint into `model`.\n",
        "    Supports either:\n",
        "      - raw state_dict\n",
        "      - dict with key 'model_state_dict'\n",
        "    Also strips an optional leading 'module.' (from DataParallel).\n",
        "    \"\"\"\n",
        "    sd = torch.load(ckpt_path, map_location=device)\n",
        "\n",
        "    # unwrap container dicts\n",
        "    if isinstance(sd, dict) and \"model_state_dict\" in sd:\n",
        "        sd = sd[\"model_state_dict\"]\n",
        "\n",
        "    # handle DataParallel prefixes if present\n",
        "    if isinstance(sd, dict) and any(k.startswith(\"module.\") for k in sd.keys()):\n",
        "        sd = {k[len(\"module.\"):]: v for k, v in sd.items()}\n",
        "\n",
        "    missing, unexpected = model.load_state_dict(sd, strict=strict)\n",
        "    if missing or unexpected:\n",
        "        print(f\"[load_weights_into] missing keys: {missing} | unexpected keys: {unexpected}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "@torch.no_grad()\n",
        "def model_y_batch(model: torch.nn.Module, X_4L: torch.Tensor, chunk: int = 0) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    X_4L: (N, 4, L) -> (N,) predicted y values.\n",
        "    Set chunk>0 to evaluate in mini-batches (useful if N is large / low VRAM).\n",
        "    \"\"\"\n",
        "    if X_4L.dim() == 2:  # (4, L) -> (1,)\n",
        "        return model(X_4L.unsqueeze(0)).view(-1)\n",
        "\n",
        "    if chunk and X_4L.size(0) > chunk:\n",
        "        outs = []\n",
        "        for i in range(0, X_4L.size(0), chunk):\n",
        "            outs.append(model(X_4L[i:i+chunk]).view(-1))\n",
        "        return torch.cat(outs, dim=0)\n",
        "    else:\n",
        "        return model(X_4L).view(-1)\n",
        "\n",
        "@torch.no_grad()\n",
        "def energy_batch(model: torch.nn.Module, X_4L: torch.Tensor, chunk: int = 0) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Energy E(x) = -y(x). Returns (N,).\n",
        "    \"\"\"\n",
        "    return -model_y_batch(model, X_4L, chunk=chunk)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "# assumes you have GWGSamplerApproxConstrained defined as in the previous message\n",
        "\n",
        "def smc_via_checkpoints(\n",
        "    checkpoint_dir: str,\n",
        "    epochs: list[int],\n",
        "    model_factory,                 # e.g. lambda: MaxCNN(L=60, k=11, n_filters=64)\n",
        "    L: int,\n",
        "    num_particles: int = 1,        # you said this is run with N=1, but works for N>1 as well\n",
        "    mcmc_steps: int = 15,\n",
        "    resample_thresh: float = 0.5,  # unused in this stripped-down logger, kept for API compatibility\n",
        "    device: str = \"cuda\",\n",
        "    beta: float = 1.0,\n",
        "    last_epoch: int | None = None, # kept for API compatibility (unused below)\n",
        "    max_hamming: int = 10,         # <<< NEW: Hamming cap\n",
        "):\n",
        "    \"\"\"\n",
        "    SMC over checkpoints, but rejuvenation uses the constrained GWG sampler.\n",
        "    - For each checkpoint, we instantiate a NEW GWGSamplerApproxConstrained(model, x_ref, max_hamming, beta).\n",
        "    - The reference x_ref for each particle is its *initial* sequence at the start of the whole run,\n",
        "      so the Hamming constraint is enforced w.r.t. that fixed starting sequence across *all* checkpoints.\n",
        "\n",
        "    Returns:\n",
        "        all_seqs: list of tensors (4, L), the sequence trace (per-step records) across all rejuvenations\n",
        "    \"\"\"\n",
        "    # choose a real device\n",
        "    device = torch.device(device if torch.cuda.is_available() or device == \"cpu\" else \"cpu\")\n",
        "\n",
        "    epochs = sorted(list(epochs))\n",
        "    ckpts  = [resolve_checkpoint_path(checkpoint_dir, e) for e in epochs]\n",
        "\n",
        "    # ---------------------------------------------------------------------\n",
        "    # Initialize particles & per-particle fixed references (x_ref_i)\n",
        "    # ---------------------------------------------------------------------\n",
        "    particles = random_bipolar_batch(num_particles, L, device=device)  # (N, 4, L)\n",
        "    x_refs    = [particles[i].detach().clone() for i in range(num_particles)]  # fixed refs\n",
        "\n",
        "    all_seqs = []\n",
        "\n",
        "    # ---------------------------------------------------------------------\n",
        "    # Iterate checkpoints; for each, create a fresh model & constrained sampler\n",
        "    # ---------------------------------------------------------------------\n",
        "    for t, ckpt in enumerate(ckpts, start=1):\n",
        "        # Load CURRENT checkpoint model (used for gradients/proposals)\n",
        "        model = model_factory().to(device).eval()\n",
        "        load_weights_into(model, ckpt, device=str(device))\n",
        "\n",
        "        # Rejuvenation with constrained GWG; one independent chain per particle\n",
        "        for i in range(num_particles):\n",
        "            x      = particles[i]\n",
        "            x_ref  = x_refs[i]  # fixed reference for this particle throughout the whole run\n",
        "\n",
        "            # Fresh constrained sampler for THIS checkpoint (as requested)\n",
        "            sampler = GWGSamplerApproxConstrained(\n",
        "                model=model,\n",
        "                x_ref=x_ref,\n",
        "                max_hamming=max_hamming,\n",
        "                beta=beta,\n",
        "            )\n",
        "\n",
        "            # Run mcmc_steps via sampler.sample; collect and persist last state\n",
        "            # record_every=1 so we log every step (matches your previous all_seqs behavior)\n",
        "            xs, ys, accs = sampler.sample(x0=x, steps=mcmc_steps, record_every=1)\n",
        "\n",
        "            # Append all recorded states from this rejuvenation\n",
        "            all_seqs.extend(xs)\n",
        "\n",
        "            # Update particle to the last sampled state (if no steps, keep x as is)\n",
        "            if len(xs) > 0:\n",
        "                particles[i] = xs[-1]\n",
        "            else:\n",
        "                particles[i] = x\n",
        "\n",
        "    return all_seqs\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "task_name = \"SMC-GWG_constrained\"\n",
        "\n",
        "\n",
        "motif_fracs = []\n",
        "energies_per_run = []\n",
        "sequences_per_run = []\n",
        "\n",
        "for run in range(run_count):\n",
        "    motif_count = 0\n",
        "    energies = []\n",
        "    sequences = []\n",
        "    for k in range(N):\n",
        "        L = 60\n",
        "        particles_smc = smc_via_checkpoints(\n",
        "            checkpoint_dir=\"ckpts\",\n",
        "            epochs=[25, 50, 300],                         # stages you want to traverse\n",
        "            model_factory=lambda: MaxCNN(L=L, k=11, n_filters=64),\n",
        "            L=L,\n",
        "            num_particles=1,\n",
        "            mcmc_steps=20,\n",
        "            resample_thresh=0.5,\n",
        "            device=\"cuda\",\n",
        "            beta=10.0,\n",
        "            last_epoch=300,                                   # <<< use this epoch to score best/final y\n",
        "            max_hamming=7,                                   # <<< NEW: Hamming cap\n",
        "        )\n",
        "        particles_smc = torch.stack(particles_smc).reshape(-1,4,L) #.to(\"cuda\")\n",
        "        \n",
        "        final_y_val = load_final_model()(particles_smc)\n",
        "        idx = final_y_val.argmax(axis=0)\n",
        "        best_seq = decode_sequences(particles_smc[idx].reshape(1,4,L))[0]\n",
        "        best_seq_y = final_y_val[idx].item()\n",
        "\n",
        "\n",
        "        energies.append(best_seq_y)\n",
        "        sequences.append(best_seq)\n",
        "\n",
        "    \n",
        "\n",
        "        import numpy as np\n",
        "        BASES = np.array(list(\"ACGT\"))\n",
        "        MOTIF = \"CACGTG\"  # palindromic, so no RC needed\n",
        "\n",
        "        # Decode all particles at once\n",
        "        idx = particles_smc.argmax(axis=1).cpu().data.numpy()                     # (N, L) best channel per position\n",
        "        seqs = [''.join(BASES[row]) for row in idx]            # list of strings length N\n",
        "\n",
        "        # Boolean mask: does the sequence contain the motif?\n",
        "        has = np.fromiter((MOTIF in s for s in seqs), dtype=bool, count=len(seqs))\n",
        "\n",
        "        if has.sum() > 0:\n",
        "            #print(\"Contains motif\")\n",
        "            motif_count+=1\n",
        "\n",
        "\n",
        "\n",
        "        \n",
        "    motif_fracs.append(motif_count / N)\n",
        "    energies_per_run.append(energies)\n",
        "    sequences_per_run.append(sequences)\n",
        "\n",
        "\n",
        "np.savetxt(f\"motif_fracs_{task_name}.csv\", np.array(motif_fracs), delimiter=\",\")\n",
        "np.savetxt(f\"energies_per_run_{task_name}.csv\", np.array(energies_per_run), delimiter=\",\")\n",
        "np.savetxt(f\"sequences_per_run_{task_name}.csv\", np.array(sequences_per_run), delimiter=\",\", fmt=\"%s\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "torch-gpu-env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
