{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f5b14808",
   "metadata": {},
   "source": [
    "# DFN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2efb5921",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce8269a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import cppimport\n",
    "\n",
    "import sys, time, math, io, contextlib\n",
    "from pathlib import Path\n",
    "from typing import Optional, List, Tuple, Dict, Any\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from IPython.display import display\n",
    "\n",
    "import gurobipy as gp\n",
    "from gurobipy import GRB\n",
    "\n",
    "_TOL = 1e-9"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417583e0",
   "metadata": {},
   "source": [
    "## LEMON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62767112",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "repo = Path().resolve().parent\n",
    "if str(repo) not in sys.path:\n",
    "    sys.path.insert(0, str(repo))\n",
    "\n",
    "lemon_mcf = cppimport.imp(\"lemon_mcf\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15e6bd1b",
   "metadata": {},
   "source": [
    "### Quick sanity checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88381d28",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 3\n",
    "src    = np.array([0, 0, 1], dtype=np.int64)\n",
    "dst    = np.array([1, 2, 2], dtype=np.int64)\n",
    "cost   = np.array([2.0, 1.0, 3.0], dtype=np.float64)\n",
    "cap    = np.array([5.0, 2.0, 4.0], dtype=np.float64)\n",
    "supply = np.array([3.0, -1.0, -2.0], dtype=np.float64)\n",
    "\n",
    "out_min_cost_flow = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply)\n",
    "print(out_min_cost_flow)\n",
    "\n",
    "n = 4\n",
    "src = np.array([0,0,1,2], dtype=np.int64)\n",
    "dst = np.array([1,2,3,3], dtype=np.int64)\n",
    "cap = np.array([3.0,2.0,2.0,4.0], dtype=np.float64)\n",
    "out_max_flow = lemon_mcf.max_flow(n, src, dst, cap, 0, 3)\n",
    "print(out_max_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75235271",
   "metadata": {},
   "source": [
    "## Dataset generators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67763cc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def check_convex_extendible(X, y, tol=1e-8, topk=5):\n",
    "    X = np.asarray(X, float)\n",
    "    y = np.asarray(y, float).reshape(-1)\n",
    "    n, d = X.shape\n",
    "\n",
    "    _, inv = np.unique(X, axis=0, return_inverse=True)\n",
    "    for g in range(inv.max() + 1):\n",
    "        idx = np.where(inv == g)[0]\n",
    "        if len(idx) > 1 and not np.all(y[idx] == y[idx][0]):\n",
    "            ys = y[idx]\n",
    "            return {\"convex_extendible\": False, \"reason\": \"duplicate X, conflicting y\",\n",
    "                    \"conflict\": (idx.tolist(), ys.tolist(), float(ys.max() - ys.min()))}\n",
    "\n",
    "    D = X[None, :, :] - X[:, None, :]      # x^j - x^i\n",
    "    R = y[None, :] - y[:, None]            # y^j - y^i\n",
    "\n",
    "    m = gp.Model()\n",
    "    m.Params.OutputFlag = 0\n",
    "    g = m.addVars(n, d, lb=-GRB.INFINITY)\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            if i != j:\n",
    "                m.addConstr(gp.quicksum(g[i,k] * D[i,j,k] for k in range(d)) <= R[i,j])\n",
    "    m.setObjective(0.0)\n",
    "    m.optimize()\n",
    "    if m.Status != GRB.OPTIMAL:\n",
    "        return {\"convex_extendible\": False, \"reason\": \"LP infeasible/failed\", \"status\": int(m.Status)}\n",
    "\n",
    "    G = np.array([[g[i,k].X for k in range(d)] for i in range(n)])\n",
    "    V = np.einsum(\"ijd,id->ij\", D, G) - R  # lhs - rhs\n",
    "    max_v = float(V.max())\n",
    "    out = {\"convex_extendible\": (max_v <= tol), \"max_violation\": max_v, \"tol\": float(tol)}\n",
    "    if topk:\n",
    "        flat = np.argsort(V.ravel())[::-1]\n",
    "        worst = []\n",
    "        for p in flat:\n",
    "            v = float(V.ravel()[p])\n",
    "            if v <= tol: break\n",
    "            worst.append(((int(p // n), int(p % n)), v))\n",
    "            if len(worst) >= topk: break\n",
    "        if worst: out[\"worst_pairs\"] = worst\n",
    "    return out\n",
    "\n",
    "def make_mdvsp_dataset(\n",
    "    K: int,\n",
    "    filename: str,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "    max_trips=None,\n",
    "    max_succ=None,\n",
    "):\n",
    "    \"\"\"Multiple-Depot Vehicle Scheduling (MDVSP) dataset.\n",
    "\n",
    "    Returns:\n",
    "      X: (K, m) integer-ish capacities (float32)\n",
    "      y: (K,) min-cost-flow objective values (float32)\n",
    "      gt: dict containing the fixed network pieces (for later evaluation)\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    with open(filename) as f:\n",
    "        m, n, l = map(int, f.readline().split())\n",
    "        f.readline()  # blank line\n",
    "        trips = np.loadtxt(f, max_rows=n, dtype=np.int64)[:max_trips]\n",
    "        D = np.loadtxt(f, max_rows=l, dtype=np.int64)\n",
    "\n",
    "    p, s, q, e = trips.T\n",
    "    ntr = len(trips)\n",
    "\n",
    "    # node ids\n",
    "    SS = 0\n",
    "    depS = 1 + np.arange(m)\n",
    "    depT = 1 + m + np.arange(m)\n",
    "    trS  = 1 + 2*m + np.arange(ntr)\n",
    "    trT  = 1 + 2*m + ntr + np.arange(ntr)\n",
    "    TT = 1 + 2*m + 2*ntr\n",
    "    N = TT + 1\n",
    "\n",
    "    src, dst, cost, cap = [], [], [], []\n",
    "\n",
    "    for d in range(m):\n",
    "        src += [SS, int(depT[d])]\n",
    "        dst += [int(depS[d]), TT]\n",
    "        cost += [0.0, 0.0]\n",
    "        cap  += [0.0, 0.0]\n",
    "    idxSS = np.arange(0, 2*m, 2)  # arcs SS->depS\n",
    "    idxTT = np.arange(1, 2*m, 2)  # arcs depT->TT\n",
    "\n",
    "    # trip arcs (always cap=1)\n",
    "    src += trS.tolist()\n",
    "    dst += trT.tolist()\n",
    "    cost += [0.0] * ntr\n",
    "    cap  += [1.0] * ntr\n",
    "\n",
    "    # depot-to-trip and trip-to-depot arcs\n",
    "    for d in range(m):\n",
    "        # depS -> trS\n",
    "        src += [int(depS[d])] * ntr\n",
    "        dst += trS.tolist()\n",
    "        cost += (5000 + 10 * D[d, p]).astype(float).tolist()\n",
    "        cap  += [1.0] * ntr\n",
    "\n",
    "        # trT -> depT\n",
    "        src += trT.tolist()\n",
    "        dst += [int(depT[d])] * ntr\n",
    "        cost += (5000 + 10 * D[q, d]).astype(float).tolist()\n",
    "        cap  += [1.0] * ntr\n",
    "\n",
    "    # feasible trip successor arcs (trT -> next trS)\n",
    "    order = np.argsort(s)\n",
    "    p2, s2 = p[order], s[order]\n",
    "    max_succ_eff = ntr if max_succ is None else int(max_succ)\n",
    "\n",
    "    for i in range(ntr):\n",
    "        travel = D[q[i], p2]                          # time from trip i end depot -> next trip start depot\n",
    "        feas = np.flatnonzero(s2 >= e[i] + travel)[:max_succ_eff]\n",
    "        j = order[feas]\n",
    "        if j.size:\n",
    "            src += [int(trT[i])] * j.size\n",
    "            dst += trS[j].tolist()\n",
    "            cost += (8 * travel[feas] + 2 * (s[j] - e[i])).astype(float).tolist()\n",
    "            cap  += [1.0] * j.size\n",
    "\n",
    "    src = np.asarray(src, dtype=np.int64)\n",
    "    dst = np.asarray(dst, dtype=np.int64)\n",
    "    cost = np.asarray(cost, dtype=np.float64)\n",
    "    cap0 = np.asarray(cap, dtype=np.float64)\n",
    "\n",
    "    X = rng.integers(x_min, np.asarray(x_max) + 1, size=(K, m)).astype(np.float64)\n",
    "    y = np.empty(K, dtype=np.float64)\n",
    "\n",
    "    for k in range(K):\n",
    "        cap_k = cap0.copy()\n",
    "        cap_k[idxSS] = X[k]\n",
    "        cap_k[idxTT] = X[k]\n",
    "\n",
    "        Fmax = lemon_mcf.max_flow(N, src, dst, cap_k, SS, TT)[\"value\"]\n",
    "        supply = np.zeros(N, dtype=np.float64)\n",
    "        supply[SS] = Fmax\n",
    "        supply[TT] = -Fmax\n",
    "\n",
    "        y[k] = lemon_mcf.solve_mcf(N, src, dst, cost, cap_k, supply)[\"total_cost\"]\n",
    "        if noise_std:\n",
    "            y[k] += noise_std * rng.normal()\n",
    "\n",
    "    gt = dict(\n",
    "        type=\"mdvsp\", N=int(N), SS=int(SS), TT=int(TT),\n",
    "        src=src, dst=dst, cost=cost, cap0=cap0, idxSS=idxSS, idxTT=idxTT\n",
    "    )\n",
    "\n",
    "    print(check_convex_extendible(X.astype(np.float32), y.astype(np.float32)))\n",
    "    \n",
    "    return X.astype(np.float32), y.astype(np.float32), gt\n",
    "\n",
    "\n",
    "def generate_bmatching_dataset(\n",
    "    K: int,\n",
    "    num_nodes: int,\n",
    "    c_min: int,\n",
    "    c_max: int,\n",
    "    x_max: int,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    C = rng.integers(c_min, c_max + 1, size=(num_nodes, num_nodes)).astype(np.float32)\n",
    "\n",
    "    X = rng.integers(0, x_max + 1, size=(K, num_nodes), dtype=np.int32)\n",
    "    y = np.zeros((K,), dtype=np.float32)\n",
    "\n",
    "    col_map = np.repeat(np.arange(num_nodes), x_max)\n",
    "    R = col_map.size\n",
    "    C_cols = C[:, col_map]  # shape: (num_nodes, R)\n",
    "\n",
    "    for k in range(K):\n",
    "        x = X[k]\n",
    "        while x.sum() == 0:\n",
    "            x = rng.integers(0, x_max + 1, size=num_nodes, dtype=np.int32)\n",
    "            X[k] = x\n",
    "\n",
    "        row_idx = np.repeat(np.arange(num_nodes), x)\n",
    "        m = row_idx.size\n",
    "\n",
    "        cost_rect = C_cols[row_idx, :]\n",
    "\n",
    "        if m < R:\n",
    "            pad = np.zeros((R - m, R), dtype=np.float32)\n",
    "            cost_sq = np.vstack([cost_rect, pad])\n",
    "        else:\n",
    "            cost_sq = cost_rect\n",
    "\n",
    "        r, c = linear_sum_assignment(cost_sq)\n",
    "\n",
    "        real = r < m\n",
    "        y[k] = cost_sq[r[real], c[real]].sum()\n",
    "\n",
    "        if noise_std:\n",
    "            y[k] += noise_std * rng.normal()\n",
    "\n",
    "    gt = {\"type\": \"assignment\", \"C\": C.astype(np.float32), \"x_max\": int(x_max)}\n",
    "    return X.astype(np.float32), y, gt\n",
    "\n",
    "\n",
    "def generate_convex_quadratic_dataset(\n",
    "    K: int,\n",
    "    dim: int,\n",
    "    eigen_min: float,\n",
    "    eigen_max: float,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "    x_star_zero: bool = False,\n",
    "):\n",
    "    \"\"\"y = (x - x*)^T Q (x - x*) + noise, with Q symmetric PSD.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    U, R = np.linalg.qr(rng.standard_normal((dim, dim)))\n",
    "    U *= np.sign(np.diag(R) + 1e-12)\n",
    "    Q = U @ np.diag(rng.uniform(eigen_min, eigen_max, dim)) @ U.T\n",
    "\n",
    "    x_star = np.zeros(dim, dtype=np.int64) if x_star_zero else rng.integers(x_min, x_max + 1, size=dim)\n",
    "\n",
    "    X = rng.integers(x_min, x_max + 1, size=(K, dim)).astype(np.float32)\n",
    "    d = X - x_star\n",
    "    y = np.einsum(\"bi,ij,bj->b\", d, Q, d) + noise_std * rng.normal(size=K)\n",
    "\n",
    "    gt = {\"type\": \"quadratic\", \"Q\": Q.astype(np.float32), \"x_star\": x_star.astype(np.int64)}\n",
    "    return X.astype(np.float32), y.astype(np.float32), gt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70071f5f",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07278db8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _ste_round(x: torch.Tensor) -> torch.Tensor:\n",
    "    return x + (torch.round(x) - x).detach()\n",
    "\n",
    "class _MCFValue(torch.autograd.Function):\n",
    "    \n",
    "    @staticmethod\n",
    "    def forward(ctx, n_nodes, src, dst, cost, cap, supply):\n",
    "        n = int(n_nodes)\n",
    "\n",
    "        src = src.to(dtype=torch.int64).contiguous()\n",
    "        dst = dst.to(dtype=torch.int64).contiguous()\n",
    "\n",
    "        m = int(src.numel())\n",
    "        if (dst.numel() != m) or (cost.numel() != m) or (cap.numel() != m) or (supply.numel() != n):\n",
    "            raise ValueError(\"Bad shapes for MCF inputs.\")\n",
    "        if torch.abs(supply.double().sum()) > _TOL:\n",
    "            raise ValueError(\"Require sum(supply)=0\")\n",
    "\n",
    "        def as_np(t: torch.Tensor, dtype):\n",
    "            return t.detach().cpu().contiguous().view(-1).numpy().astype(dtype, copy=False)\n",
    "\n",
    "        out = lemon_mcf.solve_mcf(\n",
    "            n,\n",
    "            as_np(src, np.int64),\n",
    "            as_np(dst, np.int64),\n",
    "            as_np(cost, np.float64),\n",
    "            as_np(cap, np.float64),\n",
    "            as_np(supply, np.float64),\n",
    "            tol=_TOL,\n",
    "        )\n",
    "        if out[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"LEMON failed (status={out['status']})\")\n",
    "\n",
    "        flow = out[\"flow\"]\n",
    "        pot = out[\"potential\"]\n",
    "        red = out[\"reduced_cost\"]\n",
    "        at = out.get(\"at_cap\", out.get(\"at_capacity\", None))\n",
    "        if at is None:\n",
    "            at = np.abs(flow - as_np(cap, np.float64)) <= _TOL\n",
    "\n",
    "        ctx.flow, ctx.pot, ctx.red, ctx.at = flow, pot, red, at\n",
    "        return cost.new_tensor(float(out[\"total_cost\"]))\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g):\n",
    "        dev, dt = g.device, g.dtype\n",
    "        flow = torch.as_tensor(ctx.flow, device=dev, dtype=dt)\n",
    "        pot  = torch.as_tensor(ctx.pot,  device=dev, dtype=dt)\n",
    "        red  = torch.as_tensor(ctx.red,  device=dev, dtype=dt)\n",
    "        at   = torch.as_tensor(ctx.at,   device=dev, dtype=torch.bool)\n",
    "\n",
    "        grad_cost = flow\n",
    "        grad_cap  = torch.where(at, red, torch.zeros_like(red))\n",
    "        grad_sup  = pot.mean() - pot\n",
    "\n",
    "        return None, None, None, grad_cost * g, grad_cap * g, grad_sup * g\n",
    "\n",
    "\n",
    "class DFN(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        layer_sizes,\n",
    "        p_list,\n",
    "        big_cost: float = 1e6,\n",
    "        big_cap: float = 1e6,\n",
    "        seed: int = 0,\n",
    "        A_fixed=None,\n",
    "        alpha: float = 1e-6,\n",
    "        beta: float = -0.0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.alpha = float(alpha)\n",
    "        self.beta  = float(beta)\n",
    "\n",
    "        layer_sizes = list(map(int, layer_sizes))\n",
    "        if len(layer_sizes) < 2 or len(p_list) != len(layer_sizes) - 1:\n",
    "            raise ValueError(\"Need len(layer_sizes)>=2 and len(p_list)=len(layer_sizes)-1\")\n",
    "\n",
    "        self.n = int(sum(layer_sizes))\n",
    "        if self.n <= 0:\n",
    "            raise ValueError(\"sum(layer_sizes) must be > 0\")\n",
    "\n",
    "        # node indices per layer\n",
    "        layers, off = [], 0\n",
    "        for s in layer_sizes:\n",
    "            layers.append(torch.arange(off, off + s, dtype=torch.long))\n",
    "            off += s\n",
    "\n",
    "        L1, LK = layers[0], layers[-1]\n",
    "        if L1.numel() == 0 or LK.numel() == 0:\n",
    "            raise ValueError(\"First/last layer must be non-empty.\")\n",
    "\n",
    "        self.fix_node = int(LK[-1].item())\n",
    "        boundary = torch.cat([L1, LK[:-1]], 0)\n",
    "        self.register_buffer(\"boundary\", boundary)\n",
    "\n",
    "        gen = torch.Generator().manual_seed(int(seed))\n",
    "\n",
    "        def bipartite(U: torch.Tensor, V: torch.Tensor):\n",
    "            su, sv = int(U.numel()), int(V.numel())\n",
    "            return U.repeat_interleave(sv), V.repeat(su)\n",
    "\n",
    "        def sample_edges(U, V, p: float):\n",
    "            s, t = bipartite(U, V)\n",
    "            if p < 1.0:\n",
    "                keep = torch.rand(s.numel(), generator=gen) < float(p)\n",
    "                s, t = s[keep], t[keep]\n",
    "            return s, t\n",
    "\n",
    "        # learnable arcs between consecutive layers (both directions)\n",
    "        sf, tf, sb, tb = [], [], [], []\n",
    "        for i, p in enumerate(map(float, p_list)):\n",
    "            if not (0.0 <= p <= 1.0):\n",
    "                raise ValueError(\"p_list entries must be in [0,1]\")\n",
    "\n",
    "            s, t = sample_edges(layers[i], layers[i + 1], p)  # forward\n",
    "            sf.append(s); tf.append(t)\n",
    "\n",
    "            s, t = sample_edges(layers[i + 1], layers[i], p)  # backward\n",
    "            sb.append(s); tb.append(t)\n",
    "\n",
    "        src_param = torch.cat([torch.cat(sf, 0), torch.cat(sb, 0)], 0)\n",
    "        dst_param = torch.cat([torch.cat(tf, 0), torch.cat(tb, 0)], 0)\n",
    "        if src_param.numel() == 0:\n",
    "            raise ValueError(\"No learnable arcs (increase p_list / layer sizes).\")\n",
    "\n",
    "        s1, t1 = bipartite(L1, LK)\n",
    "        s2, t2 = bipartite(LK, L1)\n",
    "        src_fixed = torch.cat([s1, s2], 0)\n",
    "        dst_fixed = torch.cat([t1, t2], 0)\n",
    "        m_fixed = int(src_fixed.numel())\n",
    "\n",
    "        self.register_buffer(\"src\", torch.cat([src_param, src_fixed], 0))\n",
    "        self.register_buffer(\"dst\", torch.cat([dst_param, dst_fixed], 0))\n",
    "        self.register_buffer(\"cap_fixed\",  torch.full((m_fixed,), float(big_cap),  dtype=torch.float32))\n",
    "        self.register_buffer(\"cost_fixed\", torch.full((m_fixed,), float(big_cost), dtype=torch.float32))\n",
    "\n",
    "        nb = int(boundary.numel())\n",
    "        input_dim = int(input_dim)\n",
    "\n",
    "        m_param = int(src_param.numel())\n",
    "        self.cap_raw  = nn.Parameter(torch.zeros(m_param) + 0.542)\n",
    "        self.cost_raw = nn.Parameter(torch.randn(m_param) + 1.0)\n",
    "        self.b_raw    = nn.Parameter(torch.zeros(nb))\n",
    "\n",
    "        if A_fixed is None:\n",
    "            A = torch.zeros(nb, input_dim)\n",
    "            rows = torch.arange(nb)\n",
    "            A[rows, rows % input_dim] = 1.0\n",
    "            self.A = nn.Parameter(A)\n",
    "        else:\n",
    "            A_fixed = torch.as_tensor(A_fixed, dtype=torch.float32)\n",
    "            if A_fixed.shape != (nb, input_dim):\n",
    "                raise ValueError(f\"A_fixed must have shape {(nb, input_dim)}, got {tuple(A_fixed.shape)}\")\n",
    "            self.register_buffer(\"A\", A_fixed)\n",
    "\n",
    "    def forward(self, w: torch.Tensor) -> torch.Tensor:\n",
    "        capP  = _ste_round(F.softplus(self.cap_raw))\n",
    "        costP = self.cost_raw\n",
    "        b     = _ste_round(self.b_raw)\n",
    "        A     = _ste_round(self.A) if isinstance(self.A, nn.Parameter) else self.A\n",
    "\n",
    "        cap  = torch.cat([capP,  self.cap_fixed.to(w.device, w.dtype)], 0)\n",
    "        cost = torch.cat([costP, self.cost_fixed.to(w.device, w.dtype)], 0)\n",
    "\n",
    "        def one(w1: torch.Tensor) -> torch.Tensor:\n",
    "            supply = torch.zeros(self.n, device=w1.device, dtype=torch.float64)\n",
    "            supply[self.boundary] = (A.double() @ w1.double()) + b.double()\n",
    "            supply[self.fix_node] = -supply.sum()\n",
    "            return _MCFValue.apply(self.n, self.src, self.dst, cost, cap, supply)\n",
    "\n",
    "        out = one(w) if w.dim() == 1 else torch.stack([one(wi) for wi in w], 0)\n",
    "        return self.alpha * out + self.beta\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dims, out_dim):\n",
    "        super().__init__()\n",
    "        dims = [in_dim] + list(hidden_dims) + [out_dim]\n",
    "        layers = []\n",
    "        for a, b in zip(dims[:-2], dims[1:-1]):\n",
    "            layers += [nn.Linear(a, b), nn.ReLU()]\n",
    "        layers += [nn.Linear(dims[-2], dims[-1])]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "class LSET(nn.Module):\n",
    "    def __init__(self, in_dim: int, n_pieces: int, T: float = 0.01):\n",
    "        super().__init__()\n",
    "        self.T = float(T)\n",
    "        if self.T == 0.0:\n",
    "            raise ValueError(\"T must be nonzero\")\n",
    "        self.A = nn.Parameter(torch.randn(n_pieces, in_dim) / (in_dim**0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces))\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        z = (x @ self.A.t() + self.b) / self.T\n",
    "        return self.T * torch.logsumexp(z, dim=-1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cca0120",
   "metadata": {},
   "source": [
    "## Training Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a7db872",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _split_train_val_test(X: torch.Tensor, y: torch.Tensor, *, val_frac: float, test_frac: float, seed: int):\n",
    "    N = int(X.shape[0])\n",
    "    n_test = int(round(test_frac * N))\n",
    "    n_val  = int(round(val_frac  * N))\n",
    "    n_train = N - n_val - n_test\n",
    "    if n_train <= 0:\n",
    "        raise ValueError(\"splits too large; train set would be empty\")\n",
    "\n",
    "    g = torch.Generator().manual_seed(int(seed))\n",
    "    perm = torch.randperm(N, generator=g)\n",
    "    i_tr = perm[:n_train]\n",
    "    i_va = perm[n_train:n_train + n_val]\n",
    "    i_te = perm[n_train + n_val:]\n",
    "\n",
    "    return (X[i_tr], y[i_tr]), (X[i_va], y[i_va]), (X[i_te], y[i_te])\n",
    "\n",
    "\n",
    "def _fit_standardizer(Xtr: torch.Tensor, ytr: torch.Tensor, *, eps: float):\n",
    "    x_mean = Xtr.mean(0, keepdim=True)\n",
    "    x_std  = Xtr.std(0, unbiased=False, keepdim=True)\n",
    "    x_std  = torch.where(x_std < eps, torch.ones_like(x_std), x_std)\n",
    "\n",
    "    y_mean = ytr.mean()\n",
    "    y_std  = ytr.std(unbiased=False).clamp_min(eps)\n",
    "\n",
    "    return {\"x_mean\": x_mean, \"x_std\": x_std, \"y_mean\": y_mean, \"y_std\": y_std}\n",
    "\n",
    "\n",
    "def _apply_standardizer(X: torch.Tensor, y: torch.Tensor, scaler):\n",
    "    Xn = (X - scaler[\"x_mean\"]) / scaler[\"x_std\"]\n",
    "    yn = (y - scaler[\"y_mean\"]) / scaler[\"y_std\"]\n",
    "    return Xn, yn\n",
    "\n",
    "@torch.no_grad()\n",
    "def _mse_norm(model: nn.Module, loader: DataLoader, device: str):\n",
    "    model.eval()\n",
    "    tot, n = 0.0, 0\n",
    "    for xb, yb in loader:\n",
    "        xb, yb = xb.to(device), yb.to(device)\n",
    "        pred = model(xb).squeeze(-1)\n",
    "        tot += F.mse_loss(pred, yb, reduction=\"sum\").item()\n",
    "        n += yb.numel()\n",
    "    return tot / max(n, 1)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def _predict_in_chunks(model: nn.Module, X: torch.Tensor, *, chunk: int):\n",
    "    out = []\n",
    "    for i in range(0, int(X.shape[0]), int(chunk)):\n",
    "        out.append(model(X[i:i + chunk]).squeeze(-1))\n",
    "    return torch.cat(out, 0)\n",
    "\n",
    "\n",
    "import json, hashlib\n",
    "\n",
    "# ---- Caching / persistence helpers ----\n",
    "_CACHE_ROOT_DEFAULT = Path(\"saved_runs\")\n",
    "\n",
    "def _to_hashable(x):\n",
    "    \"\"\"Convert nested objects to a deterministic, hashable (and mostly JSON-friendly) form.\"\"\"\n",
    "    import numpy as _np\n",
    "    import torch as _torch\n",
    "    from pathlib import Path as _Path\n",
    "\n",
    "    if x is None or isinstance(x, (bool, int, str)):\n",
    "        return x\n",
    "    if isinstance(x, float):\n",
    "        return float(x)\n",
    "    if isinstance(x, _Path):\n",
    "        return str(x)\n",
    "    if isinstance(x, (list, tuple)):\n",
    "        return [_to_hashable(v) for v in x]\n",
    "    if isinstance(x, dict):\n",
    "        return {str(k): _to_hashable(x[k]) for k in sorted(x.keys(), key=lambda z: str(z))}\n",
    "    if isinstance(x, _np.ndarray):\n",
    "        h = hashlib.sha256(x.tobytes(order=\"C\")).hexdigest()\n",
    "        return {\"__ndarray__\": True, \"dtype\": str(x.dtype), \"shape\": list(x.shape), \"sha256\": h}\n",
    "    if isinstance(x, _torch.Tensor):\n",
    "        t = x.detach().cpu()\n",
    "        h = hashlib.sha256(t.numpy().tobytes(order=\"C\")).hexdigest()\n",
    "        return {\"__tensor__\": True, \"dtype\": str(t.dtype), \"shape\": list(t.shape), \"sha256\": h}\n",
    "    return {\"__repr__\": repr(x)}\n",
    "\n",
    "def _run_signature(dataset_type, dataset_params, model_type, model_params, train_sig):\n",
    "    return {\n",
    "        \"dataset_type\": str(dataset_type).lower(),\n",
    "        \"dataset_params\": _to_hashable(dict(dataset_params)),\n",
    "        \"model_type\": str(model_type),\n",
    "        \"model_params\": _to_hashable(dict(model_params)),\n",
    "        \"train_params\": _to_hashable(dict(train_sig)),\n",
    "    }\n",
    "\n",
    "def _run_id_from_signature(sig) -> str:\n",
    "    blob = json.dumps(sig, sort_keys=True, separators=(\",\", \":\"), ensure_ascii=True).encode(\"utf-8\")\n",
    "    return hashlib.sha256(blob).hexdigest()[:16]\n",
    "\n",
    "def _get_run_dir(cache_root: Path, dataset_type, model_type, run_id: str) -> Path:\n",
    "    return Path(cache_root) / str(dataset_type).lower() / str(model_type) / run_id\n",
    "\n",
    "def _cpuify(obj):\n",
    "    import torch as _torch\n",
    "    if isinstance(obj, _torch.Tensor):\n",
    "        return obj.detach().cpu()\n",
    "    if isinstance(obj, dict):\n",
    "        return {k: _cpuify(v) for k, v in obj.items()}\n",
    "    if isinstance(obj, (list, tuple)):\n",
    "        return [_cpuify(v) for v in obj]\n",
    "    return obj\n",
    "\n",
    "def _set_full_determinism(seed: int):\n",
    "    \"\"\"CPU-only determinism: seed controls Python/NumPy/Torch RNGs.\n",
    "    \"\"\"\n",
    "    import os, random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    seed = int(seed)\n",
    "\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    try:\n",
    "        torch.use_deterministic_algorithms(True)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "def _seed_worker(worker_id: int):\n",
    "    import random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    worker_seed = torch.initial_seed() % 2**32\n",
    "    np.random.seed(worker_seed)\n",
    "    random.seed(worker_seed)\n",
    "\n",
    "def _try_load_cached_run(run_dir: Path):\n",
    "    artifact_path = Path(run_dir) / \"artifact.pt\"\n",
    "    if not artifact_path.exists():\n",
    "        return None\n",
    "    try:\n",
    "        return torch.load(str(artifact_path), map_location=\"cpu\")\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "def _build_model_from_params(model_type, model_params):\n",
    "    mt = str(model_type)\n",
    "    mp = dict(model_params)\n",
    "    if mt == \"DFN\":\n",
    "        model = DFN(**mp)\n",
    "    else:\n",
    "        mp.pop(\"alpha\", None)\n",
    "        mp.pop(\"beta\", None)\n",
    "        if mt == \"MLP\":\n",
    "            model = MLP(**mp)\n",
    "        elif mt == \"LSET\":\n",
    "            model = LSET(**mp)\n",
    "        else:\n",
    "            raise ValueError(\"model_type must be: DFN | MLP | LSET\")\n",
    "    return model\n",
    "\n",
    "def _save_run(run_dir: Path, signature: dict, model: nn.Module, data: dict, history: dict, spec: dict):\n",
    "    run_dir = Path(run_dir)\n",
    "    run_dir.mkdir(parents=True, exist_ok=True)\n",
    "    artifact_path = run_dir / \"artifact.pt\"\n",
    "    if artifact_path.exists():\n",
    "        return\n",
    "    state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n",
    "    artifact = {\n",
    "        \"signature\": signature,\n",
    "        \"state_dict\": state_dict,\n",
    "        \"data\": _cpuify(data),\n",
    "        \"history\": history,\n",
    "        \"spec\": spec,\n",
    "    }\n",
    "    torch.save(artifact, str(artifact_path))\n",
    "    try:\n",
    "        with open(run_dir / \"signature.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "            json.dump(signature, f, indent=2, sort_keys=True)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "def generate_and_train_simple(dataset_type, dataset_params, model_type, model_params, train_params=None):\n",
    "    \n",
    "    tp = train_params or {}\n",
    "\n",
    "    epochs     = int(tp.get(\"epochs\", 200))\n",
    "    batch_sz   = int(tp.get(\"batch_size\", 32))\n",
    "    lr         = float(tp.get(\"lr\", 1e-3))\n",
    "    wd         = float(tp.get(\"weight_decay\", 0.0))\n",
    "    val_frac   = float(tp.get(\"val_frac\", 0.15))\n",
    "    test_frac  = float(tp.get(\"test_frac\", 0.15))\n",
    "    eps        = float(tp.get(\"eps\", 1e-8))\n",
    "    seed       = int(tp.get(\"seed\", 0))\n",
    "    device     = tp.get(\"device\", \"cpu\")\n",
    "    plot_every = int(tp.get(\"plot_every\", 0) or 0)\n",
    "    plot_points= int(tp.get(\"plot_points\", 2048))\n",
    "    plot_chunk = int(tp.get(\"plot_chunk\", 4096))\n",
    "    print_stats= bool(tp.get(\"print_stats\", True))\n",
    "\n",
    "\n",
    "    _set_full_determinism(seed)\n",
    "    dataset_params = dict(dataset_params)\n",
    "    dataset_params.setdefault(\"seed\", int(seed))\n",
    "\n",
    "\n",
    "    train_sig = {\n",
    "        \"epochs\": int(epochs),\n",
    "        \"batch_size\": int(batch_sz),\n",
    "        \"lr\": float(lr),\n",
    "        \"weight_decay\": float(wd),\n",
    "        \"val_frac\": float(val_frac),\n",
    "        \"test_frac\": float(test_frac),\n",
    "        \"eps\": float(eps),\n",
    "        \"seed\": int(seed),\n",
    "        \"deterministic\": True,\n",
    "    }\n",
    "    cache_root = Path(tp.get(\"cache_root\", _CACHE_ROOT_DEFAULT))\n",
    "    signature = _run_signature(dataset_type, dataset_params, model_type, model_params, train_sig)\n",
    "    run_id = _run_id_from_signature(signature)\n",
    "    run_dir = _get_run_dir(cache_root, dataset_type, model_type, run_id)\n",
    "\n",
    "    cached = _try_load_cached_run(run_dir)\n",
    "    if cached is not None and isinstance(cached, dict) and \"state_dict\" in cached and \"data\" in cached and \"history\" in cached and \"spec\" in cached:\n",
    "        model = _build_model_from_params(model_type, model_params)\n",
    "        model.load_state_dict(cached[\"state_dict\"])\n",
    "        model = model.to(device)\n",
    "        return model, cached[\"data\"], cached[\"history\"], cached[\"spec\"]\n",
    "\n",
    "    # ---- data ----\n",
    "    dt = str(dataset_type).lower()\n",
    "    if dt == \"mdvsp\":\n",
    "        X, y, gt = make_mdvsp_dataset(**dataset_params)\n",
    "    elif dt == \"assignment\":\n",
    "        X, y, gt = generate_bmatching_dataset(**dataset_params)\n",
    "    elif dt == \"quadratic\":\n",
    "        X, y, gt = generate_convex_quadratic_dataset(**dataset_params)\n",
    "    else:\n",
    "        raise ValueError(\"dataset_type must be: mdvsp | assignment | quadratic\")\n",
    "\n",
    "    X = torch.as_tensor(X, dtype=torch.float32)\n",
    "    y = torch.as_tensor(y, dtype=torch.float32).view(-1)\n",
    "\n",
    "    # ---- print a quick dataset stats ----\n",
    "    if print_stats:\n",
    "        with torch.no_grad():\n",
    "            xmn = X.mean(0)\n",
    "            xsd = X.std(0, unbiased=False)\n",
    "            print(\n",
    "                f\"\\n--- Dataset stats ({dataset_type}) ---\\n\"\n",
    "                f\"  X: shape={tuple(X.shape)}  mean(mean)={xmn.mean():.3g}  std(mean)={xsd.mean():.3g}  \"\n",
    "                f\"min={float(X.min()):.3g}  max={float(X.max()):.3g}\\n\"\n",
    "                f\"  y: shape={tuple(y.shape)}  mean={float(y.mean()):.3g}  std={float(y.std(unbiased=False)):.3g}  \"\n",
    "                f\"min={float(y.min()):.3g}  max={float(y.max()):.3g}\\n\"\n",
    "            )\n",
    "\n",
    "    (Xtr, ytr), (Xva, yva), (Xte, yte) = _split_train_val_test(X, y, val_frac=val_frac, test_frac=test_frac, seed=seed)\n",
    "\n",
    "    scaler = _fit_standardizer(Xtr, ytr, eps=eps)\n",
    "    XtrN, ytrN = _apply_standardizer(Xtr, ytr, scaler)\n",
    "    XvaN, yvaN = _apply_standardizer(Xva, yva, scaler)\n",
    "    XteN, yteN = _apply_standardizer(Xte, yte, scaler)\n",
    "\n",
    "    g_dl = torch.Generator()\n",
    "    g_dl.manual_seed(seed)\n",
    "    train_loader = DataLoader(\n",
    "        TensorDataset(XtrN, ytrN),\n",
    "        batch_size=batch_sz,\n",
    "        shuffle=True,\n",
    "        generator=g_dl,\n",
    "        worker_init_fn=_seed_worker,\n",
    "    )\n",
    "    val_loader   = DataLoader(TensorDataset(XvaN, yvaN), batch_size=batch_sz, shuffle=False)\n",
    "\n",
    "    # subset for plotting\n",
    "    Nv = int(XvaN.shape[0])\n",
    "    if plot_points <= 0 or plot_points >= Nv:\n",
    "        plot_idx = torch.arange(Nv)\n",
    "    else:\n",
    "        g_plot = torch.Generator().manual_seed(seed + 12345)\n",
    "        plot_idx = torch.randperm(Nv, generator=g_plot)[:plot_points]\n",
    "\n",
    "    # ---- model ----\n",
    "    mt = str(model_type)\n",
    "    mp = dict(model_params)\n",
    "\n",
    "    if mt == \"DFN\":\n",
    "        model = DFN(**mp)\n",
    "        extra = f\"layers={mp.get('layer_sizes')} p_list={mp.get('p_list')} alpha={getattr(model,'alpha',None)} beta={getattr(model,'beta',None)}\"\n",
    "    else:\n",
    "        mp.pop(\"alpha\", None)\n",
    "        mp.pop(\"beta\", None)\n",
    "        if mt == \"MLP\":\n",
    "            model = MLP(**mp)\n",
    "            extra = f\"hidden={mp.get('hidden_dims')}\"\n",
    "        elif mt == \"LSET\":\n",
    "            model = LSET(**mp)\n",
    "            extra = f\"n_pieces={mp.get('n_pieces')} T={mp.get('T')}\"\n",
    "        else:\n",
    "            raise ValueError(\"model_type must be: DFN | MLP | LSET\")\n",
    "\n",
    "    model = model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "\n",
    "    n_params = sum(p.numel() for p in model.parameters())\n",
    "    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    print(\n",
    "        f\"\\n=== Run: {dataset_type} | {model_type} ===\\n\"\n",
    "        f\"  data: N={len(X)}  train/val/test={len(Xtr)}/{len(Xva)}/{len(Xte)}  dim={X.shape[1]}\\n\"\n",
    "        f\"  model: params={n_params:,} {extra}\\n\"\n",
    "        f\"  train: device={device}  epochs={epochs}  batch={batch_sz}  lr={lr:g}  wd={wd:g}  seed={seed}\\n\"\n",
    "    )\n",
    "\n",
    "    history = {\"train_mse_norm\": [], \"val_mse_norm\": []}\n",
    "    best_val, best_ep, best_state = float(\"inf\"), 0, None\n",
    "\n",
    "    live = display(None, display_id=True) if plot_every > 0 else None\n",
    "\n",
    "    for ep in range(1, epochs + 1):\n",
    "        model.train()\n",
    "        for xb, yb in train_loader:\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            loss = F.mse_loss(model(xb).squeeze(-1), yb)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "\n",
    "        tr_mse = _mse_norm(model, train_loader, device)\n",
    "        va_mse = _mse_norm(model, val_loader, device)\n",
    "        history[\"train_mse_norm\"].append(tr_mse)\n",
    "        history[\"val_mse_norm\"].append(va_mse)\n",
    "\n",
    "        if va_mse < best_val:\n",
    "            best_val, best_ep = va_mse, ep\n",
    "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
    "\n",
    "        if plot_every > 0 and (ep == 1 or ep % plot_every == 0 or ep == epochs):\n",
    "            model.eval()\n",
    "            x_plot = XvaN[plot_idx].to(device)\n",
    "            y_true = yvaN[plot_idx].to(device)\n",
    "            y_pred = _predict_in_chunks(model, x_plot, chunk=plot_chunk)\n",
    "\n",
    "            fig, ax = plt.subplots(1, 2, figsize=(10, 4))\n",
    "            ax[0].plot(history[\"train_mse_norm\"], label=\"train\")\n",
    "            ax[0].plot(history[\"val_mse_norm\"], label=\"val\")\n",
    "            ax[0].set_yscale(\"log\")\n",
    "            ax[0].set_title(f\"Epoch {ep}/{epochs} | val MSE={va_mse:.3e} (norm)\")\n",
    "            ax[0].legend()\n",
    "\n",
    "            yt = y_true.detach().cpu().numpy()\n",
    "            yp = y_pred.detach().cpu().numpy()\n",
    "            ax[1].scatter(yt, yp, s=10)\n",
    "            lo = float(min(yt.min(), yp.min()))\n",
    "            hi = float(max(yt.max(), yp.max()))\n",
    "            ax[1].plot([lo, hi], [lo, hi])\n",
    "            ax[1].set_xlabel(\"y_true (norm)\")\n",
    "            ax[1].set_ylabel(\"y_pred (norm)\")\n",
    "            ax[1].set_title(f\"Val scatter (n={len(yt)})\")\n",
    "\n",
    "            plt.tight_layout()\n",
    "            live.update(fig)\n",
    "            plt.close(fig)\n",
    "\n",
    "    if best_state is not None:\n",
    "        model.load_state_dict(best_state)\n",
    "\n",
    "    if print_stats:\n",
    "        print(f\"[DONE] best val MSE (norm) = {best_val:.3e} @ epoch {best_ep}\\n\")\n",
    "\n",
    "    data = {\n",
    "        \"raw\":  {\"Xtr\": Xtr,  \"ytr\": ytr,  \"Xva\": Xva,  \"yva\": yva,  \"Xte\": Xte,  \"yte\": yte},\n",
    "        \"norm\": {\"Xtr\": XtrN, \"ytr\": ytrN, \"Xva\": XvaN, \"yva\": yvaN, \"Xte\": XteN, \"yte\": yteN},\n",
    "        \"scaler\": scaler,\n",
    "        \"best_val_mse_norm\": float(best_val),\n",
    "        \"best_epoch\": int(best_ep),\n",
    "        \"stats\": {\"n_params\": int(n_params), \"n_trainable\": int(n_trainable)},\n",
    "        \"true\": gt,\n",
    "        \"device\": device,\n",
    "    }\n",
    "\n",
    "    spec = {\n",
    "        \"model\": model_type,\n",
    "        \"extra\": extra,\n",
    "        \"n_params\": int(n_params),\n",
    "        \"n_trainable\": int(n_trainable),\n",
    "        \"model_params\": dict(model_params),\n",
    "        \"train_params\": {\n",
    "            \"epochs\": int(epochs),\n",
    "            \"batch_size\": int(batch_sz),\n",
    "            \"lr\": float(lr),\n",
    "            \"weight_decay\": float(wd),\n",
    "            \"seed\": int(seed),\n",
    "            \"device\": str(device),\n",
    "            \"val_frac\": float(val_frac),\n",
    "            \"test_frac\": float(test_frac),\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "    # ---- save artifacts for this run ----\n",
    "\n",
    "    _save_run(run_dir, signature, model, data, history, spec)\n",
    "\n",
    "\n",
    "    return model, data, history, spec\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ace7f63f",
   "metadata": {},
   "source": [
    "## Optimization Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38416530",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_true_obj(gt, x):\n",
    "    if not isinstance(gt, dict):\n",
    "        return np.nan\n",
    "\n",
    "    x = np.asarray(x).reshape(-1)\n",
    "    t = gt.get(\"type\", None)\n",
    "\n",
    "    if t == \"quadratic\":\n",
    "        Q = np.asarray(gt[\"Q\"], float)\n",
    "        xs = np.asarray(gt[\"x_star\"], float).reshape(-1)\n",
    "        d = x.astype(float) - xs\n",
    "        return float(d @ Q @ d)\n",
    "\n",
    "    if t == \"assignment\":\n",
    "        C = np.asarray(gt[\"C\"], float)\n",
    "        x_max = int(gt.get(\"x_max\", 1))\n",
    "\n",
    "        # expand right side (capacity x_max per column)\n",
    "        col_map = np.repeat(np.arange(C.shape[1]), x_max)\n",
    "        C_cols = C[:, col_map]\n",
    "        R = col_map.size\n",
    "\n",
    "        # expand left side by multiplicities in x\n",
    "        x_int = np.asarray(x, int).clip(min=0)\n",
    "        row_idx = np.repeat(np.arange(C.shape[0]), x_int)\n",
    "        m = row_idx.size\n",
    "        if m == 0:\n",
    "            return 0.0\n",
    "\n",
    "        cost_rect = C_cols[row_idx, :]  # (m, R)\n",
    "        if m < R:\n",
    "            cost_sq = np.vstack([cost_rect, np.zeros((R - m, R))])\n",
    "        else:\n",
    "            cost_sq = cost_rect  # m==R is possible\n",
    "\n",
    "        r, c = linear_sum_assignment(cost_sq)\n",
    "        real = r < m\n",
    "        return float(cost_sq[r[real], c[real]].sum())\n",
    "\n",
    "    if t == \"mdvsp\":\n",
    "        N, SS, TT = int(gt[\"N\"]), int(gt[\"SS\"]), int(gt[\"TT\"])\n",
    "        src = np.asarray(gt[\"src\"], np.int64)\n",
    "        dst = np.asarray(gt[\"dst\"], np.int64)\n",
    "        cost = np.asarray(gt[\"cost\"], float)\n",
    "        cap0 = np.asarray(gt[\"cap0\"], float)\n",
    "        idxSS = np.asarray(gt[\"idxSS\"], np.int64)\n",
    "        idxTT = np.asarray(gt[\"idxTT\"], np.int64)\n",
    "\n",
    "        cap = cap0.copy()\n",
    "        cap[idxSS] = x.astype(float)\n",
    "        cap[idxTT] = x.astype(float)\n",
    "\n",
    "        Fmax = lemon_mcf.max_flow(N, src, dst, cap, SS, TT)[\"value\"]\n",
    "        supply = np.zeros(N, float)\n",
    "        supply[SS] = Fmax\n",
    "        supply[TT] = -Fmax\n",
    "        return float(lemon_mcf.solve_mcf(N, src, dst, cost, cap, supply)[\"total_cost\"])\n",
    "\n",
    "    return np.nan\n",
    "\n",
    "\n",
    "    t0 = time.perf_counter()\n",
    "    y = float(eval_batch(x[None, :])[0])\n",
    "    hist = [{\"iter\": 0, \"t\": 0.0, \"best_y\": y, \"x\": x.copy()}]\n",
    "    print(f\"iter=0  t=0.00s  best_y={y:.6g}  x={x.tolist()}\")\n",
    "\n",
    "    for it in range(1, int(max_iters) + 1):\n",
    "        cand = []\n",
    "        for k in range(1, delta + 1):\n",
    "            for dlt in deltas_exact(k):\n",
    "                z = x + dlt\n",
    "                if ok(z):\n",
    "                    cand.append(z)\n",
    "\n",
    "        if not cand:\n",
    "            print(f\"STOP: no feasible neighbors. best_y={y:.6g} x={x.tolist()}\")\n",
    "            break\n",
    "\n",
    "        Y = eval_batch(np.stack(cand, 0))\n",
    "        j = int(np.argmin(Y))\n",
    "        if float(Y[j]) >= y:\n",
    "            print(f\"STOP: local minimum. best_y={y:.6g} x={x.tolist()}\")\n",
    "            break\n",
    "\n",
    "        x, y = cand[j], float(Y[j])\n",
    "        hist.append({\"iter\": it, \"t\": time.perf_counter() - t0, \"best_y\": y, \"x\": x.copy()})\n",
    "        if it % max(1, int(print_every)) == 0:\n",
    "            print(f\"iter={it}  t={hist[-1]['t']:.2f}s  best_y={y:.6g}  x={x.tolist()}\")\n",
    "\n",
    "    return x, y, hist\n",
    "\n",
    "\n",
    "def _scaler_np(scaler):\n",
    "    xm = scaler[\"x_mean\"].detach().cpu().numpy().reshape(-1)\n",
    "    xs = scaler[\"x_std\"].detach().cpu().numpy().reshape(-1)\n",
    "    ym = float(scaler[\"y_mean\"].detach().cpu())\n",
    "    ys = float(scaler[\"y_std\"].detach().cpu())\n",
    "    return xm, xs, ym, ys\n",
    "\n",
    "\n",
    "def solve_dfn_ip_gurobi(dfn, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    x_mean, x_std, y_mean, y_std = _scaler_np(scaler)\n",
    "\n",
    "    cost = np.r_[dfn.cost_raw.detach().cpu().numpy(), dfn.cost_fixed.detach().cpu().numpy()]\n",
    "    cap  = np.r_[torch.round(F.softplus(dfn.cap_raw.detach())).cpu().numpy(), dfn.cap_fixed.detach().cpu().numpy()]\n",
    "\n",
    "    A = dfn.A.detach().cpu().numpy()\n",
    "    if isinstance(dfn.A, torch.nn.Parameter):\n",
    "        A = np.round(A)\n",
    "    b = np.round(dfn.b_raw.detach().cpu().numpy())\n",
    "\n",
    "    src = dfn.src.detach().cpu().numpy().astype(int)\n",
    "    dst = dfn.dst.detach().cpu().numpy().astype(int)\n",
    "    boundary = dfn.boundary.detach().cpu().numpy().astype(int)\n",
    "    fix = int(dfn.fix_node)\n",
    "    n = int(dfn.n)\n",
    "    m = int(src.size)\n",
    "    alpha = float(dfn.alpha)\n",
    "    beta  = float(dfn.beta)\n",
    "\n",
    "    out = [[] for _ in range(n)]\n",
    "    inn = [[] for _ in range(n)]\n",
    "    for e in range(m):\n",
    "        out[src[e]].append(e)\n",
    "        inn[dst[e]].append(e)\n",
    "\n",
    "    M = gp.Model(\"DFN_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    f = M.addVars(m, lb=0.0, ub=cap.tolist(), vtype=GRB.CONTINUOUS, name=\"f\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq)\n",
    "\n",
    "    xm_over_xs = x_mean / x_std\n",
    "    s = [0] * n\n",
    "    s_boundary = []\n",
    "    for r, v in enumerate(boundary):\n",
    "        const = float(b[r] - (A[r] * xm_over_xs).sum())\n",
    "        expr = const + gp.quicksum((A[r, j] / x_std[j]) * x[j] for j in range(d) if A[r, j] != 0)\n",
    "        s[v] = expr\n",
    "        s_boundary.append(expr)\n",
    "    s[fix] = -gp.quicksum(s_boundary)\n",
    "\n",
    "    for v in range(n):\n",
    "        M.addConstr(gp.quicksum(f[e] for e in out[v]) - gp.quicksum(f[e] for e in inn[v]) == s[v])\n",
    "\n",
    "    flow_cost = gp.quicksum(cost[e] * f[e] for e in range(m))\n",
    "    M.setObjective((alpha * flow_cost + beta) * y_std + y_mean, GRB.MINIMIZE)\n",
    "\n",
    "    M.optimize()\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No solution (Gurobi status {M.Status})\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None)}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_mlp_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    base, a_out, b_out = model, 1.0, 0.0\n",
    "    if hasattr(model, \"base\") and hasattr(model, \"a\") and hasattr(model, \"b\"):\n",
    "        base, a_out, b_out = model.base, float(model.a), float(model.b)\n",
    "\n",
    "    if not hasattr(base, \"net\"):\n",
    "        raise ValueError(\"Expected an MLP with attribute .net (nn.Sequential).\")\n",
    "    linears = [L for L in base.net if isinstance(L, torch.nn.Linear)]\n",
    "    if not linears:\n",
    "        raise ValueError(\"No Linear layers found in base.net\")\n",
    "\n",
    "    W = [L.weight.detach().cpu().numpy().astype(float) for L in linears]\n",
    "    b = [L.bias.detach().cpu().numpy().astype(float) for L in linears]\n",
    "\n",
    "    W[0] = W[0] / xs[None, :]\n",
    "    b[0] = b[0] - W[0] @ xm\n",
    "\n",
    "    W[-1] *= a_out\n",
    "    b[-1] = a_out * b[-1] + b_out\n",
    "\n",
    "    u = np.maximum(np.abs(x_min), np.abs(x_max))\n",
    "    preLU = []\n",
    "    for k in range(len(W) - 1):\n",
    "        U = np.abs(W[k]) @ u + np.abs(b[k])\n",
    "        preLU.append((-U, U))\n",
    "        u = np.maximum(0.0, U)\n",
    "\n",
    "    M = gp.Model(\"MLP_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    prev = [x[i] for i in range(d)]\n",
    "\n",
    "    for k in range(len(W) - 1):\n",
    "        Lk, Uk = preLU[k]\n",
    "        h = W[k].shape[0]\n",
    "\n",
    "        a = [M.addVar(lb=float(Lk[j]), ub=float(Uk[j]), name=f\"a{k}_{j}\") for j in range(h)]\n",
    "        z = [M.addVar(lb=0.0, ub=float(max(0.0, Uk[j])), name=f\"z{k}_{j}\") for j in range(h)]\n",
    "\n",
    "        for j in range(h):\n",
    "            M.addConstr(a[j] == b[k][j] + gp.quicksum(W[k][j, i] * prev[i] for i in range(len(prev))))\n",
    "\n",
    "            Lj, Uj = float(Lk[j]), float(Uk[j])\n",
    "            if Uj <= 0.0:\n",
    "                M.addConstr(z[j] == 0.0)\n",
    "            elif Lj >= 0.0:\n",
    "                M.addConstr(z[j] == a[j])\n",
    "            else:\n",
    "                s = M.addVar(vtype=GRB.BINARY, name=f\"s{k}_{j}\")\n",
    "                M.addConstr(z[j] >= a[j])\n",
    "                M.addConstr(z[j] >= 0.0)\n",
    "                M.addConstr(z[j] <= Uj * s)\n",
    "                M.addConstr(z[j] <= a[j] - Lj * (1 - s))\n",
    "\n",
    "        prev = z\n",
    "\n",
    "    if W[-1].shape[0] != 1:\n",
    "        raise ValueError(f\"Expected scalar output, got out_dim={W[-1].shape[0]}\")\n",
    "\n",
    "    y_norm = M.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name=\"y_norm\")\n",
    "    M.addConstr(y_norm == b[-1][0] + gp.quicksum(W[-1][0, i] * prev[i] for i in range(len(prev))))\n",
    "\n",
    "    M.setObjective(ys * y_norm + ym, GRB.MINIMIZE)\n",
    "\n",
    "    M.optimize()\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None), \"sol_count\": M.SolCount}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_lset_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    A = model.A.detach().cpu().numpy().astype(float)\n",
    "    b = model.b.detach().cpu().numpy().astype(float)\n",
    "    T = float(model.T)\n",
    "    if T == 0.0:\n",
    "        raise ValueError(\"T must be nonzero\")\n",
    "\n",
    "    Aeff = A / xs[None, :]\n",
    "    beff = b - (Aeff @ xm)\n",
    "    K = int(Aeff.shape[0])\n",
    "\n",
    "    lin_lo = np.empty(K, dtype=float)\n",
    "    lin_hi = np.empty(K, dtype=float)\n",
    "    for k in range(K):\n",
    "        a = Aeff[k]\n",
    "        lo = beff[k]\n",
    "        hi = beff[k]\n",
    "        pos = a >= 0\n",
    "        lo += (a[pos] * x_min[pos]).sum() + (a[~pos] * x_max[~pos]).sum()\n",
    "        hi += (a[pos] * x_max[pos]).sum() + (a[~pos] * x_min[~pos]).sum()\n",
    "        lin_lo[k], lin_hi[k] = lo, hi\n",
    "\n",
    "    z_lo = lin_lo / T\n",
    "    z_hi = lin_hi / T\n",
    "    m_lo = float(np.max(z_lo))\n",
    "    m_hi = float(np.max(z_hi))\n",
    "\n",
    "    w_lo = float(np.min(z_lo - m_hi))  # <= 0\n",
    "    u_lo = float(np.exp(max(-700.0, w_lo)))\n",
    "    s_lo = max(1e-12, K * u_lo)\n",
    "    s_hi = float(K * 1.0)\n",
    "    v_lo = float(np.log(s_lo))\n",
    "    v_hi = float(np.log(s_hi))\n",
    "\n",
    "    yN_lo = float(T * (m_lo + v_lo))\n",
    "    yN_hi = float(T * (m_hi + v_hi))\n",
    "    y_lo  = float(ys * yN_lo + ym)\n",
    "    y_hi  = float(ys * yN_hi + ym)\n",
    "\n",
    "    M = gp.Model(\"lset_ip_stable_bounded\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    M.Params.FuncNonlinear = 1\n",
    "    M.Params.FeasibilityTol = 1e-9\n",
    "    M.Params.OptimalityTol  = 1e-9\n",
    "    M.Params.IntFeasTol     = 1e-9\n",
    "    M.Params.NumericFocus   = 3\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    z = M.addVars(K, lb=z_lo.tolist(), ub=z_hi.tolist(), vtype=GRB.CONTINUOUS, name=\"z\")\n",
    "    for k in range(K):\n",
    "        lin = beff[k] + gp.quicksum(Aeff[k, j] * x[j] for j in range(d) if Aeff[k, j] != 0.0)\n",
    "        M.addConstr(z[k] == lin / T, name=f\"zdef_{k}\")\n",
    "\n",
    "    m = M.addVar(lb=m_lo, ub=m_hi, vtype=GRB.CONTINUOUS, name=\"m\")\n",
    "    w = M.addVars(K, lb=w_lo, ub=0.0, vtype=GRB.CONTINUOUS, name=\"w\")\n",
    "    for k in range(K):\n",
    "        M.addConstr(m >= z[k], name=f\"m_ge_z_{k}\")\n",
    "        M.addConstr(w[k] == z[k] - m, name=f\"wdef_{k}\")\n",
    "\n",
    "    u = M.addVars(K, lb=0.0, ub=1.0, vtype=GRB.CONTINUOUS, name=\"u\")\n",
    "    for k in range(K):\n",
    "        M.addGenConstrExp(w[k], u[k], name=f\"exp_{k}\")\n",
    "\n",
    "    s = M.addVar(lb=s_lo, ub=s_hi, vtype=GRB.CONTINUOUS, name=\"s\")\n",
    "    M.addConstr(s == gp.quicksum(u[k] for k in range(K)), name=\"sumexp_shifted\")\n",
    "\n",
    "    v = M.addVar(lb=v_lo, ub=v_hi, vtype=GRB.CONTINUOUS, name=\"v\")\n",
    "    M.addGenConstrLog(s, v, name=\"log_shifted\")\n",
    "\n",
    "    y_norm = M.addVar(lb=yN_lo, ub=yN_hi, vtype=GRB.CONTINUOUS, name=\"y_norm\")\n",
    "    M.addConstr(y_norm == T * (m + v), name=\"y_norm_def\")\n",
    "\n",
    "    y_raw = M.addVar(lb=y_lo, ub=y_hi, vtype=GRB.CONTINUOUS, name=\"y_raw\")\n",
    "    M.addConstr(y_raw == ys * y_norm + ym, name=\"y_raw_def\")\n",
    "\n",
    "    M.setObjective(y_raw, GRB.MINIMIZE)\n",
    "    M.optimize()\n",
    "\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution found. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], dtype=float)\n",
    "    y_star = float(y_raw.X)\n",
    "\n",
    "    info = {\n",
    "        \"status\": M.Status,\n",
    "        \"runtime\": M.Runtime,\n",
    "        \"gap\": getattr(M, \"MIPGap\", None),\n",
    "        \"sol_count\": M.SolCount,\n",
    "        \"obj_gurobi\": float(M.ObjVal),\n",
    "        \"obj_bound\": float(getattr(M, \"ObjBound\", float(\"nan\"))),\n",
    "    }\n",
    "    return x_star, y_star, info\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4858979",
   "metadata": {},
   "source": [
    "## Test Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569f2165",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_ip(model_type, model, scaler, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    if model_type == \"DFN\":\n",
    "        return solve_dfn_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if model_type == \"MLP\":\n",
    "        return solve_mlp_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if model_type == \"LSET\":\n",
    "        return solve_lset_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    raise ValueError(model_type)\n",
    "\n",
    "\n",
    "def suppress_stdout(fn, *args, silence: bool = False, **kwargs):\n",
    "    if not silence:\n",
    "        return fn(*args, **kwargs), \"\"\n",
    "    buf = io.StringIO()\n",
    "    with contextlib.redirect_stdout(buf):\n",
    "        out = fn(*args, **kwargs)\n",
    "    return out, buf.getvalue()\n",
    "\n",
    "\n",
    "def make_obj(model, scaler, device, chunk: int = 4096):\n",
    "    xm = scaler[\"x_mean\"].to(device)\n",
    "    xs = scaler[\"x_std\"].to(device)\n",
    "    ym = scaler[\"y_mean\"].to(device)\n",
    "    ys = scaler[\"y_std\"].to(device)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def obj(Xraw):\n",
    "        Xraw = torch.as_tensor(Xraw, dtype=torch.float32, device=device)\n",
    "        if Xraw.dim() == 1:\n",
    "            Xraw = Xraw.unsqueeze(0)\n",
    "\n",
    "        outs = []\n",
    "        B = int(Xraw.shape[0])\n",
    "        for i in range(0, B, int(chunk)):\n",
    "            Xb = Xraw[i:i+chunk]\n",
    "            try:\n",
    "                Xn = (Xb - xm) / xs\n",
    "                yn = model(Xn)\n",
    "                yn = torch.as_tensor(yn).reshape(-1)  \n",
    "                y  = (yn * ys + ym).reshape(-1).detach().cpu() \n",
    "                if y.numel() != Xb.shape[0]:\n",
    "                    y = y.expand(Xb.shape[0]).contiguous()\n",
    "                outs.append(y)\n",
    "            except Exception as e:\n",
    "                obj._err_count += int(Xb.shape[0])\n",
    "                if obj._err_first is None:\n",
    "                    obj._err_first = repr(e)\n",
    "                outs.append(torch.full((int(Xb.shape[0]),), float(\"inf\"), device=\"cpu\"))\n",
    "\n",
    "        return torch.cat(outs, 0).numpy()\n",
    "\n",
    "    obj._err_count = 0\n",
    "    obj._err_first = None\n",
    "    return obj\n",
    "\n",
    "\n",
    "def safe_raw_mse(obj, Xraw, yraw):\n",
    "    yp = np.asarray(obj(Xraw), dtype=float).reshape(-1)\n",
    "    yt = yraw.detach().cpu().numpy().reshape(-1) if torch.is_tensor(yraw) else np.asarray(yraw, float).reshape(-1)\n",
    "    if yp.shape[0] != yt.shape[0]:\n",
    "        return np.nan, f\"shape mismatch: pred {yp.shape} vs true {yt.shape}\"\n",
    "    mask = np.isfinite(yp)\n",
    "    if mask.sum() == 0:\n",
    "        return np.nan, \"all predictions were non-finite (inf/nan)\"\n",
    "    return float(np.mean((yp[mask] - yt[mask])**2)), None\n",
    "\n",
    "\n",
    "def safe_norm_mse(model, scaler, device, Xraw, yraw, *, chunk=4096):\n",
    "    xm = scaler[\"x_mean\"].to(device)\n",
    "    xs = scaler[\"x_std\"].to(device)\n",
    "    ym = scaler[\"y_mean\"].to(device)\n",
    "    ys = scaler[\"y_std\"].to(device)\n",
    "\n",
    "    Xraw = torch.as_tensor(Xraw, dtype=torch.float32, device=device)\n",
    "    if Xraw.dim() == 1:\n",
    "        Xraw = Xraw.unsqueeze(0)\n",
    "\n",
    "    yraw_t = yraw\n",
    "    if torch.is_tensor(yraw_t):\n",
    "        yraw_t = yraw_t.to(device=device, dtype=torch.float32)\n",
    "    else:\n",
    "        yraw_t = torch.as_tensor(yraw_t, dtype=torch.float32, device=device)\n",
    "    yraw_t = yraw_t.reshape(-1)\n",
    "\n",
    "    preds = []\n",
    "    B = int(Xraw.shape[0])\n",
    "    try:\n",
    "        with torch.no_grad():\n",
    "            for i in range(0, B, int(chunk)):\n",
    "                Xb = Xraw[i:i+chunk]\n",
    "                Xn = (Xb - xm) / xs\n",
    "                yn = model(Xn)\n",
    "                yn = torch.as_tensor(yn).reshape(-1)\n",
    "                if yn.numel() != Xb.shape[0]:\n",
    "                    yn = yn.expand(Xb.shape[0]).contiguous()\n",
    "                preds.append(yn.detach().cpu())\n",
    "    except Exception as e:\n",
    "        return np.nan, f\"model forward failed in safe_norm_mse: {repr(e)}\"\n",
    "\n",
    "    yp = torch.cat(preds, 0).numpy().reshape(-1)\n",
    "    yt = ((yraw_t - ym) / ys).detach().cpu().numpy().reshape(-1)\n",
    "\n",
    "    if yp.shape[0] != yt.shape[0]:\n",
    "        return np.nan, f\"shape mismatch: pred {yp.shape} vs true {yt.shape}\"\n",
    "\n",
    "    mask = np.isfinite(yp) & np.isfinite(yt)\n",
    "    if mask.sum() == 0:\n",
    "        return np.nan, \"all predictions or targets were non-finite (inf/nan)\"\n",
    "\n",
    "    return float(np.mean((yp[mask] - yt[mask])**2)), None\n",
    "\n",
    "\n",
    "def t_to_best(hist, y_best):\n",
    "    for r in hist:\n",
    "        if abs(float(r.get(\"best_y\", np.inf)) - float(y_best)) < 1e-12:\n",
    "            return float(r.get(\"t\", float(\"nan\")))\n",
    "    return float(\"nan\")\n",
    "\n",
    "\n",
    "def check_ip_matches_obj(name, obj, x_ip, y_ip, *, strict: bool, tol: float):\n",
    "    y_obj = float(np.asarray(obj(np.asarray(x_ip)), dtype=float).reshape(-1)[0])\n",
    "    y_ip  = float(y_ip)\n",
    "    rel = abs(y_obj - y_ip) / (abs(y_obj) + 1e-12)\n",
    "    print(f\"[CHECK {name}] obj(x_ip)={y_obj:.6g}  ip_y={y_ip:.6g}  rel_err={rel:.3e}\")\n",
    "    if strict and rel > tol:\n",
    "        raise RuntimeError(f\"{name}: IP objective != obj() (rel_err={rel:.3e})\")\n",
    "    return rel\n",
    "\n",
    "\n",
    "def mean_se(x):\n",
    "    s = pd.Series(x)\n",
    "    \n",
    "    x_num = pd.to_numeric(s, errors=\"coerce\").to_numpy()\n",
    "    x_num = x_num[np.isfinite(x_num)]\n",
    "    n = int(x_num.shape[0])\n",
    "    \n",
    "    if n == 0:\n",
    "        return np.nan, np.nan\n",
    "\n",
    "    m = float(x_num.mean())\n",
    "    std = float(x_num.std(ddof=1)) if n > 1 else 0.0\n",
    "    se = float(std / np.sqrt(n)) if n > 1 else 0.0\n",
    "\n",
    "    return m, se\n",
    "\n",
    "\n",
    "\n",
    "def fmt_mean_se(m, se):\n",
    "    if not np.isfinite(m):\n",
    "        return \"nan\"\n",
    "    if not np.isfinite(se):\n",
    "        return f\"{m:.6g}\"\n",
    "    return f\"{m:.6g} ± {se:.3g}\"\n",
    "\n",
    "\n",
    "def repr_solution(xs: pd.Series, seeds=None):\n",
    "    xs = xs.dropna().astype(str)\n",
    "    xs = xs[xs != \"None\"]\n",
    "    if xs.empty:\n",
    "        return None\n",
    "\n",
    "    if seeds is not None:\n",
    "        df = pd.DataFrame({\"seed\": pd.Series(seeds), \"x\": xs})\n",
    "        df = df.dropna()\n",
    "        df[\"x\"] = df[\"x\"].astype(str)\n",
    "        if df.empty:\n",
    "            return None\n",
    "        if df[\"x\"].nunique() == 1:\n",
    "            return df[\"x\"].iloc[0]\n",
    "        df = df.sort_values(\"seed\")\n",
    "        return \"\\n\".join([f\"seed={int(r.seed)}: {r.x}\" for r in df.itertuples(index=False)])\n",
    "\n",
    "    if xs.nunique() == 1:\n",
    "        return xs.iloc[0]\n",
    "    vc = xs.value_counts()\n",
    "    top = vc.index[0]\n",
    "    n_unique = int(vc.shape[0])\n",
    "    return f\"{top} (+{n_unique-1} other)\"\n",
    "    \n",
    "\n",
    "# -----------------------------\n",
    "# Ground-truth optimum solvers\n",
    "# -----------------------------\n",
    "\n",
    "def solve_true_opt_quadratic(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    Q  = np.asarray(gt[\"Q\"], float)\n",
    "    xs = np.asarray(gt[\"x_star\"], float).reshape(-1)\n",
    "    n  = int(xs.shape[0])\n",
    "    assert Q.shape == (n, n)\n",
    "\n",
    "    m = gp.Model(\"gt_quadratic\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the (ground-truth) solve reproducible across runs\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    x = m.addVars(n, vtype=GRB.INTEGER, name=\"x\")\n",
    "    for i in range(n):\n",
    "        x[i].LB = int(xmin[i])\n",
    "        x[i].UB = int(xmax[i])\n",
    "    m.addConstr(gp.quicksum(x[i] for i in range(n)) == int(sum_eq), name=\"sum_eq\")\n",
    "\n",
    "    expr = gp.quicksum(float(Q[i, j]) * (x[i] - float(xs[i])) * (x[j] - float(xs[j])) for i in range(n) for j in range(n))\n",
    "    m.setObjective(expr, GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT quadratic solve failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(x[i].X)) for i in range(n)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status)}\n",
    "\n",
    "\n",
    "def solve_true_opt_assignment(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    C = np.asarray(gt[\"C\"], float)\n",
    "    n_rows, n_cols = C.shape\n",
    "    x_max_right = int(gt.get(\"x_max\", int(np.max(xmax))))\n",
    "    k = int(sum_eq)\n",
    "\n",
    "    xmin = np.asarray(xmin, int).ravel()\n",
    "    xmax = np.asarray(xmax, int).ravel()\n",
    "\n",
    "    m = gp.Model(\"gt_bmatching\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    f = m.addVars(n_rows, n_cols, vtype=GRB.INTEGER, lb=0.0, name=\"f\")\n",
    "\n",
    "    for i in range(n_rows):\n",
    "        m.addConstr(gp.quicksum(f[i, j] for j in range(n_cols)) >= int(xmin[i]))\n",
    "        m.addConstr(gp.quicksum(f[i, j] for j in range(n_cols)) <= int(xmax[i]))\n",
    "\n",
    "    # right capacities (b-matching capacity)\n",
    "    for j in range(n_cols):\n",
    "        m.addConstr(gp.quicksum(f[i, j] for i in range(n_rows)) <= x_max_right)\n",
    "\n",
    "    # total units must equal sum_eq\n",
    "    m.addConstr(gp.quicksum(f[i, j] for i in range(n_rows) for j in range(n_cols)) == k)\n",
    "\n",
    "    m.setObjective(gp.quicksum(float(C[i, j]) * f[i, j] for i in range(n_rows) for j in range(n_cols)), GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT b-matching solve failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(sum(f[i, j].X for j in range(n_cols)))) for i in range(n_rows)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status)}\n",
    "\n",
    "\n",
    "def solve_true_opt_mdvsp(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    N  = int(gt[\"N\"])\n",
    "    SS = int(gt[\"SS\"])\n",
    "    TT = int(gt[\"TT\"])\n",
    "    src  = np.asarray(gt[\"src\"], np.int64)\n",
    "    dst  = np.asarray(gt[\"dst\"], np.int64)\n",
    "    cost = np.asarray(gt[\"cost\"], float)\n",
    "    cap0 = np.asarray(gt[\"cap0\"], float)\n",
    "    idxSS = np.asarray(gt[\"idxSS\"], np.int64)\n",
    "    idxTT = np.asarray(gt[\"idxTT\"], np.int64)\n",
    "\n",
    "    E = int(src.shape[0])\n",
    "    k = int(idxSS.shape[0])\n",
    "    assert idxTT.shape[0] == k, \"Expect idxSS and idxTT to be same length\"\n",
    "\n",
    "    out_edges = [[] for _ in range(N)]\n",
    "    in_edges  = [[] for _ in range(N)]\n",
    "    for e in range(E):\n",
    "        out_edges[int(src[e])].append(e)\n",
    "        in_edges[int(dst[e])].append(e)\n",
    "\n",
    "    idxSS = idxSS.astype(int)\n",
    "    idxTT = idxTT.astype(int)\n",
    "    var_arc_to_i = {int(idxSS[i]): i for i in range(k)}\n",
    "    var_arc_to_i.update({int(idxTT[i]): i for i in range(k)})\n",
    "    var_arcs = sorted(var_arc_to_i.keys())\n",
    "\n",
    "    m = gp.Model(\"gt_mdvsp_true_opt\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the (ground-truth) solve reproducible across runs\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    m.Params.NonConvex = 2\n",
    "\n",
    "    x = m.addVars(k, vtype=GRB.INTEGER, name=\"x\")\n",
    "    for i in range(k):\n",
    "        x[i].LB = int(xmin[i])\n",
    "        x[i].UB = int(xmax[i])\n",
    "    m.addConstr(gp.quicksum(x[i] for i in range(k)) == int(sum_eq), name=\"sum_eq\")\n",
    "\n",
    "    f = m.addVars(E, vtype=GRB.CONTINUOUS, lb=0.0, name=\"f\")\n",
    "\n",
    "    for e in range(E):\n",
    "        if e in var_arc_to_i:\n",
    "            i = var_arc_to_i[e]\n",
    "            m.addConstr(f[e] <= x[i], name=f\"cap_var_{e}\")\n",
    "        else:\n",
    "            m.addConstr(f[e] <= float(cap0[e]), name=f\"cap_fix_{e}\")\n",
    "\n",
    "    F = m.addVar(vtype=GRB.CONTINUOUS, lb=0.0, name=\"F\")\n",
    "\n",
    "    # flow conservation: out - in = F at SS, = -F at TT, else 0\n",
    "    for v in range(N):\n",
    "        out_sum = gp.quicksum(f[e] for e in out_edges[v])\n",
    "        in_sum  = gp.quicksum(f[e] for e in in_edges[v])\n",
    "        rhs = F if v == SS else (-F if v == TT else 0.0)\n",
    "        m.addConstr(out_sum - in_sum == rhs, name=f\"flow_{v}\")\n",
    "\n",
    "    y = m.addVars(N, vtype=GRB.CONTINUOUS, lb=0.0, ub=1.0, name=\"y\")\n",
    "    m.addConstr(y[SS] == 1.0, name=\"ySS\")\n",
    "    m.addConstr(y[TT] == 0.0, name=\"yTT\")\n",
    "\n",
    "    z = m.addVars(E, vtype=GRB.CONTINUOUS, lb=0.0, ub=1.0, name=\"z\")\n",
    "    for e in range(E):\n",
    "        m.addConstr(z[e] >= y[int(src[e])] - y[int(dst[e])], name=f\"dual_{e}\")\n",
    "\n",
    "    dual_obj = gp.QuadExpr()\n",
    "    for e in range(E):\n",
    "        if e in var_arc_to_i:\n",
    "            i = var_arc_to_i[e]\n",
    "            dual_obj += x[i] * z[e]\n",
    "        else:\n",
    "            dual_obj += float(cap0[e]) * z[e]\n",
    "    m.addConstr(F == dual_obj, name=\"strong_duality\")\n",
    "\n",
    "    m.setObjective(gp.quicksum(float(cost[e]) * f[e] for e in range(E)), GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT mdvsp true-opt failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(x[i].X)) for i in range(k)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status), \"F\": float(F.X)}\n",
    "    \n",
    "\n",
    "\n",
    "def solve_true_opt(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    \"\"\"Top-level dispatcher for ground-truth optimum.\"\"\"\n",
    "    if not isinstance(gt, dict):\n",
    "        raise RuntimeError(\"No ground-truth structure found (gt must be a dict).\")\n",
    "    t = gt.get(\"type\", None)\n",
    "    if t == \"quadratic\":\n",
    "        return solve_true_opt_quadratic(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if t == \"assignment\":\n",
    "        return solve_true_opt_assignment(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if t == \"mdvsp\":\n",
    "        return solve_true_opt_mdvsp(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    raise RuntimeError(f\"Unknown gt type: {t}\")\n",
    "\n",
    "\n",
    "# -----------------------------\n",
    "# Benchmark runner\n",
    "# -----------------------------\n",
    "def run_benchmark(\n",
    "    *,\n",
    "    dataset_type: str,\n",
    "    dataset_params: dict,\n",
    "    runs: list[tuple[str, str, dict]],\n",
    "    train_base: dict,\n",
    "    lr_map: dict,\n",
    "    xmin: np.ndarray,\n",
    "    xmax: np.ndarray,\n",
    "    n_seeds: int = 1,\n",
    "    strict_ip_check: bool = False,\n",
    "    ip_check_tol: float = 1e-4,\n",
    "    allow_plots_multi_seed: bool = True,\n",
    "    time_limit=None,\n",
    "):\n",
    "    seeds = list(range(int(n_seeds)))\n",
    "\n",
    "    learn_rows, opt_rows, spec_rows, fail_rows = [], [], [], []\n",
    "    gt_rows = []\n",
    "\n",
    "    gt_cache_by_seed = {}\n",
    "\n",
    "    in_dim = int(np.asarray(xmin).size)\n",
    "\n",
    "    for seed in seeds:\n",
    "        print(f\"\\n\\n===================== SEED {seed} =====================\")\n",
    "        rng = np.random.default_rng(int(seed))\n",
    "        if dataset_type == \"assignment\":\n",
    "            x0_seed = rng.integers(0, 30, size=in_dim, dtype=int)\n",
    "        elif dataset_type == \"mdvsp\":\n",
    "            x0_seed = rng.integers(0, 15, size=in_dim, dtype=int)\n",
    "        elif dataset_type == \"quadratic\":\n",
    "            x0_seed = rng.integers(-30, 30, size=in_dim, dtype=int)\n",
    "        else:\n",
    "            lo = int(np.min(np.asarray(xmin)))\n",
    "            hi = int(np.max(np.asarray(xmax))) + 1\n",
    "            x0_seed = rng.integers(lo, hi, size=in_dim, dtype=int)\n",
    "\n",
    "        x0_seed = np.asarray(x0_seed, int).ravel()\n",
    "        sum_eq_seed = int(x0_seed.sum())\n",
    "\n",
    "\n",
    "        for name, model_type, model_params_base in runs:\n",
    "            # ---- per-run params ----\n",
    "            dp = dict(dataset_params)\n",
    "            dp[\"seed\"] = int(seed)\n",
    "\n",
    "            mp = dict(model_params_base)\n",
    "            if \"seed\" in mp:\n",
    "                mp[\"seed\"] = int(seed)\n",
    "\n",
    "            tp = dict(train_base)\n",
    "            tp[\"seed\"] = int(seed)\n",
    "            tp[\"lr\"] = float(lr_map[model_type])\n",
    "\n",
    "            if (n_seeds > 1) and (tp.get(\"plot_every\", 0) not in (0, None)) and (not allow_plots_multi_seed):\n",
    "                tp[\"plot_every\"] = 0\n",
    "\n",
    "            # ---- TRAIN ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                out = generate_and_train_simple(dataset_type, dp, model_type, mp, tp)\n",
    "            except Exception as e:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"TRAIN\", error=repr(e)))\n",
    "                learn_rows.append(dict(seed=seed, model=name, train_time=np.nan, best_epoch=np.nan,\n",
    "                                       best_val=np.nan, test=np.nan, train_err=repr(e)))\n",
    "                continue\n",
    "            train_time = time.perf_counter() - t0\n",
    "\n",
    "            if isinstance(out, tuple) and len(out) == 4:\n",
    "                model, data, hist, spec = out\n",
    "            elif isinstance(out, tuple) and len(out) == 3:\n",
    "                model, data, hist = out\n",
    "                n_params = sum(p.numel() for p in model.parameters())\n",
    "                spec = dict(n_params=int(n_params), extra=\"\", train_params=tp)\n",
    "            else:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"TRAIN\", error=f\"Unexpected return: {type(out)}\"))\n",
    "                continue\n",
    "\n",
    "            # ---- model spec (seed 0 only) ----\n",
    "            if seed == seeds[0]:\n",
    "                spec_rows.append(dict(\n",
    "                    model=name,\n",
    "                    n_params=int(spec.get(\"n_params\", np.nan)),\n",
    "                    details=str(spec.get(\"extra\", \"\")),\n",
    "                    lr=float(spec.get(\"train_params\", {}).get(\"lr\", tp[\"lr\"])),\n",
    "                    batch_size=int(spec.get(\"train_params\", {}).get(\"batch_size\", tp[\"batch_size\"])),\n",
    "                    epochs=int(spec.get(\"train_params\", {}).get(\"epochs\", tp[\"epochs\"])),\n",
    "                ))\n",
    "\n",
    "            device = data[\"device\"]\n",
    "            scaler = data[\"scaler\"]\n",
    "            obj = make_obj(model, scaler, device, chunk=int(tp.get(\"plot_chunk\", 4096)))\n",
    "            gt = data.get(\"true\", None)\n",
    "\n",
    "            # ---- compute GT optimum (once per seed) ----\n",
    "            if seed not in gt_cache_by_seed:\n",
    "                t_gt0 = time.perf_counter()\n",
    "                try:\n",
    "                    x_gt, y_gt, gt_info = solve_true_opt(gt, xmin, xmax, sum_eq_seed, time_limit=time_limit, verbose=False, seed=seed)\n",
    "                    gt_time = time.perf_counter() - t_gt0\n",
    "                    y_gt_check = float(eval_true_obj(gt, x_gt))\n",
    "                    if np.isfinite(y_gt_check) and abs(y_gt_check - float(y_gt)) / (abs(float(y_gt_check)) + 1e-12) > 1e-6:\n",
    "                        fail_rows.append(dict(seed=seed, model=name, stage=\"GT_OPT_CHECK\",\n",
    "                                              error=f\"gt solver mismatch: solver={y_gt:.6g} eval_true_obj={y_gt_check:.6g}\"))\n",
    "                    gt_cache_by_seed[seed] = dict(x=str(np.asarray(x_gt, int).tolist()),\n",
    "                                                  true_y=float(y_gt_check if np.isfinite(y_gt_check) else y_gt),\n",
    "                                                  runtime=float(gt_time),\n",
    "                                                  err=None)\n",
    "                except Exception as e:\n",
    "                    gt_time = time.perf_counter() - t_gt0\n",
    "                    gt_cache_by_seed[seed] = dict(x=None, true_y=np.nan, runtime=float(gt_time), err=repr(e))\n",
    "                    fail_rows.append(dict(seed=seed, model=name, stage=\"GT_OPT\", error=repr(e)))\n",
    "\n",
    "            # ---- learning metrics: best val (normalized) + test loss (normalized) ----\n",
    "            test_norm, err_te = safe_norm_mse(model, scaler, device, data['raw']['Xte'], data['raw']['yte'], chunk=int(tp.get('plot_chunk', 4096)))\n",
    "            if err_te:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"EVAL_TEST\", error=err_te))\n",
    "\n",
    "            if obj._err_count > 0:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"OBJ\",\n",
    "                                      error=f\"obj() had {obj._err_count} forward failures; first={obj._err_first}\"))\n",
    "\n",
    "            val_curve = np.asarray(hist.get(\"val_mse_norm\", []), dtype=float)\n",
    "            best_ep = int(np.argmin(val_curve) + 1) if val_curve.size else np.nan\n",
    "            best_val = float(np.min(val_curve)) if val_curve.size else np.nan\n",
    "\n",
    "            learn_rows.append(dict(\n",
    "                seed=seed, model=name,\n",
    "                train_time=float(train_time),\n",
    "                best_epoch=best_ep,\n",
    "                best_val=float(best_val),\n",
    "                test=float(test_norm) if np.isfinite(test_norm) else np.nan,\n",
    "                train_err=(err_te),\n",
    "            ))\n",
    "\n",
    "            # ---- IP SOLVE (learned objective) ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                x_ip, y_ip, info = solve_ip(\n",
    "                    model_type, model, scaler, xmin, xmax, sum_eq_seed,\n",
    "                    time_limit=time_limit, verbose=True, seed=seed\n",
    "                )\n",
    "                ip_time = time.perf_counter() - t0\n",
    "\n",
    "                rel_err = check_ip_matches_obj(name, obj, x_ip, y_ip, strict=strict_ip_check, tol=ip_check_tol)\n",
    "                if rel_err > ip_check_tol:\n",
    "                    fail_rows.append(dict(seed=seed, model=name, stage=\"IP_CHECK\",\n",
    "                                          error=f\"rel_err={rel_err:.3e} (tol={ip_check_tol})\"))\n",
    "\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"IP\",\n",
    "                    x=str(np.asarray(x_ip, int).tolist()),\n",
    "                    y=float(y_ip),\n",
    "                    true_y=float(eval_true_obj(gt, x_ip)),\n",
    "                    runtime=float(ip_time),\n",
    "                    status=(info.get(\"status\") if isinstance(info, dict) else None),\n",
    "                    gap=(info.get(\"gap\") if isinstance(info, dict) else None),\n",
    "                    ip_rel_err=float(rel_err),\n",
    "                    err=None,\n",
    "                ))\n",
    "            except Exception as e:\n",
    "                ip_time = time.perf_counter() - t0\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"IP\",\n",
    "                    x=None, y=np.nan, true_y=np.nan,\n",
    "                    runtime=float(ip_time),\n",
    "                    status=None, gap=None, ip_rel_err=np.nan,\n",
    "                    err=repr(e),\n",
    "                ))\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"IP\", error=repr(e)))\n",
    "\n",
    "    # ---- build DFs ----\n",
    "    spec_df  = pd.DataFrame(spec_rows).drop_duplicates(\"model\").sort_values(\"model\").reset_index(drop=True)\n",
    "    learn_df = pd.DataFrame(learn_rows).sort_values([\"model\", \"seed\"]).reset_index(drop=True)\n",
    "    opt_df   = pd.DataFrame(opt_rows).sort_values([\"model\", \"seed\", \"method\"]).reset_index(drop=True)\n",
    "\n",
    "    fail_df = pd.DataFrame(fail_rows)\n",
    "    if fail_df.empty:\n",
    "        fail_df = pd.DataFrame(columns=[\"seed\", \"model\", \"stage\", \"error\"])\n",
    "    else:\n",
    "        for c in [\"stage\", \"model\", \"seed\"]:\n",
    "            if c not in fail_df.columns:\n",
    "                fail_df[c] = np.nan\n",
    "        fail_df = fail_df.sort_values([\"stage\", \"model\", \"seed\"]).reset_index(drop=True)\n",
    "\n",
    "    gt_df = pd.DataFrame([dict(seed=s, **d) for s, d in sorted(gt_cache_by_seed.items())]).sort_values(\"seed\")\n",
    "\n",
    "    # ---- per-seed gaps ----\n",
    "    gap_seed_df = pd.DataFrame(columns=[\"seed\", \"model\", \"ip_true_over_gt\"])\n",
    "    if not opt_df.empty:\n",
    "        pivot_y = opt_df.pivot_table(index=[\"seed\", \"model\"], columns=\"method\", values=\"y\", aggfunc=\"first\")\n",
    "        pivot_true = opt_df.pivot_table(index=[\"seed\", \"model\"], columns=\"method\", values=\"true_y\", aggfunc=\"first\")\n",
    "\n",
    "        # IP true vs GT optimum true\n",
    "        gt_true = gt_df.set_index(\"seed\")[\"true_y\"] if not gt_df.empty else pd.Series(dtype=float)\n",
    "        ip_true_over_gt = []\n",
    "        for (seed, model), row in pivot_true.iterrows():\n",
    "            ipt = float(row.get(\"IP\", np.nan))\n",
    "            gtt = float(gt_true.get(seed, np.nan))\n",
    "            if np.isfinite(ipt) and np.isfinite(gtt):\n",
    "                ip_true_over_gt.append(((seed, model), ipt / (gtt + 1e-12)))\n",
    "            else:\n",
    "                ip_true_over_gt.append(((seed, model), np.nan))\n",
    "        ip_true_over_gt = pd.Series({k: v for k, v in ip_true_over_gt})\n",
    "        \n",
    "        gap_seed_df = pd.DataFrame({\n",
    "            \"seed\": [k[0] for k in ip_true_over_gt.index],\n",
    "            \"model\": [k[1] for k in ip_true_over_gt.index],\n",
    "            \"ip_true_over_gt\": [float(ip_true_over_gt.get(k, np.nan)) for k in ip_true_over_gt.index],\n",
    "        })\n",
    "\n",
    "\n",
    "    # ---- LEARNING SUMMARY (mean ± SE over seeds) ----\n",
    "    learn_sum_rows = []\n",
    "    for model in sorted(learn_df[\"model\"].unique()):\n",
    "        sub = learn_df[learn_df[\"model\"] == model]\n",
    "        m, se = mean_se(sub[\"train_time\"]); train_time_s = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(sub[\"best_val\"]);   best_val_s   = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(sub[\"test\"]);       test_s       = fmt_mean_se(m, se)\n",
    "        learn_sum_rows.append(dict(model=model, train_time=train_time_s, best_val=best_val_s, test=test_s))\n",
    "    learn_summary_df = pd.DataFrame(learn_sum_rows).sort_values(\"model\").reset_index(drop=True)\n",
    "\n",
    "    # ---- OPTIMIZATION SUMMARY ----\n",
    "    opt_sum_rows = []\n",
    "    gt_x_repr = (repr_solution(gt_df.get(\"x\", pd.Series(dtype=str)), gt_df.get(\"seed\", None))\n",
    "                if not gt_df.empty else None)\n",
    "    m, se = mean_se(gt_df.get(\"true_y\", pd.Series(dtype=float))); gt_true_s = fmt_mean_se(m, se)\n",
    "    m, se = mean_se(gt_df.get(\"runtime\", pd.Series(dtype=float))); gt_time_s = fmt_mean_se(m, se)\n",
    "\n",
    "    for model in sorted(opt_df[\"model\"].unique()):\n",
    "        sub = opt_df[opt_df[\"model\"] == model]\n",
    "        row = {\"model\": model}\n",
    "\n",
    "        ip = sub[sub[\"method\"] == \"IP\"]\n",
    "\n",
    "        row[\"IP_x\"] = repr_solution(ip[\"x\"], ip.get(\"seed\", None))\n",
    "        m, se = mean_se(ip[\"y\"]);       row[\"IP_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ip[\"true_y\"]);  row[\"IP_true_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ip[\"runtime\"]); row[\"IP_time\"] = fmt_mean_se(m, se)\n",
    "\n",
    "        row[\"GT_x\"] = gt_x_repr\n",
    "        row[\"GT_true_y\"] = gt_true_s\n",
    "        row[\"GT_time\"] = gt_time_s\n",
    "\n",
    "        gsub = gap_seed_df[gap_seed_df[\"model\"] == model] if not gap_seed_df.empty else pd.DataFrame()\n",
    "        m, se = mean_se(gsub.get(\"ip_true_over_gt\", pd.Series(dtype=float))); row[\"IP_true/GT\"] = fmt_mean_se(m, se)\n",
    "\n",
    "        opt_sum_rows.append(row)\n",
    "\n",
    "    opt_summary_df = pd.DataFrame(opt_sum_rows).sort_values(\"model\").reset_index(drop=True)\n",
    "\n",
    "    # ---- Print tables ----\n",
    "    print(\"\\n=== MODEL SPECS (from seed 0 run) ===\")\n",
    "    if not spec_df.empty:\n",
    "        print(spec_df[[\"model\", \"n_params\", \"details\", \"lr\", \"batch_size\", \"epochs\"]].to_string(index=False))\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== LEARNING SUMMARY (mean ± SE over seeds) ===\")\n",
    "    if not learn_summary_df.empty:\n",
    "        print(learn_summary_df[[\"model\", \"test\"]].to_string(index=False))\n",
    "\n",
    "\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== OPTIMIZATION SUMMARY (mean ± SE over seeds) ===\")\n",
    "    if not opt_summary_df.empty:\n",
    "        cols = [\n",
    "            \"model\", \"IP_time\", \"IP_true/GT\",\n",
    "        ]\n",
    "        print(opt_summary_df[cols].to_string(index=False))\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== FAILURES / WARNINGS (if any) ===\")\n",
    "    if fail_df.shape[0] == 0:\n",
    "        print(\"None\")\n",
    "    else:\n",
    "        print(fail_df.to_string(index=False))\n",
    "\n",
    "    return spec_df, learn_df, opt_df, fail_df, learn_summary_df, opt_summary_df, gt_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e88e7a1f",
   "metadata": {},
   "source": [
    "## Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac79bcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===================== ASSIGNMENT DATASET =====================\n",
    "N_SEEDS = 5\n",
    "\n",
    "IP_CHECK_TOL = 1e-4\n",
    "\n",
    "dataset_type = \"assignment\"\n",
    "dataset_params = dict(\n",
    "    K=2000,\n",
    "    num_nodes=16,\n",
    "    c_min=1,\n",
    "    c_max=1000,\n",
    "    x_max=50,\n",
    "    noise_std=0.0,\n",
    "    seed=0,\n",
    ")\n",
    "\n",
    "in_dim = int(dataset_params[\"num_nodes\"])\n",
    "print(f\"[ASSIGNMENT] in_dim={in_dim}\")\n",
    "\n",
    "train_base = dict(\n",
    "    epochs=1000,\n",
    "    batch_size=8,\n",
    "    val_frac=0.15,\n",
    "    test_frac=0.15,\n",
    "    seed=0,\n",
    "    device=\"cpu\",\n",
    "    eps=1e-8,\n",
    ")\n",
    "\n",
    "dfn_params = dict(\n",
    "    input_dim=in_dim, layer_sizes=[32, 100, 32], p_list=[1, 1],\n",
    "    seed=0, alpha=5e-3, beta=-2.0\n",
    ")\n",
    "mlp_params  = dict(in_dim=in_dim, hidden_dims=[150, 150], out_dim=1)\n",
    "lset_params = dict(in_dim=in_dim, n_pieces=1570, T=0.05)\n",
    "\n",
    "lr_map = dict(DFN=1e-2, MLP=1e-3, LSET=1e-3)\n",
    "time_limit = 3600\n",
    "\n",
    "xmin  = np.full(in_dim, 0, dtype=int)\n",
    "xmax  = np.full(in_dim, 50, dtype=int)\n",
    "\n",
    "runs = [\n",
    "    (\"DFN\",        \"DFN\", dfn_params),\n",
    "    (\"MLP\",        \"MLP\", mlp_params),\n",
    "    (\"LSET\",       \"LSET\", lset_params),\n",
    "]\n",
    "\n",
    "_ = run_benchmark(\n",
    "    dataset_type=dataset_type,\n",
    "    dataset_params=dataset_params,\n",
    "    runs=runs,\n",
    "    train_base=train_base,\n",
    "    lr_map=lr_map,\n",
    "    xmin=xmin, \n",
    "    xmax=xmax,\n",
    "    n_seeds=N_SEEDS,\n",
    "    time_limit=time_limit,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92dbc0ee-bbe1-4fdd-8782-44ece09dda4b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dfn]",
   "language": "python",
   "name": "conda-env-dfn-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
