{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "L4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# ---------- 0) Colab/L4 env: auto-install PyG wheels ----------\n",
        "import os, re, sys, subprocess\n",
        "try:\n",
        "    import torch\n",
        "    torch_ver = torch.__version__.split('+')[0]\n",
        "    cuda_raw  = getattr(torch.version, \"cuda\", None)\n",
        "    if cuda_raw is None:\n",
        "        cuda_tag = \"cpu\"\n",
        "    else:\n",
        "        m = re.match(r\"(\\d+)\\.(\\d+)\", str(cuda_raw))\n",
        "        cuda_tag = f\"cu{m.group(1)}{m.group(2)}\" if m else \"cu121\"\n",
        "    wheels_url = f\"https://data.pyg.org/whl/torch-{torch_ver}+{cuda_tag}.html\"\n",
        "    print(f\"[Setup] torch={torch.__version__}, CUDA={cuda_raw}, PyG index={wheels_url}\")\n",
        "    subprocess.run(\n",
        "        [sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n",
        "         \"torch-scatter\", \"torch-sparse\", \"torch-cluster\", \"torch-spline-conv\", \"torch-geometric\",\n",
        "         \"-f\", wheels_url],\n",
        "        check=True\n",
        "    )\n",
        "    subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"networkx\", \"pandas\", \"matplotlib\"], check=True)\n",
        "except Exception as e:\n",
        "    print(\"[Setup warning] PyG auto-install skipped or partially failed:\", e)\n",
        "\n",
        "# ---------- 1) Imports ----------\n",
        "import math, copy, random, warnings, json\n",
        "from pathlib import Path\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import matplotlib as mpl\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from torch import nn\n",
        "from torch.nn.utils import spectral_norm, clip_grad_norm_\n",
        "\n",
        "from torch_geometric.nn import (\n",
        "    GINConv, GCNConv, SAGEConv, GATConv, PNAConv, GCN2Conv, APPNP, SGConv, TransformerConv,\n",
        "    global_add_pool, global_mean_pool\n",
        ")\n",
        "from torch_geometric.datasets import TUDataset\n",
        "from torch_geometric.loader import DataLoader\n",
        "from torch_geometric.data import Data, Batch\n",
        "from torch_geometric.utils import to_networkx, from_networkx, degree\n",
        "\n",
        "import networkx as nx\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "if hasattr(torch, \"set_float32_matmul_precision\"):\n",
        "    torch.set_float32_matmul_precision(\"medium\")\n",
        "\n",
        "# ---------- 2) Config ----------\n",
        "class Config:\n",
        "    DATASETS = ['PROTEINS', 'NCI1', 'COLLAB']     # 你也可以加上 'MUTAG','IMDB-BINARY'（更快）\n",
        "    BACKBONES = ['GIN', 'GCN', 'GraphSAGE']       # 如需更多：'GAT','PNA','GCNII','APPNP','SGC','Transformer'\n",
        "    SEEDS = [41, 42, 43]                          # 多 seed 聚合\n",
        "\n",
        "    NUM_LAYERS = 3\n",
        "    HIDDEN = 64\n",
        "    DROPOUT = 0.2\n",
        "    HEADS = 4          # for GAT/Transformer\n",
        "    PNA_AGGRS = ['sum', 'mean', 'max']\n",
        "    PNA_SCALERS = ['identity', 'amplification', 'attenuation']\n",
        "    PNA_TOWERS = 1\n",
        "    PNA_PRE_LAYERS = 1\n",
        "    PNA_POST_LAYERS = 1\n",
        "    SGC_K = 3\n",
        "    APPNP_K = 10\n",
        "    APPNP_ALPHA = 0.1\n",
        "    GCNII_ALPHA = 0.1\n",
        "    GCNII_THETA = None\n",
        "\n",
        "    LR = 0.01\n",
        "    EPOCHS = 80\n",
        "    BATCH = 128\n",
        "    WD = 0.0\n",
        "    NUM_WORKERS = 0\n",
        "    MAX_GRAD_NORM = 1.0\n",
        "\n",
        "    WM_BITS = 128\n",
        "    BETA_GRID = [0, 5e-7, 1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2]\n",
        "    OWNER_ALPHA = 1e-6\n",
        "\n",
        "    CARRIER_SWAPS = [2, 4, 6, 8, 12, 16]\n",
        "    MAX_CARRIER_N_PERCENTILE = 25\n",
        "    BAND_LOW = 0.35\n",
        "    BAND_HIGH = 0.65\n",
        "    MAX_BAND_RELAX_STEPS = 4\n",
        "    SOFT_LABEL_EPS = 0.1\n",
        "    NUM_IMPOSTOR_KEYS = 200\n",
        "\n",
        "    # Robustness toggles\n",
        "    RUN_PRUNE = True\n",
        "    PRUNE_FRACS = [0.2, 0.4, 0.5]\n",
        "\n",
        "    RUN_FINETUNE = True\n",
        "    FT_EPOCHS = 20\n",
        "\n",
        "    RUN_FINEPRUNE = True\n",
        "    FPRUNE_FRAC = 0.4\n",
        "    FPRUNE_FT_EPOCHS = 15\n",
        "\n",
        "    RUN_DISTILL = True\n",
        "    KD_T = 2.0\n",
        "    KD_EPOCHS = 20\n",
        "    KD_HIDDEN = 32\n",
        "    RUN_KD_DEFENSE = True\n",
        "    KD_WM_LAMBDA = 5e-3\n",
        "\n",
        "    RUN_QUANT = True\n",
        "    QUANT_BITS = [8, 6, 4]\n",
        "\n",
        "    RUN_WNOISE = True\n",
        "    WNOISE_SIGMAS = [0.01, 0.05, 0.1]\n",
        "\n",
        "    RUN_PGD_UNLEARN = True\n",
        "    PGD_STEPS = 30\n",
        "    PGD_EPS_LINF = 1e-2\n",
        "    PGD_STEP_LINF = 1e-3\n",
        "    PGD_EPS_L2 = 1e-1\n",
        "    PGD_STEP_L2 = 5e-3\n",
        "\n",
        "    RUN_REINIT_HEAD = True\n",
        "\n",
        "    RUN_CARRIER_PERTURB = True\n",
        "    CARR_DEL_FRACS = [0.0, 0.05, 0.1, 0.2]\n",
        "    CARR_ADD_FRACS = [0.0, 0.05, 0.1, 0.2]\n",
        "    CARR_PERTURB_TRIALS = 6\n",
        "\n",
        "    RUN_CAPACITY_SWEEP = True\n",
        "    M_LIST = [32, 64, 128, 256]\n",
        "    CAP_BETA_GRID = [1e-7,2e-7,5e-7,1e-6,2e-6,5e-6,1e-5,2e-5,5e-5,1e-4,2e-4,5e-4,1e-3]\n",
        "\n",
        "    ALLOW_ACC_DROP = 0.01\n",
        "    DPI_FIG = 150\n",
        "    DPI_SAVE = 300\n",
        "\n",
        "    SAVE_DIR = \"wm_iclr_full\"\n",
        "    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "cfg = Config()\n",
        "print(\"Device:\", cfg.DEVICE)\n",
        "\n",
        "def set_seed(s):\n",
        "    random.seed(s); np.random.seed(s)\n",
        "    torch.manual_seed(s); torch.cuda.manual_seed_all(s)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "# ---------- 3) Dataset & λ2 ----------\n",
        "def load_dataset(name: str, seed: int = 0):\n",
        "    ds = TUDataset(root='data/TUDataset', name=name, use_node_attr=True).shuffle()\n",
        "    split = int(0.8 * len(ds))\n",
        "    tr, te = ds[:split], ds[split:]\n",
        "    in_dim = ds.num_node_features if ds.num_node_features > 0 else 1\n",
        "    out_dim = ds.num_classes\n",
        "    return ds, tr, te, in_dim, out_dim\n",
        "\n",
        "def pyg_edges_to_adj(edge_index, num_nodes):\n",
        "    A = torch.zeros((num_nodes, num_nodes), dtype=torch.float32)\n",
        "    ei = edge_index.cpu().numpy()\n",
        "    for u, v in zip(ei[0], ei[1]):\n",
        "        if u == v: continue\n",
        "        A[u, v] = 1.0; A[v, u] = 1.0\n",
        "    return A\n",
        "\n",
        "def laplacian_lambda2(data: Data):\n",
        "    n = int(data.num_nodes)\n",
        "    if n <= 1: return torch.tensor(0.0)\n",
        "    A = pyg_edges_to_adj(data.edge_index, n)\n",
        "    d = A.sum(dim=1); L = torch.diag(d) - A\n",
        "    evals = torch.linalg.eigvalsh(L.cpu())\n",
        "    evals = torch.sort(evals).values\n",
        "    if len(evals) < 2: return torch.tensor(0.0)\n",
        "    return evals[1].float()\n",
        "\n",
        "def estimate_lambda_stats(train_dataset):\n",
        "    vals = [laplacian_lambda2(g).item() for g in train_dataset]\n",
        "    vals = torch.tensor(vals)\n",
        "    lam_min = torch.quantile(vals, 0.05).item()\n",
        "    lam_scale = torch.quantile(vals, 0.95).item()\n",
        "    if lam_scale <= lam_min:\n",
        "        lam_scale = float(vals.max().item()); lam_min = float(vals.min().item())\n",
        "    return lam_min, lam_scale\n",
        "\n",
        "def normalize_lambda2(l2, lam_min, lam_scale):\n",
        "    return float((l2 - lam_min) / max(1e-8, (lam_scale - lam_min)))\n",
        "\n",
        "def degree_histogram_from_dataset(train_dataset, max_degree: int = 512):\n",
        "    hist = torch.zeros(max_degree, dtype=torch.long)\n",
        "    for g in train_dataset:\n",
        "        d = degree(g.edge_index[0], num_nodes=g.num_nodes).long().clamp(max=max_degree-1)\n",
        "        hist += torch.bincount(d, minlength=max_degree)\n",
        "    last = int((hist > 0).nonzero().max().item()) if (hist > 0).any() else 0\n",
        "    return hist[:last+1].contiguous()\n",
        "\n",
        "# ---------- 4) Carrier graphs ----------\n",
        "def _simple_edge_rewire(G: nx.Graph, target_swaps: int, max_attempts: int = 5000):\n",
        "    H = G.copy()\n",
        "    edges = list(H.edges())\n",
        "    attempts, swaps = 0, 0\n",
        "    while swaps < target_swaps and attempts < max_attempts and len(edges) >= 2:\n",
        "        (u, v) = random.choice(edges); (x, y) = random.choice(edges)\n",
        "        attempts += 1\n",
        "        if len({u, v, x, y}) < 4: continue\n",
        "        if H.has_edge(u, y) or H.has_edge(x, v) or u == y or x == v: continue\n",
        "        H.remove_edge(u, v); H.remove_edge(x, y)\n",
        "        H.add_edge(u, y); H.add_edge(x, v)\n",
        "        edges = list(H.edges()); swaps += 1\n",
        "    return H\n",
        "\n",
        "def graph_double_edge_swap(data: Data, target_swaps: int, seed: int):\n",
        "    G = nx.Graph(to_networkx(data, to_undirected=True))\n",
        "    E = G.number_of_edges()\n",
        "    if E < 2 or G.number_of_nodes() < 4:\n",
        "        return from_networkx(G)\n",
        "    nswap = max(1, min(target_swaps, E // 2))\n",
        "    tries = max(100 * nswap, 1000)\n",
        "    H = G.copy()\n",
        "    while nswap >= 1:\n",
        "        try:\n",
        "            nx.double_edge_swap(H, nswap=nswap, max_tries=tries, seed=seed)\n",
        "            break\n",
        "        except Exception:\n",
        "            nswap //= 2; tries *= 2; H = G.copy()\n",
        "    if nswap < 1:\n",
        "        H = _simple_edge_rewire(G, target_swaps=max(1, min(target_swaps, E // 2)))\n",
        "    out = from_networkx(H)\n",
        "    if data.x is None:\n",
        "        out.x = torch.randn((out.num_nodes, 1))\n",
        "    else:\n",
        "        if data.x.size(0) != out.num_nodes:\n",
        "            new_x = torch.zeros((out.num_nodes, data.x.size(1)))\n",
        "            take = min(out.num_nodes, data.x.size(0)); new_x[:take] = data.x[:take]\n",
        "            out.x = new_x\n",
        "        else:\n",
        "            out.x = data.x\n",
        "    return out\n",
        "\n",
        "def build_carriers(train_dataset, m_bits, percentile_n, lam_min, lam_scale, seed):\n",
        "    ns = torch.tensor([g.num_nodes for g in train_dataset], dtype=torch.int)\n",
        "    cutoff = int(torch.quantile(ns.float(), percentile_n / 100.0).item())\n",
        "    seeds = [g for g in train_dataset if g.num_nodes <= cutoff]\n",
        "    seeds_f = []\n",
        "    for g in seeds:\n",
        "        G = nx.Graph(to_networkx(g, to_undirected=True))\n",
        "        if G.number_of_edges() >= 2 and g.num_nodes >= 4:\n",
        "            seeds_f.append(g)\n",
        "    if len(seeds_f) == 0:\n",
        "        seeds_f = list(train_dataset)\n",
        "\n",
        "    carriers, lam2s = [], []\n",
        "    band_low, band_high = cfg.BAND_LOW, cfg.BAND_HIGH\n",
        "    relax_steps, attempts = 0, 0\n",
        "\n",
        "    while len(carriers) < m_bits:\n",
        "        seed_g = random.choice(seeds_f)\n",
        "        e = nx.Graph(to_networkx(seed_g, to_undirected=True)).number_of_edges()\n",
        "        req = max(1, min(random.choice(cfg.CARRIER_SWAPS), max(1, e // 2)))\n",
        "        H = graph_double_edge_swap(seed_g, req, seed)\n",
        "        l2n = normalize_lambda2(laplacian_lambda2(H), lam_min, lam_scale)\n",
        "        if (l2n <= band_low) or (l2n >= band_high) or (relax_steps >= cfg.MAX_BAND_RELAX_STEPS):\n",
        "            carriers.append(H); lam2s.append(l2n)\n",
        "        attempts += 1\n",
        "        if attempts % 200 == 0 and len(carriers) < m_bits:\n",
        "            relax_steps += 1\n",
        "            band_low = max(0.0, band_low - 0.05)\n",
        "            band_high = min(1.0, band_high + 0.05)\n",
        "\n",
        "    wm_batch = Batch.from_data_list(carriers).to(cfg.DEVICE)\n",
        "    lam2s = torch.tensor(lam2s, dtype=torch.float32, device=cfg.DEVICE)\n",
        "    owner_key = (lam2s >= 0.5).long()\n",
        "    return wm_batch, lam2s, owner_key\n",
        "\n",
        "# ---------- 5) ensure_x 修复版 ----------\n",
        "def ensure_x(x, edge_index, num_nodes=None, max_deg=10):\n",
        "    \"\"\"\n",
        "    保证节点特征存在。若 x=None，则用度特征。num_nodes 请传 batch.size(0)。\n",
        "    \"\"\"\n",
        "    if num_nodes is None:\n",
        "        if x is not None:\n",
        "            num_nodes = x.size(0)\n",
        "        else:\n",
        "            num_nodes = int(edge_index.max().item()) + 1 if edge_index.numel() > 0 else 0\n",
        "\n",
        "    if x is None:\n",
        "        deg = degree(edge_index[0], num_nodes=num_nodes, dtype=torch.float32).view(-1, 1)\n",
        "        return deg\n",
        "    if x.dtype not in (torch.float32, torch.float64, torch.float16, torch.bfloat16):\n",
        "        x = x.float()\n",
        "    return x\n",
        "\n",
        "# ---------- 6) Models ----------\n",
        "class GNNBackbone(nn.Module):\n",
        "    def __init__(self, in_channels, hidden, out_channels, name='GIN',\n",
        "                 num_layers=3, dropout=0.0, heads=4, pna_deg=None,\n",
        "                 pna_aggrs=None, pna_scalers=None, pna_towers=1, pna_pre=1, pna_post=1,\n",
        "                 appnp_k=10, appnp_alpha=0.1, sgc_k=3,\n",
        "                 gcnii_alpha=0.1, gcnii_theta=None):\n",
        "        super().__init__()\n",
        "        self.name = name\n",
        "        self.dropout = dropout\n",
        "        self.num_layers = num_layers\n",
        "\n",
        "        self.task_head = nn.Linear(hidden, out_channels)\n",
        "        self.wm_head = nn.Sequential(\n",
        "            spectral_norm(nn.Linear(hidden, hidden // 2)),\n",
        "            nn.ReLU(),\n",
        "            spectral_norm(nn.Linear(hidden // 2, 1))\n",
        "        )\n",
        "\n",
        "        self.convs = nn.ModuleList()\n",
        "        self.norms = nn.ModuleList()\n",
        "\n",
        "        if name == 'GIN':\n",
        "            nn1 = nn.Sequential(nn.Linear(in_channels, hidden), nn.ReLU(), nn.Linear(hidden, hidden))\n",
        "            self.convs.append(GINConv(nn1)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "            for _ in range(num_layers - 1):\n",
        "                nnk = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))\n",
        "                self.convs.append(GINConv(nnk)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "\n",
        "        elif name == 'GCN':\n",
        "            self.convs.append(GCNConv(in_channels, hidden)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "            for _ in range(num_layers - 1):\n",
        "                self.convs.append(GCNConv(hidden, hidden)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "\n",
        "        elif name == 'GraphSAGE':\n",
        "            self.convs.append(SAGEConv(in_channels, hidden)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "            for _ in range(num_layers - 1):\n",
        "                self.convs.append(SAGEConv(hidden, hidden)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "\n",
        "        elif name == 'GAT':\n",
        "            self.convs.append(GATConv(in_channels, hidden, heads=heads, concat=False, dropout=dropout)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "            for _ in range(num_layers - 1):\n",
        "                self.convs.append(GATConv(hidden, hidden, heads=heads, concat=False, dropout=dropout)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "\n",
        "        elif name == 'Transformer':\n",
        "            self.convs.append(TransformerConv(in_channels, hidden, heads=heads, concat=False, dropout=dropout)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "            for _ in range(num_layers - 1):\n",
        "                self.convs.append(TransformerConv(hidden, hidden, heads=heads, concat=False, dropout=dropout)); self.norms.append(nn.BatchNorm1d(hidden))\n",
        "\n",
        "        elif name == 'PNA':\n",
        "            assert pna_deg is not None and pna_deg.numel() > 0, \"PNA needs a non-empty degree histogram.\"\n",
        "            aggrs = pna_aggrs or ['sum', 'mean', 'max']\n",
        "            scalers = pna_scalers or ['identity', 'amplification', 'attenuation']\n",
        "            self.convs.append(PNAConv(in_channels, hidden, aggrs, scalers, pna_deg,\n",
        "                                      towers=pna_towers, pre_layers=pna_pre, post_layers=pna_post, divide_input=False))\n",
        "            self.norms.append(nn.BatchNorm1d(hidden))\n",
        "            for _ in range(num_layers - 1):\n",
        "                self.convs.append(PNAConv(hidden, hidden, aggrs, scalers, pna_deg,\n",
        "                                          towers=pna_towers, pre_layers=pna_pre, post_layers=pna_post, divide_input=False))\n",
        "                self.norms.append(nn.BatchNorm1d(hidden))\n",
        "\n",
        "        elif name == 'GCNII':\n",
        "            self.lin_in = nn.Linear(in_channels, hidden)\n",
        "            self.convs = nn.ModuleList([GCN2Conv(hidden, alpha=gcnii_alpha, theta=gcnii_theta, layer=i+1)\n",
        "                                        for i in range(num_layers)])\n",
        "            self.norms = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(num_layers)])\n",
        "\n",
        "        elif name == 'APPNP':\n",
        "            self.lin1 = nn.Linear(in_channels, hidden)\n",
        "            self.lin2 = nn.Linear(hidden, hidden)\n",
        "            self.propagation = APPNP(K=appnp_k, alpha=appnp_alpha)\n",
        "\n",
        "        elif name == 'SGC':\n",
        "            self.sgc = SGConv(in_channels, hidden, K=sgc_k, cached=False)\n",
        "\n",
        "        else:\n",
        "            raise ValueError(f\"Unknown BACKBONE: {name}\")\n",
        "\n",
        "    def forward(self, x, edge_index, batch):\n",
        "        # use batch size to determine num_nodes when x is None\n",
        "        num_nodes = batch.size(0) if isinstance(batch, torch.Tensor) else (x.size(0) if x is not None else int(edge_index.max().item()) + 1)\n",
        "        x = ensure_x(x, edge_index, num_nodes=num_nodes)\n",
        "\n",
        "        name = self.name\n",
        "        if name in ['GIN', 'GCN', 'GraphSAGE', 'GAT', 'Transformer', 'PNA']:\n",
        "            h = x\n",
        "            for conv, bn in zip(self.convs, self.norms):\n",
        "                h = conv(h, edge_index); h = bn(h); h = F.relu(h)\n",
        "                h = F.dropout(h, p=self.dropout, training=self.training)\n",
        "            g_emb = global_add_pool(h, batch)\n",
        "\n",
        "        elif name == 'GCNII':\n",
        "            h0 = F.relu(self.lin_in(x))\n",
        "            h = h0\n",
        "            for conv, bn in zip(self.convs, self.norms):\n",
        "                h = F.dropout(h, p=self.dropout, training=self.training)\n",
        "                h = conv(h, h0, edge_index)\n",
        "                h = bn(h); h = F.relu(h)\n",
        "            g_emb = global_add_pool(h, batch)\n",
        "\n",
        "        elif name == 'APPNP':\n",
        "            h = F.dropout(x, p=self.dropout, training=self.training)\n",
        "            h = F.relu(self.lin1(h))\n",
        "            h = F.dropout(h, p=self.dropout, training=self.training)\n",
        "            h = self.lin2(h)\n",
        "            h = self.propagation(h, edge_index)\n",
        "            g_emb = global_add_pool(h, batch)\n",
        "\n",
        "        elif name == 'SGC':\n",
        "            h = self.sgc(x, edge_index); h = F.relu(h)\n",
        "            h = F.dropout(h, p=self.dropout, training=self.training)\n",
        "            g_emb = global_add_pool(h, batch)\n",
        "\n",
        "        else:\n",
        "            raise ValueError(f\"Unknown BACKBONE (forward): {name}\")\n",
        "\n",
        "        logits = self.task_head(g_emb)\n",
        "        wm_logit = self.wm_head(g_emb).squeeze(-1)\n",
        "        return logits, g_emb, wm_logit\n",
        "\n",
        "class WatermarkedGNN(nn.Module):\n",
        "    def __init__(self, in_dim, hidden, out_dim, backbone='GIN', num_layers=3, dropout=0.0, **kwargs):\n",
        "        super().__init__()\n",
        "        self.backbone = GNNBackbone(in_dim, hidden, out_dim, name=backbone,\n",
        "                                    num_layers=num_layers, dropout=dropout, **kwargs)\n",
        "    def forward(self, x, edge_index, batch):\n",
        "        return self.backbone(x, edge_index, batch)\n",
        "\n",
        "# ---------- 7) Train / Eval / Verify ----------\n",
        "bce_logits = nn.BCEWithLogitsLoss()\n",
        "\n",
        "def train_epoch(model, task_loader, optimizer, lambda_wm, wm_batch, owner_key, device):\n",
        "    model.train()\n",
        "    total = 0.0\n",
        "    eps = cfg.SOFT_LABEL_EPS\n",
        "    y_soft = owner_key.float() * (1 - 2 * eps) + eps if owner_key is not None else None\n",
        "    for data in task_loader:\n",
        "        data = data.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        task_logits, _, _ = model(data.x, data.edge_index, data.batch)\n",
        "        task_loss = F.cross_entropy(task_logits, data.y)\n",
        "        wm_loss = torch.tensor(0.0, device=device)\n",
        "        if lambda_wm > 0 and wm_batch is not None and y_soft is not None:\n",
        "            _, _, wm_logit = model(wm_batch.x, wm_batch.edge_index, wm_batch.batch)\n",
        "            wm_loss = bce_logits(wm_logit, y_soft)\n",
        "        loss = task_loss + lambda_wm * wm_loss\n",
        "        loss.backward()\n",
        "        clip_grad_norm_(model.parameters(), max_norm=cfg.MAX_GRAD_NORM)\n",
        "        optimizer.step()\n",
        "        total += loss.item() * data.num_graphs\n",
        "    return total / len(task_loader.dataset)\n",
        "\n",
        "@torch.no_grad()\n",
        "def eval_acc(model, loader, device):\n",
        "    model.eval()\n",
        "    correct = 0\n",
        "    for data in loader:\n",
        "        data = data.to(device)\n",
        "        logits, _, _ = model(data.x, data.edge_index, data.batch)\n",
        "        pred = logits.argmax(dim=1)\n",
        "        correct += int((pred == data.y).sum())\n",
        "    return correct / len(loader.dataset)\n",
        "\n",
        "@torch.no_grad()\n",
        "def verify_watermark(model, wm_batch, key_to_verify, device):\n",
        "    model.eval()\n",
        "    _, _, wm_logit = model(wm_batch.x, wm_batch.edge_index, wm_batch.batch)\n",
        "    s_vals = torch.sigmoid(wm_logit)\n",
        "    predicted_bits = (s_vals >= 0.5).long()\n",
        "    correct_bits = int((predicted_bits == key_to_verify).sum().item())\n",
        "    wa = correct_bits / len(key_to_verify)\n",
        "    kappa_marg = float(torch.min(torch.abs(s_vals - 0.5)).item())\n",
        "    return wa, kappa_marg, predicted_bits, s_vals\n",
        "\n",
        "def tau_threshold(m: int, alpha: float):\n",
        "    eps = math.sqrt(0.5 * math.log(1 / alpha) / m)\n",
        "    return int(math.ceil(m * (0.5 + eps)))\n",
        "\n",
        "# ---------- 8) Attacks ----------\n",
        "def global_prune_(model, frac: float):\n",
        "    ws = []\n",
        "    for name, p in model.named_parameters():\n",
        "        if (\"weight\" in name) and p.requires_grad and p.dim() >= 2:\n",
        "            ws.append(p.view(-1))\n",
        "    if not ws: return 0.0\n",
        "    all_w = torch.cat([w.detach().abs().cpu() for w in ws])\n",
        "    k = int(frac * all_w.numel())\n",
        "    if k <= 0: return 0.0\n",
        "    thresh = torch.topk(all_w, k, largest=False).values.max().item()\n",
        "    with torch.no_grad():\n",
        "        for name, p in model.named_parameters():\n",
        "            if (\"weight\" in name) and p.requires_grad and p.dim() >= 2:\n",
        "                mask = p.abs() <= thresh\n",
        "                p[mask] = 0.0\n",
        "    return thresh\n",
        "\n",
        "def finetune(model, train_loader, epochs=10, lr=cfg.LR, device=cfg.DEVICE):\n",
        "    ft = copy.deepcopy(model)\n",
        "    opt = torch.optim.Adam(ft.parameters(), lr=lr)\n",
        "    for _ in range(epochs):\n",
        "        train_epoch(ft, train_loader, opt, 0.0, None, None, device)\n",
        "    return ft\n",
        "\n",
        "def fineprune_and_finetune(model, train_loader, frac=0.4, ft_epochs=10, lr=cfg.LR, device=cfg.DEVICE):\n",
        "    pruned = copy.deepcopy(model)\n",
        "    _ = global_prune_(pruned, frac)\n",
        "    opt = torch.optim.Adam(pruned.parameters(), lr=lr)\n",
        "    for _ in range(ft_epochs):\n",
        "        train_epoch(pruned, train_loader, opt, 0.0, None, None, device)\n",
        "    return pruned\n",
        "\n",
        "def distill_student(teacher, train_loader, in_dim, out_dim, hidden=32, epochs=10, T=2.0, backbone='GIN', device=cfg.DEVICE):\n",
        "    student = WatermarkedGNN(in_dim, hidden, out_dim, backbone=backbone, num_layers=cfg.NUM_LAYERS, dropout=cfg.DROPOUT).to(device)\n",
        "    opt = torch.optim.Adam(student.parameters(), lr=cfg.LR)\n",
        "    kl = nn.KLDivLoss(reduction='batchmean')\n",
        "    teacher.eval()\n",
        "    for _ in range(epochs):\n",
        "        student.train()\n",
        "        for data in train_loader:\n",
        "            data = data.to(device)\n",
        "            with torch.no_grad():\n",
        "                t_logits, _, _ = teacher(data.x, data.edge_index, data.batch)\n",
        "            s_logits, _, _ = student(data.x, data.edge_index, data.batch)\n",
        "            loss = kl(F.log_softmax(s_logits/T, dim=1), F.softmax(t_logits/T, dim=1)) * (T*T)\n",
        "            opt.zero_grad(); loss.backward(); clip_grad_norm_(student.parameters(), cfg.MAX_GRAD_NORM); opt.step()\n",
        "    return student\n",
        "\n",
        "def distill_student_with_wm(teacher, train_loader, wm_batch, owner_key, in_dim, out_dim,\n",
        "                            hidden=32, epochs=10, T=2.0, lambda_wm=5e-3, backbone='GIN', device=cfg.DEVICE):\n",
        "    student = WatermarkedGNN(in_dim, hidden, out_dim, backbone=backbone, num_layers=cfg.NUM_LAYERS, dropout=cfg.DROPOUT).to(device)\n",
        "    opt = torch.optim.Adam(student.parameters(), lr=cfg.LR)\n",
        "    kl = nn.KLDivLoss(reduction='batchmean')\n",
        "    teacher.eval()\n",
        "    eps = cfg.SOFT_LABEL_EPS\n",
        "    y_soft = owner_key.float() * (1 - 2 * eps) + eps\n",
        "    for _ in range(epochs):\n",
        "        student.train()\n",
        "        for data in train_loader:\n",
        "            data = data.to(device)\n",
        "            with torch.no_grad():\n",
        "                t_logits, _, _ = teacher(data.x, data.edge_index, data.batch)\n",
        "            s_logits, _, _ = student(data.x, data.edge_index, data.batch)\n",
        "            kd = kl(F.log_softmax(s_logits/T, dim=1), F.softmax(t_logits/T, dim=1)) * (T*T)\n",
        "            _, _, wm_logit = student(wm_batch.x, wm_batch.edge_index, wm_batch.batch)\n",
        "            wm = bce_logits(wm_logit, y_soft)\n",
        "            loss = kd + lambda_wm * wm\n",
        "            opt.zero_grad(); loss.backward(); clip_grad_norm_(student.parameters(), cfg.MAX_GRAD_NORM); opt.step()\n",
        "    return student\n",
        "\n",
        "def quantize_model_(model, nbits=8):\n",
        "    qmax = 2**(nbits-1) - 1\n",
        "    with torch.no_grad():\n",
        "        for name, p in model.named_parameters():\n",
        "            if (not p.requires_grad) or (p.data.dtype not in (torch.float32, torch.float16, torch.bfloat16)):\n",
        "                continue\n",
        "            if \"weight\" not in name:\n",
        "                continue\n",
        "            scale = p.abs().max()\n",
        "            if scale == 0:\n",
        "                continue\n",
        "            s = scale / qmax\n",
        "            q = torch.clamp((p / (s + 1e-12)).round(), -qmax-1, qmax)\n",
        "            p.copy_(q * s)\n",
        "    return model\n",
        "\n",
        "def add_weight_noise_(model, sigma=0.01):\n",
        "    with torch.no_grad():\n",
        "        for _, p in model.named_parameters():\n",
        "            if p.requires_grad and p.data.dtype in (torch.float32, torch.float16, torch.bfloat16):\n",
        "                m = p.abs().mean().item()\n",
        "                if m == 0: continue\n",
        "                p.add_(torch.randn_like(p) * (sigma * m))\n",
        "    return model\n",
        "\n",
        "def reinit_wm_head_(model):\n",
        "    with torch.no_grad():\n",
        "        for m in model.modules():\n",
        "            if isinstance(m, nn.Sequential):\n",
        "                # wm_head sequential: two Linear with spectral_norm, last outputs 1-d\n",
        "                pass\n",
        "        # safer: find by name\n",
        "        for n, mod in model.named_modules():\n",
        "            if \"wm_head\" in n:\n",
        "                for p in mod.parameters():\n",
        "                    if p.dim() >= 2:\n",
        "                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))\n",
        "                    else:\n",
        "                        fan_in = p.numel()\n",
        "                        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0.01\n",
        "                        nn.init.uniform_(p, -bound, bound)\n",
        "    return model\n",
        "\n",
        "def pgd_unlearn_(model, wm_batch, owner_key, steps=30, eps=1e-2, step=1e-3, norm='linf', device=cfg.DEVICE):\n",
        "    \"\"\"\n",
        "    对抗“去水印”：在不看任务数据的情况下，沿着水印 BCE 的上升方向做 PGD，\n",
        "    约束在参数球内（L_inf 或 L2）。\n",
        "    \"\"\"\n",
        "    adv = copy.deepcopy(model).to(device)\n",
        "    params = [p for p in adv.parameters() if p.requires_grad]\n",
        "    with torch.no_grad():\n",
        "        ref = [p.clone() for p in params]\n",
        "    opt = torch.optim.SGD(params, lr=1.0)  # 学习率由 step 手动控制\n",
        "    eps_tensor = [torch.full_like(p, eps) for p in params]\n",
        "\n",
        "    eps_l2 = eps if norm == 'l2' else None\n",
        "\n",
        "    for _ in range(steps):\n",
        "        opt.zero_grad()\n",
        "        _, _, wm_logit = adv(wm_batch.x, wm_batch.edge_index, wm_batch.batch)\n",
        "        y = owner_key.float()\n",
        "        loss = - bce_logits(wm_logit, y)           # maximize BCE\n",
        "        loss.backward()\n",
        "        with torch.no_grad():\n",
        "            for p, r, epsi in zip(params, ref, eps_tensor):\n",
        "                if norm == 'linf':\n",
        "                    p.add_(step * torch.sign(p.grad))\n",
        "                    p.copy_(torch.max(torch.min(p, r + epsi), r - epsi))\n",
        "                elif norm == 'l2':\n",
        "                    g = p.grad\n",
        "                    if g is not None:\n",
        "                        g_norm = g.norm()\n",
        "                        if g_norm > 0:\n",
        "                            p.add_(step * g / (g_norm + 1e-12))\n",
        "                    delta = p - r\n",
        "                    dnorm = delta.norm()\n",
        "                    if dnorm > eps_l2:\n",
        "                        p.copy_(r + delta * (eps_l2 / (dnorm + 1e-12)))\n",
        "                else:\n",
        "                    raise ValueError(\"norm must be 'linf' or 'l2'\")\n",
        "    return adv\n",
        "\n",
        "def perturb_carriers(wm_batch: Batch, p_del=0.05, p_add=0.05, device=cfg.DEVICE):\n",
        "    out_graphs = []\n",
        "    for idx in range(wm_batch.num_graphs):\n",
        "        mask = wm_batch.batch == idx\n",
        "        x = wm_batch.x[mask].cpu()\n",
        "        node_idx = torch.where(mask)[0]\n",
        "        node_map = {int(n): i for i, n in enumerate(node_idx)}\n",
        "        e_mask = torch.tensor([u.item() in node_map and v.item() in node_map for u, v in wm_batch.edge_index.t()])\n",
        "        ei = wm_batch.edge_index[:, e_mask]\n",
        "        u = [node_map[int(a)] for a in ei[0].tolist()]\n",
        "        v = [node_map[int(b)] for b in ei[1].tolist()]\n",
        "        G = nx.Graph(); G.add_nodes_from(range(x.size(0))); G.add_edges_from(zip(u, v))\n",
        "        E = list(G.edges())\n",
        "        ndel = max(0, int(round(p_del * len(E))))\n",
        "        if ndel > 0 and len(E) > 0:\n",
        "            del_edges = random.sample(E, min(ndel, len(E))); G.remove_edges_from(del_edges)\n",
        "        nadd = max(0, int(round(p_add * len(E))))\n",
        "        possible = list(nx.non_edges(G))\n",
        "        if nadd > 0 and len(possible) > 0:\n",
        "            add_edges = random.sample(possible, min(nadd, len(possible))); G.add_edges_from(add_edges)\n",
        "        D = from_networkx(G); D.x = x; out_graphs.append(D)\n",
        "    pert = Batch.from_data_list(out_graphs).to(device)\n",
        "    return pert\n",
        "\n",
        "# ---------- 9) Helpers: backbone kwargs ----------\n",
        "def _backbone_kwargs(backbone, train_dataset):\n",
        "    kw = {}\n",
        "    if backbone in ['GAT', 'Transformer']:\n",
        "        kw[\"heads\"] = cfg.HEADS\n",
        "    if backbone == 'PNA':\n",
        "        kw.update({\n",
        "            \"pna_deg\": degree_histogram_from_dataset(train_dataset, max_degree=512),\n",
        "            \"pna_aggrs\": cfg.PNA_AGGRS,\n",
        "            \"pna_scalers\": cfg.PNA_SCALERS,\n",
        "            \"pna_towers\": cfg.PNA_TOWERS,\n",
        "            \"pna_pre\": cfg.PNA_PRE_LAYERS,\n",
        "            \"pna_post\": cfg.PNA_POST_LAYERS,\n",
        "        })\n",
        "    if backbone == 'APPNP':\n",
        "        kw.update({\"appnp_k\": cfg.APPNP_K, \"appnp_alpha\": cfg.APPNP_ALPHA})\n",
        "    if backbone == 'SGC':\n",
        "        kw.update({\"sgc_k\": cfg.SGC_K})\n",
        "    if backbone == 'GCNII':\n",
        "        kw.update({\"gcnii_alpha\": cfg.GCNII_ALPHA, \"gcnii_theta\": cfg.GCNII_THETA})\n",
        "    return kw\n",
        "\n",
        "# ---------- 10) Single experiment ----------\n",
        "def run_single_experiment(dataset_name, backbone, seed, betas, m_bits, owner_alpha, save_dir):\n",
        "    set_seed(seed)\n",
        "    dataset_all, train_dataset, test_dataset, IN_DIM, OUT_DIM = load_dataset(dataset_name, seed)\n",
        "    lam_min, lam_scale = estimate_lambda_stats(train_dataset)\n",
        "    print(f\"[{dataset_name}|{backbone}|seed={seed}] λ2: min={lam_min:.4f} scale={lam_scale:.4f}\")\n",
        "\n",
        "    wm_batch, wm_targets, owner_key = build_carriers(\n",
        "        train_dataset, m_bits, cfg.MAX_CARRIER_N_PERCENTILE, lam_min, lam_scale, seed\n",
        "    )\n",
        "    print(f\"[Carriers] m={m_bits}, key ones={int(owner_key.sum().item())}\")\n",
        "\n",
        "    g = torch.Generator(); g.manual_seed(seed)\n",
        "    train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH, shuffle=True,\n",
        "                              num_workers=cfg.NUM_WORKERS, generator=g)\n",
        "    test_loader  = DataLoader(test_dataset,  batch_size=cfg.BATCH, shuffle=False,\n",
        "                              num_workers=cfg.NUM_WORKERS)\n",
        "\n",
        "    extra_kw = _backbone_kwargs(backbone, train_dataset)\n",
        "\n",
        "    # 10.1 Imperceptibility sweep\n",
        "    results = []; models_at_beta = {}\n",
        "    for beta in betas:\n",
        "        print(f\"\\n--- Training with β_wm = {beta} ---\")\n",
        "        model = WatermarkedGNN(IN_DIM, cfg.HIDDEN, OUT_DIM, backbone=backbone,\n",
        "                               num_layers=cfg.NUM_LAYERS, dropout=cfg.DROPOUT, **extra_kw).to(cfg.DEVICE)\n",
        "        opt = torch.optim.Adam(model.parameters(), lr=cfg.LR, weight_decay=cfg.WD)\n",
        "        for epoch in range(1, cfg.EPOCHS+1):\n",
        "            _ = train_epoch(model, train_loader, opt, beta, wm_batch, owner_key, cfg.DEVICE)\n",
        "        task_acc = eval_acc(model, test_loader, cfg.DEVICE)\n",
        "        wm_acc, kappa_marg, *_ = verify_watermark(model, wm_batch, owner_key, cfg.DEVICE) if beta > 0 else (0.0, 0.0, None, None)\n",
        "        results.append({'beta': beta, 'task_acc': task_acc, 'wm_acc': wm_acc, 'kappa_marg': kappa_marg})\n",
        "        models_at_beta[beta] = model\n",
        "        print(f\"-> Test ACC {task_acc:.4f} | WM-ACC {wm_acc:.4f} | κ_marg {kappa_marg:.4f}\")\n",
        "\n",
        "    base_acc = next((r['task_acc'] for r in results if r['beta']==0), results[0]['task_acc'])\n",
        "    cand = [r for r in results if r['beta']>0 and r['task_acc'] >= base_acc - cfg.ALLOW_ACC_DROP]\n",
        "    chosen_row = max(cand, key=lambda r: r['wm_acc']) if len(cand)>0 else max([r for r in results if r['beta']>0], key=lambda r: r['wm_acc'])\n",
        "    chosen_beta = float(chosen_row['beta'])\n",
        "    wm_model = models_at_beta[chosen_beta]\n",
        "    print(f\"\\n[Chosen β_wm] {chosen_beta}  (base_acc={base_acc:.4f}, chosen_acc={chosen_row['task_acc']:.4f}, chosen_wm={chosen_row['wm_acc']:.4f})\")\n",
        "\n",
        "    # 10.2 Uniqueness\n",
        "    owner_wa, owner_kappa, owner_pred, _ = verify_watermark(wm_model, wm_batch, owner_key, cfg.DEVICE)\n",
        "    m = len(owner_key); tau = tau_threshold(m, owner_alpha)\n",
        "    T_owner = int((owner_pred == owner_key).sum().item())\n",
        "    impostor_Ts = []\n",
        "    for _ in range(cfg.NUM_IMPOSTOR_KEYS):\n",
        "        rand_key = torch.randint(0, 2, (m,), device=cfg.DEVICE).long()\n",
        "        wa_imp, _, pred_imp, _ = verify_watermark(wm_model, wm_batch, rand_key, cfg.DEVICE)\n",
        "        impostor_Ts.append(int((pred_imp == rand_key).sum().item()))\n",
        "    avg_imp_wa = float(np.mean(np.array(impostor_Ts) / m)); max_imp_T = max(impostor_Ts)\n",
        "    uniq_pass = (T_owner >= tau) and (max_imp_T < tau)\n",
        "    print(f\"[Uniqueness] Owner T={T_owner} vs τ={tau}; Impostor max T={max_imp_T} -> {'PASS' if uniq_pass else 'FAIL'}\")\n",
        "\n",
        "    Path(save_dir).mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "    # 10.3 Robustness & attacks\n",
        "    base_task_acc = eval_acc(wm_model, test_loader, cfg.DEVICE)\n",
        "    base_wa, base_kappa, *_ = verify_watermark(wm_model, wm_batch, owner_key, cfg.DEVICE)\n",
        "    robust_rows = [(\"Baseline\", base_task_acc, base_wa, base_kappa)]\n",
        "\n",
        "    if cfg.RUN_PRUNE:\n",
        "        for pf in cfg.PRUNE_FRACS:\n",
        "            pruned = copy.deepcopy(wm_model)\n",
        "            _ = global_prune_(pruned, pf)\n",
        "            acc_p = eval_acc(pruned, test_loader, cfg.DEVICE)\n",
        "            wa_p, km_p, *_ = verify_watermark(pruned, wm_batch, owner_key, cfg.DEVICE)\n",
        "            robust_rows.append((f\"Pr-{int(pf*100)}%\", acc_p, wa_p, km_p))\n",
        "\n",
        "    if cfg.RUN_FINEPRUNE:\n",
        "        fpr = fineprune_and_finetune(wm_model, train_loader, frac=cfg.FPRUNE_FRAC, ft_epochs=cfg.FPRUNE_FT_EPOCHS, lr=cfg.LR, device=cfg.DEVICE)\n",
        "        acc_fp = eval_acc(fpr, test_loader, cfg.DEVICE)\n",
        "        wa_fp, km_fp, *_ = verify_watermark(fpr, wm_batch, owner_key, cfg.DEVICE)\n",
        "        robust_rows.append((f\"FPr-{int(cfg.FPRUNE_FRAC*100)}%+FT\", acc_fp, wa_fp, km_fp))\n",
        "\n",
        "    if cfg.RUN_FINETUNE:\n",
        "        ft = finetune(wm_model, train_loader, epochs=cfg.FT_EPOCHS, lr=cfg.LR, device=cfg.DEVICE)\n",
        "        acc_ft = eval_acc(ft, test_loader, cfg.DEVICE)\n",
        "        wa_ft, km_ft, *_ = verify_watermark(ft, wm_batch, owner_key, cfg.DEVICE)\n",
        "        robust_rows.append((f\"FT-{cfg.FT_EPOCHS}e\", acc_ft, wa_ft, km_ft))\n",
        "\n",
        "    attack_rows = []\n",
        "    if cfg.RUN_DISTILL:\n",
        "        stu = distill_student(wm_model, train_loader, IN_DIM, OUT_DIM,\n",
        "                              hidden=cfg.KD_HIDDEN, epochs=cfg.KD_EPOCHS, T=cfg.KD_T,\n",
        "                              backbone=backbone, device=cfg.DEVICE)\n",
        "        acc_stu = eval_acc(stu, test_loader, cfg.DEVICE)\n",
        "        wa_stu, km_stu, *_ = verify_watermark(stu, wm_batch, owner_key, cfg.DEVICE)\n",
        "        attack_rows.append((f\"KD-{cfg.KD_EPOCHS}e-T{cfg.KD_T}\", acc_stu, wa_stu, km_stu))\n",
        "\n",
        "    if cfg.RUN_KD_DEFENSE:\n",
        "        stu_def = distill_student_with_wm(wm_model, train_loader, wm_batch, owner_key,\n",
        "                                          IN_DIM, OUT_DIM, hidden=cfg.KD_HIDDEN, epochs=cfg.KD_EPOCHS,\n",
        "                                          T=cfg.KD_T, lambda_wm=cfg.KD_WM_LAMBDA, backbone=backbone, device=cfg.DEVICE)\n",
        "        acc_sd = eval_acc(stu_def, test_loader, cfg.DEVICE)\n",
        "        wa_sd, km_sd, *_ = verify_watermark(stu_def, wm_batch, owner_key, cfg.DEVICE)\n",
        "        attack_rows.append((f\"KD+WM-{cfg.KD_EPOCHS}e-T{cfg.KD_T}\", acc_sd, wa_sd, km_sd))\n",
        "\n",
        "    if cfg.RUN_QUANT:\n",
        "        for qbits in cfg.QUANT_BITS:\n",
        "            q_model = copy.deepcopy(wm_model)\n",
        "            quantize_model_(q_model, nbits=qbits)\n",
        "            acc_q = eval_acc(q_model, test_loader, cfg.DEVICE)\n",
        "            wa_q, km_q, *_ = verify_watermark(q_model, wm_batch, owner_key, cfg.DEVICE)\n",
        "            attack_rows.append((f\"Quant-{qbits}b\", acc_q, wa_q, km_q))\n",
        "\n",
        "    if cfg.RUN_WNOISE:\n",
        "        for s in cfg.WNOISE_SIGMAS:\n",
        "            n_model = copy.deepcopy(wm_model)\n",
        "            add_weight_noise_(n_model, sigma=s)\n",
        "            acc_n = eval_acc(n_model, test_loader, cfg.DEVICE)\n",
        "            wa_n, km_n, *_ = verify_watermark(n_model, wm_batch, owner_key, cfg.DEVICE)\n",
        "            attack_rows.append((f\"WNoise-{s}\", acc_n, wa_n, km_n))\n",
        "\n",
        "    if cfg.RUN_PGD_UNLEARN:\n",
        "        # L_inf\n",
        "        pgd = pgd_unlearn_(wm_model, wm_batch, owner_key, steps=cfg.PGD_STEPS, eps=cfg.PGD_EPS_LINF, step=cfg.PGD_STEP_LINF, norm='linf', device=cfg.DEVICE)\n",
        "        acc_pgd = eval_acc(pgd, test_loader, cfg.DEVICE)\n",
        "        wa_pgd, km_pgd, *_ = verify_watermark(pgd, wm_batch, owner_key, cfg.DEVICE)\n",
        "        attack_rows.append((f\"PGD-Linf-{cfg.PGD_STEPS}\", acc_pgd, wa_pgd, km_pgd))\n",
        "        # L2\n",
        "        pgd2 = pgd_unlearn_(wm_model, wm_batch, owner_key, steps=cfg.PGD_STEPS, eps=cfg.PGD_EPS_L2, step=cfg.PGD_STEP_L2, norm='l2', device=cfg.DEVICE)\n",
        "        acc_pgd2 = eval_acc(pgd2, test_loader, cfg.DEVICE)\n",
        "        wa_pgd2, km_pgd2, *_ = verify_watermark(pgd2, wm_batch, owner_key, cfg.DEVICE)\n",
        "        attack_rows.append((f\"PGD-L2-{cfg.PGD_STEPS}\", acc_pgd2, wa_pgd2, km_pgd2))\n",
        "\n",
        "    if cfg.RUN_REINIT_HEAD:\n",
        "        rh = copy.deepcopy(wm_model)\n",
        "        reinit_wm_head_(rh)\n",
        "        acc_rh = eval_acc(rh, test_loader, cfg.DEVICE)\n",
        "        wa_rh, km_rh, *_ = verify_watermark(rh, wm_batch, owner_key, cfg.DEVICE)\n",
        "        attack_rows.append((\"Reinit-Head\", acc_rh, wa_rh, km_rh))\n",
        "\n",
        "    # 10.4 Carrier perturbation\n",
        "    carr_rows = []\n",
        "    if cfg.RUN_CARRIER_PERTURB:\n",
        "        for pdel in cfg.CARR_DEL_FRACS:\n",
        "            for padd in cfg.CARR_ADD_FRACS:\n",
        "                wa_list = []\n",
        "                for _ in range(cfg.CARR_PERTURB_TRIALS):\n",
        "                    pert = perturb_carriers(wm_batch, p_del=pdel, p_add=padd, device=cfg.DEVICE)\n",
        "                    wa_p, km_p, _, _ = verify_watermark(wm_model, pert, owner_key, cfg.DEVICE)\n",
        "                    wa_list.append(wa_p)\n",
        "                wa_arr = np.array(wa_list, dtype=float)\n",
        "                wa_mean = float(wa_arr.mean()); wa_std = float(wa_arr.std(ddof=1)) if len(wa_arr) > 1 else 0.0\n",
        "                ci95 = float(1.96 * wa_std / math.sqrt(max(1, len(wa_arr))))\n",
        "                carr_rows.append((pdel, padd, wa_mean, wa_std, ci95, len(wa_arr)))\n",
        "\n",
        "    # ---------- Figures & CSVs ----------\n",
        "    Path(save_dir).mkdir(parents=True, exist_ok=True)\n",
        "    mpl.rcParams.update({\n",
        "        \"figure.dpi\": cfg.DPI_FIG, \"savefig.dpi\": cfg.DPI_SAVE, \"font.size\": 12,\n",
        "        \"axes.titlesize\": 14, \"axes.labelsize\": 12, \"legend.fontsize\": 11,\n",
        "        \"xtick.labelsize\": 11, \"ytick.labelsize\": 11, \"lines.linewidth\": 2.0,\n",
        "    })\n",
        "\n",
        "    def dual_axis_plot(x, y_left, y_right, x_label, y_left_label, y_right_label, title, x_is_log=False, x_ticks=None):\n",
        "        fig, ax1 = plt.subplots(figsize=(6.5, 3.8), constrained_layout=True)\n",
        "        if x_is_log: ax1.set_xscale(\"log\")\n",
        "        ax1.plot(x, y_left, marker=\"o\", label=y_left_label)\n",
        "        ax1.set_xlabel(x_label); ax1.set_ylabel(y_left_label)\n",
        "        ax1.grid(True, which=\"both\", linestyle=\"--\", alpha=0.4)\n",
        "        ax2 = ax1.twinx()\n",
        "        ax2.plot(x, y_right, linestyle=\"--\", marker=\"s\", label=y_right_label)\n",
        "        ax2.set_ylabel(y_right_label)\n",
        "        lines, labels = [], []\n",
        "        for ax in (ax1, ax2):\n",
        "            L = ax.get_lines(); lines += L; labels += [l.get_label() for l in L]\n",
        "        ax1.legend(lines, labels, loc=\"best\", frameon=False)\n",
        "        if x_ticks is not None:\n",
        "            ax1.set_xticks(range(len(x_ticks))); ax1.set_xticklabels(x_ticks, rotation=20)\n",
        "        plt.title(title)\n",
        "        return fig\n",
        "\n",
        "    betas_arr  = np.array([r['beta'] for r in results], dtype=float)\n",
        "    acc_arr    = np.array([r['task_acc'] for r in results], dtype=float)\n",
        "    wmacc_arr  = np.array([r['wm_acc'] for r in results], dtype=float)\n",
        "    kappa_arr  = np.array([r['kappa_marg'] for r in results], dtype=float)\n",
        "\n",
        "    tag = f\"{dataset_name}_{backbone}_s{seed}\"\n",
        "\n",
        "    # Fig1: imperceptibility\n",
        "    fig1 = dual_axis_plot(\n",
        "        x=betas_arr + 1e-12, y_left=acc_arr, y_right=wmacc_arr,\n",
        "        x_label=r\"$\\beta_{\\mathrm{wm}}$\", y_left_label=\"Task ACC\", y_right_label=\"WM-ACC\",\n",
        "        title=f\"Imperceptibility ({dataset_name}/{backbone})\", x_is_log=True\n",
        "    )\n",
        "    for ext in (\"pdf\",\"svg\",\"png\"):\n",
        "        fig1.savefig(os.path.join(save_dir, f\"{tag}_fig1_imperceptibility.{ext}\"), bbox_inches=\"tight\")\n",
        "    plt.close(fig1)\n",
        "\n",
        "    # Fig2: robustness + attacks（常规尺寸）\n",
        "    x_items = [r[0] for r in robust_rows] + [r[0] for r in attack_rows]\n",
        "    x_acc   = [r[1] for r in robust_rows] + [r[1] for r in attack_rows]\n",
        "    x_wm    = [r[2] for r in robust_rows] + [r[2] for r in attack_rows]\n",
        "    xpos = np.arange(len(x_items))\n",
        "    fig3 = dual_axis_plot(\n",
        "        x=xpos, y_left=np.array(x_acc), y_right=np.array(x_wm),\n",
        "        x_label=\"Modification\", y_left_label=\"Task ACC\", y_right_label=\"WM-ACC\",\n",
        "        title=f\"Robustness & Attacks ({dataset_name}/{backbone})\", x_is_log=False, x_ticks=x_items\n",
        "    )\n",
        "    for ext in (\"pdf\",\"svg\",\"png\"):\n",
        "        fig3.savefig(os.path.join(save_dir, f\"{tag}_fig2_robust_attacks.{ext}\"), bbox_inches=\"tight\")\n",
        "    plt.close(fig3)\n",
        "\n",
        "    # 保存 CSV\n",
        "    imp_df = pd.DataFrame(results)\n",
        "    rb_df  = pd.DataFrame(robust_rows, columns=[\"name\",\"acc\",\"wm_acc\",\"kappa\"])\n",
        "    at_df  = pd.DataFrame(attack_rows, columns=[\"name\",\"acc\",\"wm_acc\",\"kappa\"]) if len(attack_rows)>0 else pd.DataFrame(columns=[\"name\",\"acc\",\"wm_acc\",\"kappa\"])\n",
        "    carr_df = pd.DataFrame(carr_rows, columns=[\"pdel\",\"padd\",\"wm_acc_mean\",\"wm_acc_std\",\"wm_acc_ci95\",\"n_trials\"]) if len(carr_rows)>0 else None\n",
        "\n",
        "    for df, nm in [(imp_df,\"imperceptibility\"),(rb_df,\"robust\"),(at_df,\"attacks\")]:\n",
        "        df.to_csv(os.path.join(save_dir, f\"{tag}_{nm}.csv\"), index=False)\n",
        "    if carr_df is not None:\n",
        "        carr_df.to_csv(os.path.join(save_dir, f\"{tag}_carrier_perturb.csv\"), index=False)\n",
        "\n",
        "    # 小报告\n",
        "    best_row = imp_df[imp_df[\"beta\"]>0].sort_values([\"task_acc\",\"wm_acc\"], ascending=[False,False]).iloc[0]\n",
        "    base_acc_for_rep = float(imp_df[imp_df[\"beta\"]==0][\"task_acc\"].iloc[0]) if (0 in imp_df[\"beta\"].values) else float(imp_df[\"task_acc\"].iloc[0])\n",
        "    rep = {\n",
        "        \"dataset\": dataset_name, \"backbone\": backbone, \"seed\": seed,\n",
        "        \"base_acc\": base_acc_for_rep, \"chosen_beta\": chosen_beta,\n",
        "        \"chosen_acc\": float(best_row[\"task_acc\"]), \"chosen_wm_acc\": float(best_row[\"wm_acc\"]),\n",
        "        \"uniqueness\": {\"m\": int(m), \"alpha\": owner_alpha, \"tau\": int(tau), \"T_owner\": int(T_owner),\n",
        "                       \"imp_max_T\": int(max_imp_T), \"pass\": bool(uniq_pass)}\n",
        "    }\n",
        "    with open(os.path.join(save_dir, f\"{tag}_report.json\"), \"w\") as f:\n",
        "        json.dump(rep, f, indent=2)\n",
        "\n",
        "    return {\n",
        "        \"dataset\": dataset_name, \"backbone\": backbone, \"seed\": seed,\n",
        "        \"IN_DIM\": IN_DIM, \"OUT_DIM\": OUT_DIM,\n",
        "        \"imperceptibility\": results, \"chosen_beta\": chosen_beta,\n",
        "        \"uniqueness\": {\"m\": m, \"tau\": tau, \"T_owner\": T_owner, \"imp_avg_wa\": avg_imp_wa, \"imp_max_T\": max_imp_T, \"pass\": uniq_pass},\n",
        "        \"robust\": robust_rows, \"attacks\": attack_rows, \"carrier_perturb\": carr_rows\n",
        "    }\n",
        "\n",
        "# ---------- 11) β_min vs m ----------\n",
        "def run_beta_min_vs_m(dataset_name, backbone, seed, m_list, beta_list, owner_alpha, save_dir):\n",
        "    set_seed(seed)\n",
        "    _, train_dataset, test_dataset, IN_DIM, OUT_DIM = load_dataset(dataset_name, seed)\n",
        "    lam_min, lam_scale = estimate_lambda_stats(train_dataset)\n",
        "    g = torch.Generator(); g.manual_seed(seed)\n",
        "    train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH, shuffle=True, num_workers=cfg.NUM_WORKERS, generator=g)\n",
        "    test_loader  = DataLoader(test_dataset,  batch_size=cfg.BATCH, shuffle=False, num_workers=cfg.NUM_WORKERS)\n",
        "\n",
        "    rows = []\n",
        "    extra_kw = _backbone_kwargs(backbone, train_dataset)\n",
        "    for m_bits in m_list:\n",
        "        wm_batch, wm_targets, owner_key = build_carriers(train_dataset, m_bits, cfg.MAX_CARRIER_N_PERCENTILE, lam_min, lam_scale, seed)\n",
        "        tau = tau_threshold(m_bits, owner_alpha)\n",
        "        tau_frac = tau / m_bits\n",
        "        for beta in beta_list:\n",
        "            model = WatermarkedGNN(IN_DIM, cfg.HIDDEN, OUT_DIM, backbone=backbone,\n",
        "                                   num_layers=cfg.NUM_LAYERS, dropout=cfg.DROPOUT, **extra_kw).to(cfg.DEVICE)\n",
        "            opt = torch.optim.Adam(model.parameters(), lr=cfg.LR, weight_decay=cfg.WD)\n",
        "            for _ in range(cfg.EPOCHS):\n",
        "                train_epoch(model, train_loader, opt, beta, wm_batch, owner_key, cfg.DEVICE)\n",
        "            acc = eval_acc(model, test_loader, cfg.DEVICE)\n",
        "            wa, km, *_ = verify_watermark(model, wm_batch, owner_key, cfg.DEVICE)\n",
        "            rows.append([m_bits, beta, acc, wa, km, tau, int(wa*m_bits)])\n",
        "        print(f\"[β_min sweep] m={m_bits} done\")\n",
        "\n",
        "    df = pd.DataFrame(rows, columns=[\"m\",\"beta\",\"ACC\",\"WM-ACC\",\"kappa\",\"tau\",\"T_est\"])\n",
        "    Path(save_dir).mkdir(parents=True, exist_ok=True)\n",
        "    tag = f\"{dataset_name}_{backbone}_s{seed}\"\n",
        "    df.to_csv(os.path.join(save_dir, f\"{tag}_beta_sweep_per_m.csv\"), index=False)\n",
        "\n",
        "    # plot\n",
        "    beta_min_vals = []; tau_fracs = []\n",
        "    for m_bits in m_list:\n",
        "        sub = df[df[\"m\"] == m_bits].sort_values(\"beta\")\n",
        "        tau_frac = sub[\"tau\"].iloc[0] / m_bits\n",
        "        sat = sub[sub[\"WM-ACC\"] >= tau_frac]\n",
        "        beta_min_vals.append(sat[\"beta\"].iloc[0] if len(sat)>0 else np.nan)\n",
        "        tau_fracs.append(float(tau_frac))\n",
        "\n",
        "    fig, ax = plt.subplots(figsize=(7.2, 4.2), constrained_layout=True)\n",
        "    ax.set_yscale(\"log\")\n",
        "    ms = np.array(m_list, dtype=float); bs = np.array(beta_min_vals, dtype=float)\n",
        "    ax.plot(ms[~np.isnan(bs)], bs[~np.isnan(bs)], marker=\"o\", label=r\"min $\\beta_{\\mathrm{wm}}$\")\n",
        "    if np.isnan(bs).any():\n",
        "        ax.scatter(ms[np.isnan(bs)], np.maximum(1e-7, np.nanmin(bs[~np.isnan(bs)]) if (~np.isnan(bs)).any() else 1e-4),\n",
        "                   facecolors='none', edgecolors='red', label=\"not reached\")\n",
        "    ax.set_xlabel(\"Watermark bits (m)\")\n",
        "    ax.set_ylabel(r\"Minimum $\\beta_{\\mathrm{wm}}$ to meet $\\tau(\\alpha)$\")\n",
        "    ax.grid(True, linestyle=\"--\", alpha=0.4)\n",
        "    ax.set_title(f\"β_min vs m at α={owner_alpha} ({dataset_name}/{backbone})\")\n",
        "    ax.legend(frameon=False)\n",
        "    for ext in (\"pdf\",\"svg\",\"png\"):\n",
        "        fig.savefig(os.path.join(save_dir, f\"{tag}_beta_min_vs_m.{ext}\"), bbox_inches=\"tight\")\n",
        "    plt.close(fig)\n",
        "    return df\n",
        "\n",
        "# ---------- 12) Main table aggregator ----------\n",
        "def collect_main_table(all_runs, out_dir):\n",
        "    rows = []\n",
        "    for r in all_runs:\n",
        "        ds, bb, sd = r[\"dataset\"], r[\"backbone\"], r[\"seed\"]\n",
        "        imp = r[\"imperceptibility\"]\n",
        "        base = [x for x in imp if x[\"beta\"]==0]\n",
        "        base_acc = base[0][\"task_acc\"] if base else np.nan\n",
        "        # chosen\n",
        "        betas_pos = [x for x in imp if x[\"beta\"]>0]\n",
        "        chosen = max([x for x in betas_pos if x[\"task_acc\"] >= base_acc - cfg.ALLOW_ACC_DROP] or betas_pos, key=lambda x: x[\"wm_acc\"])\n",
        "        chosen_beta = float(chosen[\"beta\"]); chosen_acc = float(chosen[\"task_acc\"]); chosen_wm = float(chosen[\"wm_acc\"])\n",
        "        uniq = r[\"uniqueness\"]\n",
        "        uniq_pass = uniq.get(\"pass\", False)\n",
        "        # worst attack wm-acc\n",
        "        all_atk = r[\"robust\"] + r[\"attacks\"]\n",
        "        worst_wm = min([x[2] for x in all_atk]) if all_atk else np.nan\n",
        "        rows.append([ds, bb, sd, base_acc, chosen_beta, chosen_acc, chosen_wm, int(uniq[\"tau\"]), int(uniq[\"T_owner\"]), int(uniq[\"imp_max_T\"]), bool(uniq_pass), float(worst_wm)])\n",
        "    df = pd.DataFrame(rows, columns=[\"dataset\",\"backbone\",\"seed\",\"base_acc\",\"chosen_beta\",\"chosen_acc\",\"chosen_wm_acc\",\"tau\",\"T_owner\",\"imp_max_T\",\"uniq_pass\",\"worst_wm_acc\"])\n",
        "    Path(out_dir).mkdir(parents=True, exist_ok=True)\n",
        "    df.to_csv(os.path.join(out_dir, \"Table_Main.csv\"), index=False)\n",
        "    print(\"Saved:\", os.path.join(out_dir, \"Table_Main.csv\"))\n",
        "    return df\n",
        "\n",
        "# ---------- 13) Run all ----------\n",
        "def run_all_and_make_outputs():\n",
        "    Path(cfg.SAVE_DIR).mkdir(parents=True, exist_ok=True)\n",
        "    all_runs = []\n",
        "    for ds in cfg.DATASETS:\n",
        "        for bb in cfg.BACKBONES:\n",
        "            for sd in cfg.SEEDS:\n",
        "                print(f\"\\n== Running {ds} / {bb} / seed={sd} ==\")\n",
        "                runsave = os.path.join(cfg.SAVE_DIR, f\"{ds}_{bb}_s{sd}\")\n",
        "                out = run_single_experiment(\n",
        "                    dataset_name=ds, backbone=bb, seed=sd,\n",
        "                    betas=cfg.BETA_GRID, m_bits=cfg.WM_BITS,\n",
        "                    owner_alpha=cfg.OWNER_ALPHA, save_dir=runsave\n",
        "                )\n",
        "                all_runs.append(out)\n",
        "\n",
        "            # capacity sweep（只在第一个 seed 上做，以省时）\n",
        "            if cfg.RUN_CAPACITY_SWEEP and len(cfg.SEEDS)>0:\n",
        "                _ = run_beta_min_vs_m(\n",
        "                    dataset_name=ds, backbone=bb, seed=cfg.SEEDS[0],\n",
        "                    m_list=cfg.M_LIST, beta_list=cfg.CAP_BETA_GRID,\n",
        "                    owner_alpha=cfg.OWNER_ALPHA, save_dir=os.path.join(cfg.SAVE_DIR, f\"{ds}_{bb}_capacity\")\n",
        "                )\n",
        "\n",
        "    # Main table\n",
        "    _ = collect_main_table(all_runs, cfg.SAVE_DIR)\n",
        "    print(\"\\nAll outputs under:\", os.path.abspath(cfg.SAVE_DIR))\n",
        "\n",
        "# ---------- 14) Go ----------\n",
        "if __name__ == \"__main__\":\n",
        "    run_all_and_make_outputs()\n"
      ],
      "metadata": {
        "id": "yNb1YcK-hSMv"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}