{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f5b14808",
   "metadata": {},
   "source": [
    "# DFN notebook (cleaned)\n",
    "\n",
    "This is a simplified/cleaned rewrite of the original notebook. It keeps the same behavior and outputs, but the code is shorter, more consistent, and easier to read."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2efb5921",
   "metadata": {},
   "source": [
    "## 1) Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce8269a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import sys, time, math, io, contextlib\n",
    "from pathlib import Path\n",
    "from typing import Optional, List, Tuple, Dict, Any\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from IPython.display import display\n",
    "\n",
    "_TOL = 1e-9"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417583e0",
   "metadata": {},
   "source": [
    "## 2) LEMON Min-Cost Flow binding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62767112",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cppimport\n",
    "\n",
    "# The original notebook assumes `lemon_mcf` is available as a local cppimport module.\n",
    "repo = Path().resolve().parent\n",
    "if str(repo) not in sys.path:\n",
    "    sys.path.insert(0, str(repo))\n",
    "\n",
    "lemon_mcf = cppimport.imp(\"lemon_mcf\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15e6bd1b",
   "metadata": {},
   "source": [
    "### Quick sanity check (same as original)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88381d28",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 3\n",
    "src    = np.array([0, 0, 1], dtype=np.int64)\n",
    "dst    = np.array([1, 2, 2], dtype=np.int64)\n",
    "cost   = np.array([2.0, 1.0, 3.0], dtype=np.float64)\n",
    "cap    = np.array([5.0, 2.0, 4.0], dtype=np.float64)\n",
    "supply = np.array([3.0, -1.0, -2.0], dtype=np.float64)\n",
    "\n",
    "out_min_cost_flow = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply)\n",
    "print(out_min_cost_flow)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75235271",
   "metadata": {},
   "source": [
    "## 3) Dataset generators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67763cc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_mdvsp_dataset(\n",
    "    K: int,\n",
    "    filename: str,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "    max_trips=None,\n",
    "    max_succ=None,\n",
    "):\n",
    "    \"\"\"Multiple-Depot Vehicle Scheduling (MDVSP) dataset.\n",
    "\n",
    "    Returns:\n",
    "      X: (K, m) integer-ish capacities (float32)\n",
    "      y: (K,) min-cost-flow objective values (float32)\n",
    "      gt: dict containing the fixed network pieces (for later evaluation)\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    with open(filename) as f:\n",
    "        m, n, l = map(int, f.readline().split())\n",
    "        f.readline()  # blank line\n",
    "        trips = np.loadtxt(f, max_rows=n, dtype=np.int64)[:max_trips]\n",
    "        D = np.loadtxt(f, max_rows=l, dtype=np.int64)\n",
    "\n",
    "    p, s, q, e = trips.T\n",
    "    ntr = len(trips)\n",
    "\n",
    "    # node ids\n",
    "    SS = 0\n",
    "    depS = 1 + np.arange(m)\n",
    "    depT = 1 + m + np.arange(m)\n",
    "    trS  = 1 + 2*m + np.arange(ntr)\n",
    "    trT  = 1 + 2*m + ntr + np.arange(ntr)\n",
    "    TT = 1 + 2*m + 2*ntr\n",
    "    N = TT + 1\n",
    "\n",
    "    src, dst, cost, cap = [], [], [], []\n",
    "\n",
    "    # depot boundary arcs (capacity is what we learn/provide per sample)\n",
    "    for d in range(m):\n",
    "        src += [SS, int(depT[d])]\n",
    "        dst += [int(depS[d]), TT]\n",
    "        cost += [0.0, 0.0]\n",
    "        cap  += [0.0, 0.0]\n",
    "    idxSS = np.arange(0, 2*m, 2)  # arcs SS->depS\n",
    "    idxTT = np.arange(1, 2*m, 2)  # arcs depT->TT\n",
    "\n",
    "    # trip arcs (always cap=1)\n",
    "    src += trS.tolist()\n",
    "    dst += trT.tolist()\n",
    "    cost += [0.0] * ntr\n",
    "    cap  += [1.0] * ntr\n",
    "\n",
    "    # depot-to-trip and trip-to-depot arcs\n",
    "    for d in range(m):\n",
    "        # depS -> trS\n",
    "        src += [int(depS[d])] * ntr\n",
    "        dst += trS.tolist()\n",
    "        cost += (5000 + 10 * D[d, p]).astype(float).tolist()\n",
    "        cap  += [1.0] * ntr\n",
    "\n",
    "        # trT -> depT\n",
    "        src += trT.tolist()\n",
    "        dst += [int(depT[d])] * ntr\n",
    "        cost += (5000 + 10 * D[q, d]).astype(float).tolist()\n",
    "        cap  += [1.0] * ntr\n",
    "\n",
    "    # feasible trip successor arcs (trT -> next trS)\n",
    "    order = np.argsort(s)\n",
    "    p2, s2 = p[order], s[order]\n",
    "    max_succ_eff = ntr if max_succ is None else int(max_succ)\n",
    "\n",
    "    for i in range(ntr):\n",
    "        travel = D[q[i], p2]                          # time from trip i end depot -> next trip start depot\n",
    "        feas = np.flatnonzero(s2 >= e[i] + travel)[:max_succ_eff]\n",
    "        j = order[feas]\n",
    "        if j.size:\n",
    "            src += [int(trT[i])] * j.size\n",
    "            dst += trS[j].tolist()\n",
    "            cost += (8 * travel[feas] + 2 * (s[j] - e[i])).astype(float).tolist()\n",
    "            cap  += [1.0] * j.size\n",
    "\n",
    "    src = np.asarray(src, dtype=np.int64)\n",
    "    dst = np.asarray(dst, dtype=np.int64)\n",
    "    cost = np.asarray(cost, dtype=np.float64)\n",
    "    cap0 = np.asarray(cap, dtype=np.float64)\n",
    "\n",
    "    # sample capacities X and compute y via max-flow + min-cost-flow\n",
    "    X = rng.integers(x_min, np.asarray(x_max) + 1, size=(K, m)).astype(np.float64)\n",
    "    y = np.empty(K, dtype=np.float64)\n",
    "\n",
    "    for k in range(K):\n",
    "        cap_k = cap0.copy()\n",
    "        cap_k[idxSS] = X[k]\n",
    "        cap_k[idxTT] = X[k]\n",
    "\n",
    "        Fmax = lemon_mcf.max_flow(N, src, dst, cap_k, SS, TT)[\"value\"]\n",
    "        supply = np.zeros(N, dtype=np.float64)\n",
    "        supply[SS] = Fmax\n",
    "        supply[TT] = -Fmax\n",
    "\n",
    "        y[k] = lemon_mcf.solve_mcf(N, src, dst, cost, cap_k, supply)[\"total_cost\"]\n",
    "        if noise_std:\n",
    "            y[k] += noise_std * rng.normal()\n",
    "\n",
    "    gt = dict(\n",
    "        type=\"mdvsp\", N=int(N), SS=int(SS), TT=int(TT),\n",
    "        src=src, dst=dst, cost=cost, cap0=cap0, idxSS=idxSS, idxTT=idxTT\n",
    "    )\n",
    "    return X.astype(np.float32), y.astype(np.float32), gt\n",
    "\n",
    "\n",
    "def generate_bipartite_subset_matching_dataset(\n",
    "    K: int, num_nodes: int, c_min: int, c_max: int, noise_std: float = 0.0, seed: int = 0\n",
    "):\n",
    "    \"\"\"Assignment-style dataset: choose a subset of left nodes, match to right nodes with min cost.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    C = rng.integers(c_min, c_max + 1, size=(num_nodes, num_nodes)).astype(np.float32)\n",
    "\n",
    "    X = np.zeros((K, num_nodes), dtype=np.float32)\n",
    "    y = np.zeros((K,), dtype=np.float32)\n",
    "\n",
    "    for k in range(K):\n",
    "        s = int(rng.integers(1, num_nodes + 1))\n",
    "        idx = rng.choice(num_nodes, size=s, replace=False)\n",
    "        X[k, idx] = 1.0\n",
    "\n",
    "        r, c = linear_sum_assignment(C[idx, :])\n",
    "        y[k] = C[idx, :][r, c].sum()\n",
    "        if noise_std:\n",
    "            y[k] += noise_std * rng.normal()\n",
    "\n",
    "    gt = {\"type\": \"assignment\", \"C\": C.astype(np.float32)}\n",
    "    return X, y, gt\n",
    "\n",
    "\n",
    "def generate_convex_quadratic_dataset(\n",
    "    K: int,\n",
    "    dim: int,\n",
    "    eigen_min: float,\n",
    "    eigen_max: float,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "    x_star_zero: bool = True,\n",
    "):\n",
    "    \"\"\"y = (x - x*)^T Q (x - x*) + noise, with Q symmetric PSD.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    U, R = np.linalg.qr(rng.standard_normal((dim, dim)))\n",
    "    U *= np.sign(np.diag(R) + 1e-12)\n",
    "    Q = U @ np.diag(rng.uniform(eigen_min, eigen_max, dim)) @ U.T\n",
    "\n",
    "    x_star = np.zeros(dim, dtype=np.int64) if x_star_zero else rng.integers(x_min, x_max + 1, size=dim)\n",
    "\n",
    "    X = rng.integers(x_min, x_max + 1, size=(K, dim)).astype(np.float32)\n",
    "    d = X - x_star\n",
    "    y = np.einsum(\"bi,ij,bj->b\", d, Q, d) + noise_std * rng.normal(size=K)\n",
    "\n",
    "    gt = {\"type\": \"quadratic\", \"Q\": Q.astype(np.float32), \"x_star\": x_star.astype(np.int64)}\n",
    "    return X.astype(np.float32), y.astype(np.float32), gt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dfc62e4",
   "metadata": {},
   "source": [
    "### Quick checks (same outputs as the original demo cells)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89201b4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: MDVSP requires that `filename` exists on disk.\n",
    "# If you don't have the .dat file, comment this block out.\n",
    "X, y, _ = make_mdvsp_dataset(\n",
    "    K=50, filename=\"RN-8-3000-03.dat\", x_min=0, x_max=6, noise_std=1.0, seed=1, max_trips=200, max_succ=5\n",
    ")\n",
    "print(X.shape, y.shape, y[:5])\n",
    "\n",
    "X, y, _ = generate_bipartite_subset_matching_dataset(\n",
    "    K=50, num_nodes=20, c_min=1, c_max=10, noise_std=0.0, seed=0\n",
    ")\n",
    "print(X.shape, y.shape, y[:5])\n",
    "\n",
    "X, y, _ = generate_convex_quadratic_dataset(\n",
    "    K=50, dim=10, eigen_min=1.0, eigen_max=20.0, x_min=-100, x_max=100, noise_std=0.0, seed=0\n",
    ")\n",
    "print(X.shape, y.shape, y[:5])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70071f5f",
   "metadata": {},
   "source": [
    "## 4) Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07278db8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _ste_round(x: torch.Tensor) -> torch.Tensor:\n",
    "    \"\"\"Straight-through estimator for rounding.\"\"\"\n",
    "    return x + (torch.round(x) - x).detach()\n",
    "\n",
    "\n",
    "class _MCFValue(torch.autograd.Function):\n",
    "    \"\"\"Differentiable min-cost-flow value using LEMON.\n",
    "\n",
    "    Gradients:\n",
    "      dV/dcost = flow\n",
    "      dV/dcap  = reduced_cost on arcs at capacity (0 otherwise)\n",
    "      dV/dsupply = potentials (with gauge fixing)\n",
    "    \"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    def forward(ctx, n_nodes, src, dst, cost, cap, supply):\n",
    "        n = int(n_nodes)\n",
    "\n",
    "        src = src.to(dtype=torch.int64).contiguous()\n",
    "        dst = dst.to(dtype=torch.int64).contiguous()\n",
    "\n",
    "        m = int(src.numel())\n",
    "        if (dst.numel() != m) or (cost.numel() != m) or (cap.numel() != m) or (supply.numel() != n):\n",
    "            raise ValueError(\"Bad shapes for MCF inputs.\")\n",
    "        if torch.abs(supply.double().sum()) > _TOL:\n",
    "            raise ValueError(\"Require sum(supply)=0\")\n",
    "\n",
    "        def as_np(t: torch.Tensor, dtype):\n",
    "            return t.detach().cpu().contiguous().view(-1).numpy().astype(dtype, copy=False)\n",
    "\n",
    "        out = lemon_mcf.solve_mcf(\n",
    "            n,\n",
    "            as_np(src, np.int64),\n",
    "            as_np(dst, np.int64),\n",
    "            as_np(cost, np.float64),\n",
    "            as_np(cap, np.float64),\n",
    "            as_np(supply, np.float64),\n",
    "            tol=_TOL,\n",
    "        )\n",
    "        if out[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"LEMON failed (status={out['status']})\")\n",
    "\n",
    "        flow = out[\"flow\"]\n",
    "        pot = out[\"potential\"]\n",
    "        red = out[\"reduced_cost\"]\n",
    "        at = out.get(\"at_cap\", out.get(\"at_capacity\", None))\n",
    "        if at is None:\n",
    "            at = np.abs(flow - as_np(cap, np.float64)) <= _TOL\n",
    "\n",
    "        ctx.flow, ctx.pot, ctx.red, ctx.at = flow, pot, red, at\n",
    "        return cost.new_tensor(float(out[\"total_cost\"]))\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g):\n",
    "        dev, dt = g.device, g.dtype\n",
    "        flow = torch.as_tensor(ctx.flow, device=dev, dtype=dt)\n",
    "        pot  = torch.as_tensor(ctx.pot,  device=dev, dtype=dt)\n",
    "        red  = torch.as_tensor(ctx.red,  device=dev, dtype=dt)\n",
    "        at   = torch.as_tensor(ctx.at,   device=dev, dtype=torch.bool)\n",
    "\n",
    "        grad_cost = flow\n",
    "        grad_cap  = torch.where(at, red, torch.zeros_like(red))\n",
    "        grad_sup  = pot.mean() - pot  # gauge-fix\n",
    "\n",
    "        return None, None, None, grad_cost * g, grad_cap * g, grad_sup * g\n",
    "\n",
    "\n",
    "class DFN(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        layer_sizes,\n",
    "        p_list,\n",
    "        big_cost: float = 1e6,\n",
    "        big_cap: float = 1e6,\n",
    "        seed: int = 0,\n",
    "        A_fixed=None,\n",
    "        alpha: float = 1e-6,\n",
    "        beta: float = -0.0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.alpha = float(alpha)\n",
    "        self.beta  = float(beta)\n",
    "\n",
    "        layer_sizes = list(map(int, layer_sizes))\n",
    "        if len(layer_sizes) < 2 or len(p_list) != len(layer_sizes) - 1:\n",
    "            raise ValueError(\"Need len(layer_sizes)>=2 and len(p_list)=len(layer_sizes)-1\")\n",
    "\n",
    "        self.n = int(sum(layer_sizes))\n",
    "        if self.n <= 0:\n",
    "            raise ValueError(\"sum(layer_sizes) must be > 0\")\n",
    "\n",
    "        # node indices per layer\n",
    "        layers, off = [], 0\n",
    "        for s in layer_sizes:\n",
    "            layers.append(torch.arange(off, off + s, dtype=torch.long))\n",
    "            off += s\n",
    "\n",
    "        L1, LK = layers[0], layers[-1]\n",
    "        if L1.numel() == 0 or LK.numel() == 0:\n",
    "            raise ValueError(\"First/last layer must be non-empty.\")\n",
    "\n",
    "        # one node is used for the supply gauge constraint\n",
    "        self.fix_node = int(LK[-1].item())\n",
    "        boundary = torch.cat([L1, LK[:-1]], 0)  # boundary nodes for A w + b\n",
    "        self.register_buffer(\"boundary\", boundary)\n",
    "\n",
    "        gen = torch.Generator().manual_seed(int(seed))\n",
    "\n",
    "        def bipartite(U: torch.Tensor, V: torch.Tensor):\n",
    "            su, sv = int(U.numel()), int(V.numel())\n",
    "            return U.repeat_interleave(sv), V.repeat(su)\n",
    "\n",
    "        def sample_edges(U, V, p: float):\n",
    "            s, t = bipartite(U, V)\n",
    "            if p < 1.0:\n",
    "                keep = torch.rand(s.numel(), generator=gen) < float(p)\n",
    "                s, t = s[keep], t[keep]\n",
    "            return s, t\n",
    "\n",
    "        # learnable arcs between consecutive layers (both directions)\n",
    "        sf, tf, sb, tb = [], [], [], []\n",
    "        for i, p in enumerate(map(float, p_list)):\n",
    "            if not (0.0 <= p <= 1.0):\n",
    "                raise ValueError(\"p_list entries must be in [0,1]\")\n",
    "\n",
    "            s, t = sample_edges(layers[i], layers[i + 1], p)  # forward\n",
    "            sf.append(s); tf.append(t)\n",
    "\n",
    "            s, t = sample_edges(layers[i + 1], layers[i], p)  # backward\n",
    "            sb.append(s); tb.append(t)\n",
    "\n",
    "        src_param = torch.cat([torch.cat(sf, 0), torch.cat(sb, 0)], 0)\n",
    "        dst_param = torch.cat([torch.cat(tf, 0), torch.cat(tb, 0)], 0)\n",
    "        if src_param.numel() == 0:\n",
    "            raise ValueError(\"No learnable arcs (increase p_list / layer sizes).\")\n",
    "\n",
    "        # fixed big arcs L1 <-> LK (always present)\n",
    "        s1, t1 = bipartite(L1, LK)\n",
    "        s2, t2 = bipartite(LK, L1)\n",
    "        src_fixed = torch.cat([s1, s2], 0)\n",
    "        dst_fixed = torch.cat([t1, t2], 0)\n",
    "        m_fixed = int(src_fixed.numel())\n",
    "\n",
    "        self.register_buffer(\"src\", torch.cat([src_param, src_fixed], 0))\n",
    "        self.register_buffer(\"dst\", torch.cat([dst_param, dst_fixed], 0))\n",
    "        self.register_buffer(\"cap_fixed\",  torch.full((m_fixed,), float(big_cap),  dtype=torch.float32))\n",
    "        self.register_buffer(\"cost_fixed\", torch.full((m_fixed,), float(big_cost), dtype=torch.float32))\n",
    "\n",
    "        nb = int(boundary.numel())\n",
    "        input_dim = int(input_dim)\n",
    "\n",
    "        m_param = int(src_param.numel())\n",
    "        self.cap_raw  = nn.Parameter(torch.zeros(m_param) + 0.542)\n",
    "        self.cost_raw = nn.Parameter(torch.randn(m_param) + 1.0)\n",
    "        self.b_raw    = nn.Parameter(torch.zeros(nb))\n",
    "\n",
    "        if A_fixed is None:\n",
    "            # same construction as original: a sparse-ish default A\n",
    "            A = torch.zeros(nb, input_dim)\n",
    "            rows = torch.arange(nb)\n",
    "            A[rows, rows % input_dim] = 1.0\n",
    "            self.A = nn.Parameter(A)\n",
    "        else:\n",
    "            A_fixed = torch.as_tensor(A_fixed, dtype=torch.float32)\n",
    "            if A_fixed.shape != (nb, input_dim):\n",
    "                raise ValueError(f\"A_fixed must have shape {(nb, input_dim)}, got {tuple(A_fixed.shape)}\")\n",
    "            self.register_buffer(\"A\", A_fixed)\n",
    "\n",
    "    def forward(self, w: torch.Tensor) -> torch.Tensor:\n",
    "        capP  = _ste_round(F.softplus(self.cap_raw))\n",
    "        costP = self.cost_raw\n",
    "        b     = _ste_round(self.b_raw)\n",
    "        A     = _ste_round(self.A) if isinstance(self.A, nn.Parameter) else self.A\n",
    "\n",
    "        cap  = torch.cat([capP,  self.cap_fixed.to(w.device, w.dtype)], 0)\n",
    "        cost = torch.cat([costP, self.cost_fixed.to(w.device, w.dtype)], 0)\n",
    "\n",
    "        def one(w1: torch.Tensor) -> torch.Tensor:\n",
    "            supply = torch.zeros(self.n, device=w1.device, dtype=torch.float64)\n",
    "            supply[self.boundary] = (A.double() @ w1.double()) + b.double()\n",
    "            supply[self.fix_node] = -supply.sum()\n",
    "            return _MCFValue.apply(self.n, self.src, self.dst, cost, cap, supply)\n",
    "\n",
    "        out = one(w) if w.dim() == 1 else torch.stack([one(wi) for wi in w], 0)\n",
    "        return self.alpha * out + self.beta\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dims, out_dim):\n",
    "        super().__init__()\n",
    "        dims = [in_dim] + list(hidden_dims) + [out_dim]\n",
    "        layers = []\n",
    "        for a, b in zip(dims[:-2], dims[1:-1]):\n",
    "            layers += [nn.Linear(a, b), nn.ReLU()]\n",
    "        layers += [nn.Linear(dims[-2], dims[-1])]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "class MaxAffine(nn.Module):\n",
    "    \"\"\"f(x) = max_k (w_k^T x + b_k), returns shape (batch,)\"\"\"\n",
    "\n",
    "    def __init__(self, in_dim: int, n_pieces: int):\n",
    "        super().__init__()\n",
    "        self.W = nn.Parameter(torch.randn(n_pieces, in_dim) / (in_dim**0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces))\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        return (x @ self.W.T + self.b).max(dim=1).values\n",
    "\n",
    "\n",
    "class LSET(nn.Module):\n",
    "    def __init__(self, in_dim: int, n_pieces: int, T: float = 0.01):\n",
    "        super().__init__()\n",
    "        self.T = float(T)\n",
    "        if self.T == 0.0:\n",
    "            raise ValueError(\"T must be nonzero\")\n",
    "        self.A = nn.Parameter(torch.randn(n_pieces, in_dim) / (in_dim**0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces))\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        z = (x @ self.A.t() + self.b) / self.T\n",
    "        return self.T * torch.logsumexp(z, dim=-1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cca0120",
   "metadata": {},
   "source": [
    "## 5) Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a7db872",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _split_train_val_test(X: torch.Tensor, y: torch.Tensor, *, val_frac: float, test_frac: float, seed: int):\n",
    "    N = int(X.shape[0])\n",
    "    n_test = int(round(test_frac * N))\n",
    "    n_val  = int(round(val_frac  * N))\n",
    "    n_train = N - n_val - n_test\n",
    "    if n_train <= 0:\n",
    "        raise ValueError(\"splits too large; train set would be empty\")\n",
    "\n",
    "    g = torch.Generator().manual_seed(int(seed))\n",
    "    perm = torch.randperm(N, generator=g)\n",
    "    i_tr = perm[:n_train]\n",
    "    i_va = perm[n_train:n_train + n_val]\n",
    "    i_te = perm[n_train + n_val:]\n",
    "\n",
    "    return (X[i_tr], y[i_tr]), (X[i_va], y[i_va]), (X[i_te], y[i_te])\n",
    "\n",
    "\n",
    "def _fit_standardizer(Xtr: torch.Tensor, ytr: torch.Tensor, *, eps: float):\n",
    "    x_mean = Xtr.mean(0, keepdim=True)\n",
    "    x_std  = Xtr.std(0, unbiased=False, keepdim=True)\n",
    "    x_std  = torch.where(x_std < eps, torch.ones_like(x_std), x_std)\n",
    "\n",
    "    y_mean = ytr.mean()\n",
    "    y_std  = ytr.std(unbiased=False).clamp_min(eps)\n",
    "\n",
    "    return {\"x_mean\": x_mean, \"x_std\": x_std, \"y_mean\": y_mean, \"y_std\": y_std}\n",
    "\n",
    "\n",
    "def _apply_standardizer(X: torch.Tensor, y: torch.Tensor, scaler):\n",
    "    Xn = (X - scaler[\"x_mean\"]) / scaler[\"x_std\"]\n",
    "    yn = (y - scaler[\"y_mean\"]) / scaler[\"y_std\"]\n",
    "    return Xn, yn\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def _mse_norm(model: nn.Module, loader: DataLoader, device: str):\n",
    "    model.eval()\n",
    "    tot, n = 0.0, 0\n",
    "    for xb, yb in loader:\n",
    "        xb, yb = xb.to(device), yb.to(device)\n",
    "        pred = model(xb).squeeze(-1)\n",
    "        tot += F.mse_loss(pred, yb, reduction=\"sum\").item()\n",
    "        n += yb.numel()\n",
    "    return tot / max(n, 1)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def _predict_in_chunks(model: nn.Module, X: torch.Tensor, *, chunk: int):\n",
    "    out = []\n",
    "    for i in range(0, int(X.shape[0]), int(chunk)):\n",
    "        out.append(model(X[i:i + chunk]).squeeze(-1))\n",
    "    return torch.cat(out, 0)\n",
    "\n",
    "\n",
    "def generate_and_train_simple(dataset_type, dataset_params, model_type, model_params, train_params=None):\n",
    "    \"\"\"Train a model on a generated dataset.\n",
    "\n",
    "    Returns (same as original notebook):\n",
    "      model, data, history, spec\n",
    "    \"\"\"\n",
    "    tp = train_params or {}\n",
    "\n",
    "    epochs     = int(tp.get(\"epochs\", 200))\n",
    "    batch_sz   = int(tp.get(\"batch_size\", 32))\n",
    "    lr         = float(tp.get(\"lr\", 1e-3))\n",
    "    wd         = float(tp.get(\"weight_decay\", 0.0))\n",
    "    val_frac   = float(tp.get(\"val_frac\", 0.15))\n",
    "    test_frac  = float(tp.get(\"test_frac\", 0.15))\n",
    "    eps        = float(tp.get(\"eps\", 1e-8))\n",
    "    seed       = int(tp.get(\"seed\", 0))\n",
    "    device     = tp.get(\"device\", \"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    plot_every = int(tp.get(\"plot_every\", 0) or 0)\n",
    "    plot_points= int(tp.get(\"plot_points\", 2048))\n",
    "    plot_chunk = int(tp.get(\"plot_chunk\", 4096))\n",
    "    print_stats= bool(tp.get(\"print_stats\", True))\n",
    "\n",
    "    # ---- data ----\n",
    "    dt = str(dataset_type).lower()\n",
    "    if dt == \"mdvsp\":\n",
    "        X, y, gt = make_mdvsp_dataset(**dataset_params)\n",
    "    elif dt == \"assignment\":\n",
    "        X, y, gt = generate_bipartite_subset_matching_dataset(**dataset_params)\n",
    "    elif dt == \"quadratic\":\n",
    "        X, y, gt = generate_convex_quadratic_dataset(**dataset_params)\n",
    "    else:\n",
    "        raise ValueError(\"dataset_type must be: mdvsp | assignment | quadratic\")\n",
    "\n",
    "    X = torch.as_tensor(X, dtype=torch.float32)\n",
    "    y = torch.as_tensor(y, dtype=torch.float32).view(-1)\n",
    "\n",
    "    # ---- quick dataset stats (raw) ----\n",
    "    if print_stats:\n",
    "        with torch.no_grad():\n",
    "            xmn = X.mean(0)\n",
    "            xsd = X.std(0, unbiased=False)\n",
    "            print(\n",
    "                f\"\\n--- Dataset stats ({dataset_type}) ---\\n\"\n",
    "                f\"  X: shape={tuple(X.shape)}  mean(mean)={xmn.mean():.3g}  std(mean)={xsd.mean():.3g}  \"\n",
    "                f\"min={float(X.min()):.3g}  max={float(X.max()):.3g}\\n\"\n",
    "                f\"  y: shape={tuple(y.shape)}  mean={float(y.mean()):.3g}  std={float(y.std(unbiased=False)):.3g}  \"\n",
    "                f\"min={float(y.min()):.3g}  max={float(y.max()):.3g}\\n\"\n",
    "            )\n",
    "\n",
    "    (Xtr, ytr), (Xva, yva), (Xte, yte) = _split_train_val_test(X, y, val_frac=val_frac, test_frac=test_frac, seed=seed)\n",
    "\n",
    "    scaler = _fit_standardizer(Xtr, ytr, eps=eps)\n",
    "    XtrN, ytrN = _apply_standardizer(Xtr, ytr, scaler)\n",
    "    XvaN, yvaN = _apply_standardizer(Xva, yva, scaler)\n",
    "    XteN, yteN = _apply_standardizer(Xte, yte, scaler)\n",
    "\n",
    "    train_loader = DataLoader(TensorDataset(XtrN, ytrN), batch_size=batch_sz, shuffle=True)\n",
    "    val_loader   = DataLoader(TensorDataset(XvaN, yvaN), batch_size=batch_sz, shuffle=False)\n",
    "\n",
    "    # subset for plotting\n",
    "    Nv = int(XvaN.shape[0])\n",
    "    if plot_points <= 0 or plot_points >= Nv:\n",
    "        plot_idx = torch.arange(Nv)\n",
    "    else:\n",
    "        g_plot = torch.Generator().manual_seed(seed + 12345)\n",
    "        plot_idx = torch.randperm(Nv, generator=g_plot)[:plot_points]\n",
    "\n",
    "    # ---- model ----\n",
    "    mt = str(model_type)\n",
    "    mp = dict(model_params)\n",
    "\n",
    "    if mt == \"DFN\":\n",
    "        model = DFN(**mp)\n",
    "        extra = f\"layers={mp.get('layer_sizes')} p_list={mp.get('p_list')} alpha={getattr(model,'alpha',None)} beta={getattr(model,'beta',None)}\"\n",
    "    else:\n",
    "        # non-DFN models don't use alpha/beta in this notebook\n",
    "        mp.pop(\"alpha\", None)\n",
    "        mp.pop(\"beta\", None)\n",
    "        if mt == \"MLP\":\n",
    "            model = MLP(**mp)\n",
    "            extra = f\"hidden={mp.get('hidden_dims')}\"\n",
    "        elif mt == \"MaxAffine\":\n",
    "            model = MaxAffine(**mp)\n",
    "            extra = f\"n_pieces={mp.get('n_pieces')}\"\n",
    "        elif mt == \"LSET\":\n",
    "            model = LSET(**mp)\n",
    "            extra = f\"n_pieces={mp.get('n_pieces')} T={mp.get('T')}\"\n",
    "        else:\n",
    "            raise ValueError(\"model_type must be: DFN | MLP | MaxAffine | LSET\")\n",
    "\n",
    "    model = model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "\n",
    "    n_params = sum(p.numel() for p in model.parameters())\n",
    "    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    print(\n",
    "        f\"\\n=== Run: {dataset_type} | {model_type} ===\\n\"\n",
    "        f\"  data: N={len(X)}  train/val/test={len(Xtr)}/{len(Xva)}/{len(Xte)}  dim={X.shape[1]}\\n\"\n",
    "        f\"  model: params={n_params:,} (trainable={n_trainable:,})  {extra}\\n\"\n",
    "        f\"  train: device={device}  epochs={epochs}  batch={batch_sz}  lr={lr:g}  wd={wd:g}  seed={seed}\\n\"\n",
    "    )\n",
    "\n",
    "    history = {\"train_mse_norm\": [], \"val_mse_norm\": []}\n",
    "    best_val, best_ep, best_state = float(\"inf\"), 0, None\n",
    "\n",
    "    live = display(None, display_id=True) if plot_every > 0 else None\n",
    "\n",
    "    for ep in range(1, epochs + 1):\n",
    "        model.train()\n",
    "        for xb, yb in train_loader:\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            loss = F.mse_loss(model(xb).squeeze(-1), yb)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "\n",
    "        tr_mse = _mse_norm(model, train_loader, device)\n",
    "        va_mse = _mse_norm(model, val_loader, device)\n",
    "        history[\"train_mse_norm\"].append(tr_mse)\n",
    "        history[\"val_mse_norm\"].append(va_mse)\n",
    "\n",
    "        if va_mse < best_val:\n",
    "            best_val, best_ep = va_mse, ep\n",
    "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
    "\n",
    "        if plot_every > 0 and (ep == 1 or ep % plot_every == 0 or ep == epochs):\n",
    "            model.eval()\n",
    "            x_plot = XvaN[plot_idx].to(device)\n",
    "            y_true = yvaN[plot_idx].to(device)\n",
    "            y_pred = _predict_in_chunks(model, x_plot, chunk=plot_chunk)\n",
    "\n",
    "            fig, ax = plt.subplots(1, 2, figsize=(10, 4))\n",
    "            ax[0].plot(history[\"train_mse_norm\"], label=\"train\")\n",
    "            ax[0].plot(history[\"val_mse_norm\"], label=\"val\")\n",
    "            ax[0].set_yscale(\"log\")\n",
    "            ax[0].set_title(f\"Epoch {ep}/{epochs} | val MSE={va_mse:.3e} (norm)\")\n",
    "            ax[0].legend()\n",
    "\n",
    "            yt = y_true.detach().cpu().numpy()\n",
    "            yp = y_pred.detach().cpu().numpy()\n",
    "            ax[1].scatter(yt, yp, s=10)\n",
    "            lo = float(min(yt.min(), yp.min()))\n",
    "            hi = float(max(yt.max(), yp.max()))\n",
    "            ax[1].plot([lo, hi], [lo, hi])\n",
    "            ax[1].set_xlabel(\"y_true (norm)\")\n",
    "            ax[1].set_ylabel(\"y_pred (norm)\")\n",
    "            ax[1].set_title(f\"Val scatter (n={len(yt)})\")\n",
    "\n",
    "            plt.tight_layout()\n",
    "            live.update(fig)\n",
    "            plt.close(fig)\n",
    "\n",
    "    if best_state is not None:\n",
    "        model.load_state_dict(best_state)\n",
    "\n",
    "    if print_stats:\n",
    "        print(f\"[DONE] best val MSE (norm) = {best_val:.3e} @ epoch {best_ep}\\n\")\n",
    "\n",
    "    data = {\n",
    "        \"raw\":  {\"Xtr\": Xtr,  \"ytr\": ytr,  \"Xva\": Xva,  \"yva\": yva,  \"Xte\": Xte,  \"yte\": yte},\n",
    "        \"norm\": {\"Xtr\": XtrN, \"ytr\": ytrN, \"Xva\": XvaN, \"yva\": yvaN, \"Xte\": XteN, \"yte\": yteN},\n",
    "        \"scaler\": scaler,\n",
    "        \"best_val_mse_norm\": float(best_val),\n",
    "        \"best_epoch\": int(best_ep),\n",
    "        \"stats\": {\"n_params\": int(n_params), \"n_trainable\": int(n_trainable)},\n",
    "        \"true\": gt,\n",
    "        \"device\": device,\n",
    "    }\n",
    "\n",
    "    spec = {\n",
    "        \"model\": model_type,\n",
    "        \"extra\": extra,\n",
    "        \"n_params\": int(n_params),\n",
    "        \"n_trainable\": int(n_trainable),\n",
    "        \"model_params\": dict(model_params),\n",
    "        \"train_params\": {\n",
    "            \"epochs\": int(epochs),\n",
    "            \"batch_size\": int(batch_sz),\n",
    "            \"lr\": float(lr),\n",
    "            \"weight_decay\": float(wd),\n",
    "            \"seed\": int(seed),\n",
    "            \"device\": str(device),\n",
    "            \"val_frac\": float(val_frac),\n",
    "            \"test_frac\": float(test_frac),\n",
    "        },\n",
    "    }\n",
    "\n",
    "    return model, data, history, spec\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ace7f63f",
   "metadata": {},
   "source": [
    "## 6) Optimization helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38416530",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_true_obj(gt, x):\n",
    "    \"\"\"Evaluate the *ground truth* objective for the dataset (when gt is provided).\"\"\"\n",
    "    if not isinstance(gt, dict):\n",
    "        return np.nan\n",
    "\n",
    "    x = np.asarray(x).reshape(-1)\n",
    "    t = gt.get(\"type\", None)\n",
    "\n",
    "    if t == \"quadratic\":\n",
    "        Q = np.asarray(gt[\"Q\"], float)\n",
    "        xs = np.asarray(gt[\"x_star\"], float).reshape(-1)\n",
    "        d = x.astype(float) - xs\n",
    "        return float(d @ Q @ d)\n",
    "\n",
    "    if t == \"assignment\":\n",
    "        C = np.asarray(gt[\"C\"], float)\n",
    "        idx = np.flatnonzero(x > 0.5)\n",
    "        if idx.size == 0:\n",
    "            return 0.0\n",
    "        r, c = linear_sum_assignment(C[idx, :])\n",
    "        return float(C[idx, :][r, c].sum())\n",
    "\n",
    "    if t == \"mdvsp\":\n",
    "        N, SS, TT = int(gt[\"N\"]), int(gt[\"SS\"]), int(gt[\"TT\"])\n",
    "        src = np.asarray(gt[\"src\"], np.int64)\n",
    "        dst = np.asarray(gt[\"dst\"], np.int64)\n",
    "        cost = np.asarray(gt[\"cost\"], float)\n",
    "        cap0 = np.asarray(gt[\"cap0\"], float)\n",
    "        idxSS = np.asarray(gt[\"idxSS\"], np.int64)\n",
    "        idxTT = np.asarray(gt[\"idxTT\"], np.int64)\n",
    "\n",
    "        cap = cap0.copy()\n",
    "        cap[idxSS] = x.astype(float)\n",
    "        cap[idxTT] = x.astype(float)\n",
    "\n",
    "        Fmax = lemon_mcf.max_flow(N, src, dst, cap, SS, TT)[\"value\"]\n",
    "        supply = np.zeros(N, float)\n",
    "        supply[SS] = Fmax\n",
    "        supply[TT] = -Fmax\n",
    "        return float(lemon_mcf.solve_mcf(N, src, dst, cost, cap, supply)[\"total_cost\"])\n",
    "\n",
    "    return np.nan\n",
    "\n",
    "\n",
    "def local_search_l1_int(\n",
    "    f,\n",
    "    x0,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    delta: int,\n",
    "    sum_eq=None,\n",
    "    max_iters: int = 10_000,\n",
    "    print_every: int = 1,\n",
    "):\n",
    "    \"\"\"Naive local search over integer L1-ball neighbors.\"\"\"\n",
    "    x  = np.asarray(x0,    int).ravel()\n",
    "    lo = np.asarray(x_min, int).ravel()\n",
    "    hi = np.asarray(x_max, int).ravel()\n",
    "    n = int(x.size)\n",
    "    delta = int(delta)\n",
    "\n",
    "    assert lo.size == n and hi.size == n\n",
    "    assert np.all(x >= lo) and np.all(x <= hi)\n",
    "    if sum_eq is not None:\n",
    "        sum_eq = int(sum_eq)\n",
    "        assert int(x.sum()) == sum_eq\n",
    "\n",
    "    def eval_batch(X):\n",
    "        y = np.asarray(f(X), float).reshape(-1)\n",
    "        return y\n",
    "\n",
    "    def ok(z):\n",
    "        if np.any(z < lo) or np.any(z > hi):\n",
    "            return False\n",
    "        if sum_eq is not None and int(z.sum()) != sum_eq:\n",
    "            return False\n",
    "        return True\n",
    "\n",
    "    # integer deltas with exact L1 = k\n",
    "    def deltas_exact(k):\n",
    "        d = np.zeros(n, int)\n",
    "\n",
    "        def rec(i, rem):\n",
    "            if i == n:\n",
    "                if rem == 0:\n",
    "                    yield d.copy()\n",
    "                return\n",
    "            for t in range(rem + 1):\n",
    "                if t == 0:\n",
    "                    d[i] = 0\n",
    "                    yield from rec(i + 1, rem)\n",
    "                else:\n",
    "                    d[i] = +t\n",
    "                    yield from rec(i + 1, rem - t)\n",
    "                    d[i] = -t\n",
    "                    yield from rec(i + 1, rem - t)\n",
    "            d[i] = 0\n",
    "\n",
    "        yield from rec(0, k)\n",
    "\n",
    "    t0 = time.perf_counter()\n",
    "    y = float(eval_batch(x[None, :])[0])\n",
    "    hist = [{\"iter\": 0, \"t\": 0.0, \"best_y\": y, \"x\": x.copy()}]\n",
    "    print(f\"iter=0  t=0.00s  best_y={y:.6g}  x={x.tolist()}\")\n",
    "\n",
    "    for it in range(1, int(max_iters) + 1):\n",
    "        cand = []\n",
    "        for k in range(1, delta + 1):\n",
    "            for dlt in deltas_exact(k):\n",
    "                z = x + dlt\n",
    "                if ok(z):\n",
    "                    cand.append(z)\n",
    "\n",
    "        if not cand:\n",
    "            print(f\"STOP: no feasible neighbors. best_y={y:.6g} x={x.tolist()}\")\n",
    "            break\n",
    "\n",
    "        Y = eval_batch(np.stack(cand, 0))\n",
    "        j = int(np.argmin(Y))\n",
    "        if float(Y[j]) >= y:\n",
    "            print(f\"STOP: local minimum. best_y={y:.6g} x={x.tolist()}\")\n",
    "            break\n",
    "\n",
    "        x, y = cand[j], float(Y[j])\n",
    "        hist.append({\"iter\": it, \"t\": time.perf_counter() - t0, \"best_y\": y, \"x\": x.copy()})\n",
    "        if it % max(1, int(print_every)) == 0:\n",
    "            print(f\"iter={it}  t={hist[-1]['t']:.2f}s  best_y={y:.6g}  x={x.tolist()}\")\n",
    "\n",
    "    return x, y, hist\n",
    "\n",
    "\n",
    "def _scaler_np(scaler):\n",
    "    xm = scaler[\"x_mean\"].detach().cpu().numpy().reshape(-1)\n",
    "    xs = scaler[\"x_std\"].detach().cpu().numpy().reshape(-1)\n",
    "    ym = float(scaler[\"y_mean\"].detach().cpu())\n",
    "    ys = float(scaler[\"y_std\"].detach().cpu())\n",
    "    return xm, xs, ym, ys\n",
    "\n",
    "\n",
    "def solve_dfn_ip_gurobi(dfn, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    x_mean, x_std, y_mean, y_std = _scaler_np(scaler)\n",
    "\n",
    "    # match DFN forward: cap = round(softplus(cap_raw)), A/b rounded if learnable\n",
    "    cost = np.r_[dfn.cost_raw.detach().cpu().numpy(), dfn.cost_fixed.detach().cpu().numpy()]\n",
    "    cap  = np.r_[torch.round(F.softplus(dfn.cap_raw.detach())).cpu().numpy(), dfn.cap_fixed.detach().cpu().numpy()]\n",
    "\n",
    "    A = dfn.A.detach().cpu().numpy()\n",
    "    if isinstance(dfn.A, torch.nn.Parameter):\n",
    "        A = np.round(A)\n",
    "    b = np.round(dfn.b_raw.detach().cpu().numpy())\n",
    "\n",
    "    src = dfn.src.detach().cpu().numpy().astype(int)\n",
    "    dst = dfn.dst.detach().cpu().numpy().astype(int)\n",
    "    boundary = dfn.boundary.detach().cpu().numpy().astype(int)\n",
    "    fix = int(dfn.fix_node)\n",
    "    n = int(dfn.n)\n",
    "    m = int(src.size)\n",
    "    alpha = float(dfn.alpha)\n",
    "    beta  = float(dfn.beta)\n",
    "\n",
    "    out = [[] for _ in range(n)]\n",
    "    inn = [[] for _ in range(n)]\n",
    "    for e in range(m):\n",
    "        out[src[e]].append(e)\n",
    "        inn[dst[e]].append(e)\n",
    "\n",
    "    M = gp.Model(\"DFN_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    f = M.addVars(m, lb=0.0, ub=cap.tolist(), vtype=GRB.CONTINUOUS, name=\"f\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq)\n",
    "\n",
    "    # supplies: s[boundary] = A*((x-x_mean)/x_std) + b, and s[fix] = -sum(s[boundary])\n",
    "    xm_over_xs = x_mean / x_std\n",
    "    s = [0] * n\n",
    "    s_boundary = []\n",
    "    for r, v in enumerate(boundary):\n",
    "        const = float(b[r] - (A[r] * xm_over_xs).sum())\n",
    "        expr = const + gp.quicksum((A[r, j] / x_std[j]) * x[j] for j in range(d) if A[r, j] != 0)\n",
    "        s[v] = expr\n",
    "        s_boundary.append(expr)\n",
    "    s[fix] = -gp.quicksum(s_boundary)\n",
    "\n",
    "    for v in range(n):\n",
    "        M.addConstr(gp.quicksum(f[e] for e in out[v]) - gp.quicksum(f[e] for e in inn[v]) == s[v])\n",
    "\n",
    "    flow_cost = gp.quicksum(cost[e] * f[e] for e in range(m))\n",
    "    M.setObjective((alpha * flow_cost + beta) * y_std + y_mean, GRB.MINIMIZE)\n",
    "\n",
    "    M.optimize()\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No solution (Gurobi status {M.Status})\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None)}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_mlp_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    # unwrap optional affine wrapper (if present)\n",
    "    base, a_out, b_out = model, 1.0, 0.0\n",
    "    if hasattr(model, \"base\") and hasattr(model, \"a\") and hasattr(model, \"b\"):\n",
    "        base, a_out, b_out = model.base, float(model.a), float(model.b)\n",
    "\n",
    "    if not hasattr(base, \"net\"):\n",
    "        raise ValueError(\"Expected an MLP with attribute .net (nn.Sequential).\")\n",
    "    linears = [L for L in base.net if isinstance(L, torch.nn.Linear)]\n",
    "    if not linears:\n",
    "        raise ValueError(\"No Linear layers found in base.net\")\n",
    "\n",
    "    W = [L.weight.detach().cpu().numpy().astype(float) for L in linears]\n",
    "    b = [L.bias.detach().cpu().numpy().astype(float) for L in linears]\n",
    "\n",
    "    # absorb input normalization: x_norm = (x-xm)/xs\n",
    "    W[0] = W[0] / xs[None, :]\n",
    "    b[0] = b[0] - W[0] @ xm\n",
    "\n",
    "    # absorb output affine (if any)\n",
    "    W[-1] *= a_out\n",
    "    b[-1] = a_out * b[-1] + b_out\n",
    "\n",
    "    # safe big-M bounds for pre-activations\n",
    "    u = np.maximum(np.abs(x_min), np.abs(x_max))\n",
    "    preLU = []\n",
    "    for k in range(len(W) - 1):\n",
    "        U = np.abs(W[k]) @ u + np.abs(b[k])\n",
    "        preLU.append((-U, U))\n",
    "        u = np.maximum(0.0, U)\n",
    "\n",
    "    M = gp.Model(\"MLP_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    prev = [x[i] for i in range(d)]\n",
    "\n",
    "    # hidden ReLU layers\n",
    "    for k in range(len(W) - 1):\n",
    "        Lk, Uk = preLU[k]\n",
    "        h = W[k].shape[0]\n",
    "\n",
    "        a = [M.addVar(lb=float(Lk[j]), ub=float(Uk[j]), name=f\"a{k}_{j}\") for j in range(h)]\n",
    "        z = [M.addVar(lb=0.0, ub=float(max(0.0, Uk[j])), name=f\"z{k}_{j}\") for j in range(h)]\n",
    "\n",
    "        for j in range(h):\n",
    "            M.addConstr(a[j] == b[k][j] + gp.quicksum(W[k][j, i] * prev[i] for i in range(len(prev))))\n",
    "\n",
    "            Lj, Uj = float(Lk[j]), float(Uk[j])\n",
    "            if Uj <= 0.0:\n",
    "                M.addConstr(z[j] == 0.0)\n",
    "            elif Lj >= 0.0:\n",
    "                M.addConstr(z[j] == a[j])\n",
    "            else:\n",
    "                s = M.addVar(vtype=GRB.BINARY, name=f\"s{k}_{j}\")\n",
    "                M.addConstr(z[j] >= a[j])\n",
    "                M.addConstr(z[j] >= 0.0)\n",
    "                M.addConstr(z[j] <= Uj * s)\n",
    "                M.addConstr(z[j] <= a[j] - Lj * (1 - s))\n",
    "\n",
    "        prev = z\n",
    "\n",
    "    if W[-1].shape[0] != 1:\n",
    "        raise ValueError(f\"Expected scalar output, got out_dim={W[-1].shape[0]}\")\n",
    "\n",
    "    y_norm = M.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name=\"y_norm\")\n",
    "    M.addConstr(y_norm == b[-1][0] + gp.quicksum(W[-1][0, i] * prev[i] for i in range(len(prev))))\n",
    "\n",
    "    M.setObjective(ys * y_norm + ym, GRB.MINIMIZE)\n",
    "\n",
    "    M.optimize()\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None), \"sol_count\": M.SolCount}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_maxaffine_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    base, a_out, b_out = model, 1.0, 0.0\n",
    "    if hasattr(model, \"base\") and hasattr(model, \"a\") and hasattr(model, \"b\"):\n",
    "        base, a_out, b_out = model.base, float(model.a), float(model.b)\n",
    "\n",
    "    W = base.W.detach().cpu().numpy().astype(float)\n",
    "    b = base.b.detach().cpu().numpy().astype(float)\n",
    "\n",
    "    # absorb x normalization: x_norm = (x-xm)/xs\n",
    "    Weff = W / xs[None, :]\n",
    "    beff = b - (Weff @ xm)\n",
    "\n",
    "    # absorb output affine\n",
    "    Weff *= a_out\n",
    "    beff  = a_out * beff + b_out\n",
    "\n",
    "    K = int(Weff.shape[0])\n",
    "\n",
    "    M = gp.Model(\"MaxAffine_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    # t >= each affine piece\n",
    "    t = M.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name=\"t_norm\")\n",
    "    for k in range(K):\n",
    "        M.addConstr(t >= beff[k] + gp.quicksum(Weff[k, j] * x[j] for j in range(d) if Weff[k, j] != 0.0))\n",
    "\n",
    "    M.setObjective(ys * t + ym, GRB.MINIMIZE)\n",
    "    M.optimize()\n",
    "\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None), \"sol_count\": M.SolCount}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_lset_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    A = model.A.detach().cpu().numpy().astype(float)\n",
    "    b = model.b.detach().cpu().numpy().astype(float)\n",
    "    T = float(model.T)\n",
    "    if T == 0.0:\n",
    "        raise ValueError(\"T must be nonzero\")\n",
    "\n",
    "    # absorb x_norm = (x_raw - xm)/xs\n",
    "    Aeff = A / xs[None, :]\n",
    "    beff = b - (Aeff @ xm)\n",
    "    K = int(Aeff.shape[0])\n",
    "\n",
    "    # bounds for lin_k(x) in each piece\n",
    "    lin_lo = np.empty(K, dtype=float)\n",
    "    lin_hi = np.empty(K, dtype=float)\n",
    "    for k in range(K):\n",
    "        a = Aeff[k]\n",
    "        lo = beff[k]\n",
    "        hi = beff[k]\n",
    "        pos = a >= 0\n",
    "        lo += (a[pos] * x_min[pos]).sum() + (a[~pos] * x_max[~pos]).sum()\n",
    "        hi += (a[pos] * x_max[pos]).sum() + (a[~pos] * x_min[~pos]).sum()\n",
    "        lin_lo[k], lin_hi[k] = lo, hi\n",
    "\n",
    "    z_lo = lin_lo / T\n",
    "    z_hi = lin_hi / T\n",
    "    m_lo = float(np.max(z_lo))\n",
    "    m_hi = float(np.max(z_hi))\n",
    "\n",
    "    w_lo = float(np.min(z_lo - m_hi))  # <= 0\n",
    "    u_lo = float(np.exp(max(-700.0, w_lo)))\n",
    "    s_lo = max(1e-12, K * u_lo)\n",
    "    s_hi = float(K * 1.0)\n",
    "    v_lo = float(np.log(s_lo))\n",
    "    v_hi = float(np.log(s_hi))\n",
    "\n",
    "    yN_lo = float(T * (m_lo + v_lo))\n",
    "    yN_hi = float(T * (m_hi + v_hi))\n",
    "    y_lo  = float(ys * yN_lo + ym)\n",
    "    y_hi  = float(ys * yN_hi + ym)\n",
    "\n",
    "    M = gp.Model(\"lset_ip_stable_bounded\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    M.Params.FuncNonlinear = 1\n",
    "    M.Params.FeasibilityTol = 1e-9\n",
    "    M.Params.OptimalityTol  = 1e-9\n",
    "    M.Params.IntFeasTol     = 1e-9\n",
    "    M.Params.NumericFocus   = 3\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    z = M.addVars(K, lb=z_lo.tolist(), ub=z_hi.tolist(), vtype=GRB.CONTINUOUS, name=\"z\")\n",
    "    for k in range(K):\n",
    "        lin = beff[k] + gp.quicksum(Aeff[k, j] * x[j] for j in range(d) if Aeff[k, j] != 0.0)\n",
    "        M.addConstr(z[k] == lin / T, name=f\"zdef_{k}\")\n",
    "\n",
    "    m = M.addVar(lb=m_lo, ub=m_hi, vtype=GRB.CONTINUOUS, name=\"m\")\n",
    "    w = M.addVars(K, lb=w_lo, ub=0.0, vtype=GRB.CONTINUOUS, name=\"w\")\n",
    "    for k in range(K):\n",
    "        M.addConstr(m >= z[k], name=f\"m_ge_z_{k}\")\n",
    "        M.addConstr(w[k] == z[k] - m, name=f\"wdef_{k}\")\n",
    "\n",
    "    u = M.addVars(K, lb=0.0, ub=1.0, vtype=GRB.CONTINUOUS, name=\"u\")\n",
    "    for k in range(K):\n",
    "        M.addGenConstrExp(w[k], u[k], name=f\"exp_{k}\")\n",
    "\n",
    "    s = M.addVar(lb=s_lo, ub=s_hi, vtype=GRB.CONTINUOUS, name=\"s\")\n",
    "    M.addConstr(s == gp.quicksum(u[k] for k in range(K)), name=\"sumexp_shifted\")\n",
    "\n",
    "    v = M.addVar(lb=v_lo, ub=v_hi, vtype=GRB.CONTINUOUS, name=\"v\")\n",
    "    M.addGenConstrLog(s, v, name=\"log_shifted\")\n",
    "\n",
    "    y_norm = M.addVar(lb=yN_lo, ub=yN_hi, vtype=GRB.CONTINUOUS, name=\"y_norm\")\n",
    "    M.addConstr(y_norm == T * (m + v), name=\"y_norm_def\")\n",
    "\n",
    "    y_raw = M.addVar(lb=y_lo, ub=y_hi, vtype=GRB.CONTINUOUS, name=\"y_raw\")\n",
    "    M.addConstr(y_raw == ys * y_norm + ym, name=\"y_raw_def\")\n",
    "\n",
    "    M.setObjective(y_raw, GRB.MINIMIZE)\n",
    "    M.optimize()\n",
    "\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution found. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], dtype=float)\n",
    "    y_star = float(y_raw.X)\n",
    "\n",
    "    info = {\n",
    "        \"status\": M.Status,\n",
    "        \"runtime\": M.Runtime,\n",
    "        \"gap\": getattr(M, \"MIPGap\", None),\n",
    "        \"sol_count\": M.SolCount,\n",
    "        \"obj_gurobi\": float(M.ObjVal),\n",
    "        \"obj_bound\": float(getattr(M, \"ObjBound\", float(\"nan\"))),\n",
    "    }\n",
    "    return x_star, y_star, info\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4858979",
   "metadata": {},
   "source": [
    "## 7) Benchmark / Tests (same behavior, but wrapped into one function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569f2165",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_ip(model_type, model, scaler, xmin, xmax, sum_eq, *, time_limit=None, verbose=False):\n",
    "    if model_type == \"DFN\":\n",
    "        return solve_dfn_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose)\n",
    "    if model_type == \"MLP\":\n",
    "        return solve_mlp_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose)\n",
    "    if model_type == \"MaxAffine\":\n",
    "        return solve_maxaffine_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose)\n",
    "    if model_type == \"LSET\":\n",
    "        return solve_lset_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose)\n",
    "    raise ValueError(model_type)\n",
    "\n",
    "\n",
    "def make_obj(model, scaler, device, chunk=4096):\n",
    "    \"\"\"Robust objective wrapper returning a numpy array of shape (B,).\"\"\"\n",
    "    xm = scaler[\"x_mean\"].to(device)\n",
    "    xs = scaler[\"x_std\"].to(device)\n",
    "    ym = scaler[\"y_mean\"].to(device)\n",
    "    ys = scaler[\"y_std\"].to(device)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def obj(Xraw):\n",
    "        Xraw = torch.as_tensor(Xraw, dtype=torch.float32, device=device)\n",
    "        if Xraw.dim() == 1:\n",
    "            Xraw = Xraw.unsqueeze(0)\n",
    "\n",
    "        outs = []\n",
    "        B = int(Xraw.shape[0])\n",
    "        for i in range(0, B, int(chunk)):\n",
    "            Xb = Xraw[i:i + chunk]\n",
    "            try:\n",
    "                Xn = (Xb - xm) / xs\n",
    "                yn = torch.as_tensor(model(Xn)).reshape(-1)\n",
    "                y  = (yn * ys + ym).reshape(-1).detach().cpu()\n",
    "                if y.numel() != Xb.shape[0]:\n",
    "                    y = y.expand(Xb.shape[0]).contiguous()\n",
    "                outs.append(y)\n",
    "            except Exception as e:\n",
    "                obj._err_count += int(Xb.shape[0])\n",
    "                if obj._err_first is None:\n",
    "                    obj._err_first = repr(e)\n",
    "                outs.append(torch.full((int(Xb.shape[0]),), float(\"inf\"), device=\"cpu\"))\n",
    "\n",
    "        return torch.cat(outs, 0).numpy()\n",
    "\n",
    "    obj._err_count = 0\n",
    "    obj._err_first = None\n",
    "    return obj\n",
    "\n",
    "\n",
    "def safe_raw_mse(obj, Xraw, yraw):\n",
    "    yp = np.asarray(obj(Xraw), dtype=float).reshape(-1)\n",
    "    yt = yraw.detach().cpu().numpy().reshape(-1) if torch.is_tensor(yraw) else np.asarray(yraw, float).reshape(-1)\n",
    "    if yp.shape[0] != yt.shape[0]:\n",
    "        return np.nan, f\"shape mismatch: pred {yp.shape} vs true {yt.shape}\"\n",
    "    mask = np.isfinite(yp)\n",
    "    if mask.sum() == 0:\n",
    "        return np.nan, \"all predictions were non-finite (inf/nan)\"\n",
    "    return float(np.mean((yp[mask] - yt[mask]) ** 2)), None\n",
    "\n",
    "\n",
    "def check_ip_matches_obj(name, obj, x_ip, y_ip, *, strict=False, tol=1e-4):\n",
    "    y_obj = float(np.asarray(obj(np.asarray(x_ip)), dtype=float).reshape(-1)[0])\n",
    "    y_ip = float(y_ip)\n",
    "    rel = abs(y_obj - y_ip) / (abs(y_obj) + 1e-12)\n",
    "    print(f\"[CHECK {name}] obj(x_ip)={y_obj:.6g}  ip_y={y_ip:.6g}  rel_err={rel:.3e}\")\n",
    "    if strict and rel > tol:\n",
    "        raise RuntimeError(f\"{name}: IP objective != obj() (rel_err={rel:.3e})\")\n",
    "    return rel\n",
    "\n",
    "\n",
    "def mean_se(x):\n",
    "    x = pd.to_numeric(pd.Series(x), errors=\"coerce\").to_numpy()\n",
    "    x = x[np.isfinite(x)]\n",
    "    n = int(x.shape[0])\n",
    "    if n == 0:\n",
    "        return np.nan, np.nan, 0\n",
    "    m = float(x.mean())\n",
    "    se = float(x.std(ddof=1) / np.sqrt(n)) if n > 1 else 0.0\n",
    "    return m, se, n\n",
    "\n",
    "\n",
    "def fmt_mean_se(m, se):\n",
    "    if not np.isfinite(m):\n",
    "        return \"nan\"\n",
    "    if not np.isfinite(se):\n",
    "        return f\"{m:.6g}\"\n",
    "    return f\"{m:.6g} ± {se:.3g}\"\n",
    "\n",
    "\n",
    "def run_benchmark(\n",
    "    *,\n",
    "    dataset_type: str,\n",
    "    dataset_params: dict,\n",
    "    runs: list[tuple[str, str, dict]],\n",
    "    train_base: dict,\n",
    "    lr_map: dict,\n",
    "    x0: np.ndarray,\n",
    "    xmin: np.ndarray,\n",
    "    xmax: np.ndarray,\n",
    "    delta: int,\n",
    "    sum_eq: int,\n",
    "    n_seeds: int = 1,\n",
    "    vary_dataset_seed: bool = False,\n",
    "    vary_model_init_seed: bool = True,\n",
    "    strict_ip_check: bool = False,\n",
    "    ip_check_tol: float = 1e-4,\n",
    "    silence_local_search: bool = False,\n",
    "    allow_plots_multi_seed: bool = True,\n",
    "    time_limit=None,\n",
    "):\n",
    "    \"\"\"Runs the same loop as the original 'Tests and Results' cell and prints the same tables.\"\"\"\n",
    "    seeds = list(range(int(n_seeds)))\n",
    "\n",
    "    learn_rows, opt_rows, spec_rows, fail_rows = [], [], [], []\n",
    "\n",
    "    def suppress_stdout(fn, *args, **kwargs):\n",
    "        if not silence_local_search:\n",
    "            return fn(*args, **kwargs), \"\"\n",
    "        buf = io.StringIO()\n",
    "        with contextlib.redirect_stdout(buf):\n",
    "            out = fn(*args, **kwargs)\n",
    "        return out, buf.getvalue()\n",
    "\n",
    "    def t_to_best(hist, y_best):\n",
    "        for r in hist:\n",
    "            if abs(float(r.get(\"best_y\", np.inf)) - float(y_best)) < 1e-12:\n",
    "                return float(r.get(\"t\", float(\"nan\")))\n",
    "        return float(\"nan\")\n",
    "\n",
    "    for seed in seeds:\n",
    "        print(f\"\\n\\n===================== SEED {seed} =====================\")\n",
    "\n",
    "        for name, model_type, model_params_base in runs:\n",
    "            dp = dict(dataset_params)\n",
    "            if vary_dataset_seed and \"seed\" in dp:\n",
    "                dp[\"seed\"] = int(seed)\n",
    "\n",
    "            mp = dict(model_params_base)\n",
    "            if vary_model_init_seed and \"seed\" in mp:\n",
    "                mp[\"seed\"] = int(seed)\n",
    "\n",
    "            tp = dict(train_base)\n",
    "            tp[\"seed\"] = int(seed)\n",
    "            tp[\"lr\"] = float(lr_map[model_type])\n",
    "\n",
    "            # avoid plot spam unless explicitly allowed\n",
    "            if (len(seeds) > 1) and (tp.get(\"plot_every\", 0) not in (0, None)) and (not allow_plots_multi_seed):\n",
    "                tp[\"plot_every\"] = 0\n",
    "\n",
    "            # ---- TRAIN ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                model, data, hist, spec = generate_and_train_simple(dataset_type, dp, model_type, mp, tp)\n",
    "            except Exception as e:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"TRAIN\", error=repr(e)))\n",
    "                learn_rows.append(dict(\n",
    "                    seed=seed, model=name, train_time=np.nan, best_epoch=np.nan, best_val=np.nan,\n",
    "                    train_raw=np.nan, val_raw=np.nan, test_raw=np.nan, train_err=repr(e)\n",
    "                ))\n",
    "                continue\n",
    "\n",
    "            train_time = time.perf_counter() - t0\n",
    "\n",
    "            if seed == seeds[0]:\n",
    "                spec_rows.append(dict(\n",
    "                    model=name,\n",
    "                    n_params=int(spec.get(\"n_params\", np.nan)),\n",
    "                    n_trainable=int(spec.get(\"n_trainable\", np.nan)),\n",
    "                    extra=str(spec.get(\"extra\", \"\")),\n",
    "                    lr=float(spec.get(\"train_params\", {}).get(\"lr\", tp[\"lr\"])),\n",
    "                    batch_size=int(spec.get(\"train_params\", {}).get(\"batch_size\", tp[\"batch_size\"])),\n",
    "                    epochs=int(spec.get(\"train_params\", {}).get(\"epochs\", tp[\"epochs\"])),\n",
    "                    weight_decay=float(spec.get(\"train_params\", {}).get(\"weight_decay\", tp.get(\"weight_decay\", 0.0))),\n",
    "                ))\n",
    "\n",
    "            device = data[\"device\"]\n",
    "            scaler = data[\"scaler\"]\n",
    "            obj = make_obj(model, scaler, device, chunk=int(tp.get(\"plot_chunk\", 4096)))\n",
    "            gt = data.get(\"true\", None)\n",
    "\n",
    "            # ---- LEARNING METRICS (RAW MSE) ----\n",
    "            train_raw, err_tr = safe_raw_mse(obj, data[\"raw\"][\"Xtr\"], data[\"raw\"][\"ytr\"])\n",
    "            val_raw,   err_va = safe_raw_mse(obj, data[\"raw\"][\"Xva\"], data[\"raw\"][\"yva\"])\n",
    "            test_raw,  err_te = safe_raw_mse(obj, data[\"raw\"][\"Xte\"], data[\"raw\"][\"yte\"])\n",
    "\n",
    "            if obj._err_count > 0:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"OBJ\",\n",
    "                                      error=f\"obj() had {obj._err_count} forward failures; first={obj._err_first}\"))\n",
    "\n",
    "            val_curve = np.asarray(hist.get(\"val_mse_norm\", []), dtype=float)\n",
    "            best_ep = int(np.argmin(val_curve) + 1) if val_curve.size else np.nan\n",
    "            best_val = float(np.min(val_curve)) if val_curve.size else np.nan\n",
    "\n",
    "            learn_rows.append(dict(\n",
    "                seed=seed, model=name,\n",
    "                train_time=float(train_time),\n",
    "                best_epoch=best_ep,\n",
    "                best_val=best_val,\n",
    "                train_raw=float(train_raw) if np.isfinite(train_raw) else np.nan,\n",
    "                val_raw=float(val_raw) if np.isfinite(val_raw) else np.nan,\n",
    "                test_raw=float(test_raw) if np.isfinite(test_raw) else np.nan,\n",
    "                train_err=(err_tr or err_va or err_te),\n",
    "            ))\n",
    "            if err_tr or err_va or err_te:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"EVAL_MSE\", error=(err_tr or err_va or err_te)))\n",
    "\n",
    "            # ---- LOCAL SEARCH ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                (ls_out, _log) = suppress_stdout(\n",
    "                    local_search_l1_int, obj, x0, xmin, xmax, delta=delta, sum_eq=sum_eq, print_every=0\n",
    "                )\n",
    "                x_best, y_best, ls_hist = ls_out\n",
    "                ls_time = time.perf_counter() - t0\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"LS\",\n",
    "                    y=float(y_best), runtime=float(ls_time),\n",
    "                    t_best=float(t_to_best(ls_hist, y_best)),\n",
    "                    iters=int(len(ls_hist) - 1),\n",
    "                    status=None, gap=None,\n",
    "                    x=str(np.asarray(x_best, int).tolist()),\n",
    "                    true_y=float(eval_true_obj(gt, x_best)),\n",
    "                    err=None,\n",
    "                ))\n",
    "            except Exception as e:\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"LS\",\n",
    "                    y=np.nan, runtime=float(time.perf_counter() - t0),\n",
    "                    t_best=np.nan, iters=np.nan,\n",
    "                    status=None, gap=None, x=None, true_y=np.nan,\n",
    "                    err=repr(e),\n",
    "                ))\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"LS\", error=repr(e)))\n",
    "\n",
    "            # ---- IP SOLVE ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                x_ip, y_ip, info = solve_ip(\n",
    "                    model_type, model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=False\n",
    "                )\n",
    "                ip_time = time.perf_counter() - t0\n",
    "                rel_err = check_ip_matches_obj(name, obj, x_ip, y_ip, strict=strict_ip_check, tol=ip_check_tol)\n",
    "                if rel_err > ip_check_tol:\n",
    "                    fail_rows.append(dict(seed=seed, model=name, stage=\"IP_CHECK\",\n",
    "                                          error=f\"rel_err={rel_err:.3e} (tol={ip_check_tol})\"))\n",
    "\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"IP\",\n",
    "                    y=float(y_ip), runtime=float(ip_time),\n",
    "                    t_best=np.nan, iters=np.nan,\n",
    "                    status=(info.get(\"status\") if isinstance(info, dict) else None),\n",
    "                    gap=(info.get(\"gap\") if isinstance(info, dict) else None),\n",
    "                    x=str(np.asarray(x_ip, int).tolist()),\n",
    "                    true_y=float(eval_true_obj(gt, x_ip)),\n",
    "                    err=None,\n",
    "                    ip_rel_err=float(rel_err),\n",
    "                ))\n",
    "            except Exception as e:\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"IP\",\n",
    "                    y=np.nan, runtime=float(time.perf_counter() - t0),\n",
    "                    t_best=np.nan, iters=np.nan,\n",
    "                    status=None, gap=None, x=None, true_y=np.nan, ip_rel_err=np.nan,\n",
    "                    err=repr(e),\n",
    "                ))\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"IP\", error=repr(e)))\n",
    "\n",
    "    # ---------- TABLES ----------\n",
    "    spec_df  = pd.DataFrame(spec_rows).drop_duplicates(\"model\").sort_values(\"model\").reset_index(drop=True)\n",
    "    learn_df = pd.DataFrame(learn_rows).sort_values([\"model\", \"seed\"]).reset_index(drop=True)\n",
    "    opt_df   = pd.DataFrame(opt_rows).sort_values([\"model\", \"seed\", \"method\"]).reset_index(drop=True)\n",
    "\n",
    "    fail_df = pd.DataFrame(fail_rows)\n",
    "    if fail_df.empty:\n",
    "        fail_df = pd.DataFrame(columns=[\"seed\", \"model\", \"stage\", \"error\"])\n",
    "    else:\n",
    "        for c in [\"stage\", \"model\", \"seed\"]:\n",
    "            if c not in fail_df.columns:\n",
    "                fail_df[c] = np.nan\n",
    "        fail_df = fail_df.sort_values([\"stage\", \"model\", \"seed\"]).reset_index(drop=True)\n",
    "\n",
    "    # per-seed LS vs IP % gap\n",
    "    gap_seed_df = pd.DataFrame(columns=[\"seed\", \"model\", \"ls_vs_ip_pct\"])\n",
    "    if not opt_df.empty:\n",
    "        pivot = opt_df.pivot_table(index=[\"seed\", \"model\"], columns=\"method\", values=\"y\", aggfunc=\"first\")\n",
    "        if (\"LS\" in pivot.columns) and (\"IP\" in pivot.columns):\n",
    "            pivot[\"ls_vs_ip_pct\"] = 100.0 * (pivot[\"LS\"] - pivot[\"IP\"]) / (np.abs(pivot[\"IP\"]) + 1e-12)\n",
    "            gap_seed_df = pivot.reset_index()[[\"seed\", \"model\", \"ls_vs_ip_pct\"]]\n",
    "            opt_df = opt_df.merge(gap_seed_df, on=[\"seed\", \"model\"], how=\"left\")\n",
    "        else:\n",
    "            opt_df[\"ls_vs_ip_pct\"] = np.nan\n",
    "\n",
    "    # aggregated summaries\n",
    "    learn_agg = []\n",
    "    for model in sorted(learn_df[\"model\"].unique()):\n",
    "        sub = learn_df[learn_df[\"model\"] == model]\n",
    "        row = {\"model\": model, \"n\": int(sub[\"seed\"].nunique())}\n",
    "        for col in [\"train_time\", \"best_val\", \"train_raw\", \"val_raw\", \"test_raw\"]:\n",
    "            m, se, n = mean_se(sub[col])\n",
    "            row[col] = fmt_mean_se(m, se)\n",
    "            row[col + \"_n\"] = n\n",
    "        row[\"n_eval_err\"] = int(sub[\"train_err\"].notna().sum())\n",
    "        learn_agg.append(row)\n",
    "    learn_agg_df = pd.DataFrame(learn_agg).sort_values(\"model\").reset_index(drop=True)\n",
    "\n",
    "    opt_agg = []\n",
    "    for model in sorted(opt_df[\"model\"].unique()):\n",
    "        sub = opt_df[opt_df[\"model\"] == model]\n",
    "        row = {\"model\": model, \"n_seeds\": int(sub[\"seed\"].nunique())}\n",
    "\n",
    "        for method in [\"LS\", \"IP\"]:\n",
    "            s = sub[sub[\"method\"] == method]\n",
    "            m, se, n = mean_se(s[\"y\"])\n",
    "            row[f\"{method}_y\"] = fmt_mean_se(m, se)\n",
    "            row[f\"{method}_y_n\"] = n\n",
    "\n",
    "            m, se, n = mean_se(s.get(\"true_y\", []))\n",
    "            row[f\"{method}_true_y\"] = fmt_mean_se(m, se)\n",
    "            row[f\"{method}_true_y_n\"] = n\n",
    "\n",
    "            m, se, n = mean_se(s[\"runtime\"])\n",
    "            row[f\"{method}_time\"] = fmt_mean_se(m, se)\n",
    "            row[f\"{method}_time_n\"] = n\n",
    "\n",
    "            row[f\"{method}_fails\"] = int(s[\"y\"].isna().sum())\n",
    "\n",
    "        g = gap_seed_df[gap_seed_df[\"model\"] == model][\"ls_vs_ip_pct\"] if not gap_seed_df.empty else []\n",
    "        m, se, n = mean_se(g)\n",
    "        row[\"LS_vs_IP_%\"] = fmt_mean_se(m, se)\n",
    "        row[\"LS_vs_IP_%_n\"] = n\n",
    "\n",
    "        ip_rel = sub[sub[\"method\"] == \"IP\"].get(\"ip_rel_err\", pd.Series([], dtype=float))\n",
    "        m, se, n = mean_se(ip_rel)\n",
    "        row[\"IP_obj_rel_err\"] = fmt_mean_se(m, se)\n",
    "        row[\"IP_obj_rel_err_n\"] = n\n",
    "\n",
    "        opt_agg.append(row)\n",
    "    opt_agg_df = pd.DataFrame(opt_agg).sort_values(\"model\").reset_index(drop=True)\n",
    "\n",
    "    print(\"\\n=== MODEL SPECS (from seed 0 run) ===\")\n",
    "    print(spec_df.to_string(index=False))\n",
    "\n",
    "    print(\"\\n=== LEARNING SUMMARY (mean ± SE over seeds) ===\")\n",
    "    print(learn_agg_df.to_string(index=False))\n",
    "\n",
    "    print(\"\\n=== OPTIMIZATION SUMMARY (mean ± SE over seeds) ===\")\n",
    "    print(opt_agg_df.to_string(index=False))\n",
    "\n",
    "    print(\"\\n=== FAILURES / WARNINGS (if any) ===\")\n",
    "    if fail_df.shape[0] == 0:\n",
    "        print(\"None\")\n",
    "    else:\n",
    "        print(fail_df.to_string(index=False))\n",
    "\n",
    "    return spec_df, learn_df, opt_df, fail_df, learn_agg_df, opt_agg_df\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e88e7a1f",
   "metadata": {},
   "source": [
    "### Example configuration (same as original)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8edf3f97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------- 0) CONFIG ----------\n",
    "N_SEEDS = 1\n",
    "\n",
    "VARY_DATASET_SEED = False\n",
    "VARY_MODEL_INIT_SEED = True\n",
    "STRICT_IP_CHECK = False\n",
    "IP_CHECK_TOL = 1e-4\n",
    "\n",
    "SILENCE_LOCAL_SEARCH = False\n",
    "ALLOW_PLOTS_MULTI_SEED = True\n",
    "\n",
    "dataset_type = \"quadratic\"\n",
    "dataset_params = dict(\n",
    "    K=5000, dim=10, eigen_min=1, eigen_max=20.0,\n",
    "    x_min=-100, x_max=100, noise_std=0.0, seed=0\n",
    ")\n",
    "in_dim = int(dataset_params[\"dim\"])\n",
    "\n",
    "train_base = dict(\n",
    "    epochs=1000,\n",
    "    batch_size=8,\n",
    "    val_frac=0.15,\n",
    "    test_frac=0.15,\n",
    "    seed=0,\n",
    "    device=(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
    "    eps=1e-8,\n",
    "    weight_decay=0.0,\n",
    "    plot_every=1,\n",
    "    plot_points=2048,\n",
    "    plot_chunk=4096,\n",
    ")\n",
    "\n",
    "# ---- DFN (learnable A) ----\n",
    "dfn_params = dict(\n",
    "    input_dim=in_dim, layer_sizes=[16, 128, 16], p_list=[1, 1],\n",
    "    seed=0, alpha=5e-3, beta=-2.0\n",
    ")\n",
    "\n",
    "# ---- DFN (fixed A = I) ----\n",
    "dfnA_layer_sizes = [5, 400, 6]\n",
    "assert dfnA_layer_sizes[0] + dfnA_layer_sizes[-1] == in_dim + 1, \"Need |L1|+|LK| = dim+1 for A_fixed=I\"\n",
    "\n",
    "dfn_Afix_params = dict(\n",
    "    input_dim=in_dim,\n",
    "    layer_sizes=dfnA_layer_sizes,\n",
    "    p_list=[1, 1],\n",
    "    seed=0,\n",
    "    alpha=5e-3,\n",
    "    beta=-2.0,\n",
    "    A_fixed=np.eye(in_dim, dtype=np.float32),  # shape (dim, dim)\n",
    ")\n",
    "\n",
    "# ---- other models ----\n",
    "mlp_params  = dict(in_dim=in_dim, hidden_dims=[128, 128], out_dim=1)\n",
    "maff_params = dict(in_dim=in_dim, n_pieces=1600)\n",
    "lset_params = dict(in_dim=in_dim, n_pieces=1600, T=0.05)\n",
    "\n",
    "lr_map = dict(DFN=1e-1, MLP=1e-3, MaxAffine=1e-2, LSET=1e-2)\n",
    "\n",
    "time_limit = 300\n",
    "x0    = np.array([50,-50,50,-50,50,50,-50,50,-50,-50], dtype=int)\n",
    "xmin  = np.full(in_dim, -100, dtype=int)\n",
    "xmax  = np.full(in_dim,  100, dtype=int)\n",
    "delta = 2\n",
    "sum_eq = 0\n",
    "\n",
    "runs = [\n",
    "    (\"DFN\",        \"DFN\", dfn_params),\n",
    "    (\"DFN_AfixI\",  \"DFN\", dfn_Afix_params),\n",
    "    (\"MLP\",        \"MLP\", mlp_params),\n",
    "    (\"MaxAffine\",  \"MaxAffine\", maff_params),\n",
    "    (\"LSET\",       \"LSET\", lset_params),\n",
    "]\n",
    "\n",
    "_ = run_benchmark(\n",
    "    dataset_type=dataset_type,\n",
    "    dataset_params=dataset_params,\n",
    "    runs=runs,\n",
    "    train_base=train_base,\n",
    "    lr_map=lr_map,\n",
    "    x0=x0, xmin=xmin, xmax=xmax,\n",
    "    delta=delta, sum_eq=sum_eq,\n",
    "    n_seeds=N_SEEDS,\n",
    "    vary_dataset_seed=VARY_DATASET_SEED,\n",
    "    vary_model_init_seed=VARY_MODEL_INIT_SEED,\n",
    "    strict_ip_check=STRICT_IP_CHECK,\n",
    "    ip_check_tol=IP_CHECK_TOL,\n",
    "    silence_local_search=SILENCE_LOCAL_SEARCH,\n",
    "    allow_plots_multi_seed=ALLOW_PLOTS_MULTI_SEED,\n",
    "    time_limit=time_limit,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac79bcb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dfn]",
   "language": "python",
   "name": "conda-env-dfn-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
