{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f5b14808",
   "metadata": {},
   "source": [
    "# DFN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2efb5921",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ce8269a4",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import cppimport\n",
    "\n",
    "import sys, time, math, io, contextlib\n",
    "from pathlib import Path\n",
    "from typing import Optional, List, Tuple, Dict, Any\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from IPython.display import display\n",
    "\n",
    "import gurobipy as gp\n",
    "from gurobipy import GRB\n",
    "\n",
    "_TOL = 1e-9"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417583e0",
   "metadata": {},
   "source": [
    "## LEMON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "62767112",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "repo = Path().resolve().parent\n",
    "if str(repo) not in sys.path:\n",
    "    sys.path.insert(0, str(repo))\n",
    "\n",
    "lemon_mcf = cppimport.imp(\"lemon_mcf\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15e6bd1b",
   "metadata": {},
   "source": [
    "### Quick sanity checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "88381d28",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'status': 1, 'flow': array([1., 2., 0.]), 'potential': array([-2.,  0., -1.]), 'reduced_cost': array([0., 0., 4.]), 'at_capacity': array([False,  True, False]), 'total_cost': 4.0}\n",
      "{'value': 4.0, 'flow': array([2., 2., 2., 2.])}\n"
     ]
    }
   ],
   "source": [
    "n = 3\n",
    "src    = np.array([0, 0, 1], dtype=np.int64)\n",
    "dst    = np.array([1, 2, 2], dtype=np.int64)\n",
    "cost   = np.array([2.0, 1.0, 3.0], dtype=np.float64)\n",
    "cap    = np.array([5.0, 2.0, 4.0], dtype=np.float64)\n",
    "supply = np.array([3.0, -1.0, -2.0], dtype=np.float64)\n",
    "\n",
    "out_min_cost_flow = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply)\n",
    "print(out_min_cost_flow)\n",
    "\n",
    "n = 4\n",
    "src = np.array([0,0,1,2], dtype=np.int64)\n",
    "dst = np.array([1,2,3,3], dtype=np.int64)\n",
    "cap = np.array([3.0,2.0,2.0,4.0], dtype=np.float64)\n",
    "out_max_flow = lemon_mcf.max_flow(n, src, dst, cap, 0, 3)\n",
    "print(out_max_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75235271",
   "metadata": {},
   "source": [
    "## Dataset generators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "67763cc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_mdvsp_dataset(\n",
    "    K: int,\n",
    "    filename: str,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "    max_trips=None,\n",
    "    max_succ=None,\n",
    "):\n",
    "    \"\"\"Multiple-Depot Vehicle Scheduling (MDVSP) dataset.\n",
    "\n",
    "    Returns:\n",
    "      X: (K, m) integer-ish capacities (float32)\n",
    "      y: (K,) min-cost-flow objective values (float32)\n",
    "      gt: dict containing the fixed network pieces (for later evaluation)\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    with open(filename) as f:\n",
    "        m, n, l = map(int, f.readline().split())\n",
    "        f.readline()  # blank line\n",
    "        trips = np.loadtxt(f, max_rows=n, dtype=np.int64)[:max_trips]\n",
    "        D = np.loadtxt(f, max_rows=l, dtype=np.int64)\n",
    "\n",
    "    p, s, q, e = trips.T\n",
    "    ntr = len(trips)\n",
    "\n",
    "    # node ids\n",
    "    SS = 0\n",
    "    depS = 1 + np.arange(m)\n",
    "    depT = 1 + m + np.arange(m)\n",
    "    trS  = 1 + 2*m + np.arange(ntr)\n",
    "    trT  = 1 + 2*m + ntr + np.arange(ntr)\n",
    "    TT = 1 + 2*m + 2*ntr\n",
    "    N = TT + 1\n",
    "\n",
    "    src, dst, cost, cap = [], [], [], []\n",
    "\n",
    "    # depot boundary arcs (capacity is what we learn/provide per sample)\n",
    "    for d in range(m):\n",
    "        src += [SS, int(depT[d])]\n",
    "        dst += [int(depS[d]), TT]\n",
    "        cost += [0.0, 0.0]\n",
    "        cap  += [0.0, 0.0]\n",
    "    idxSS = np.arange(0, 2*m, 2)  # arcs SS->depS\n",
    "    idxTT = np.arange(1, 2*m, 2)  # arcs depT->TT\n",
    "\n",
    "    # trip arcs (always cap=1)\n",
    "    src += trS.tolist()\n",
    "    dst += trT.tolist()\n",
    "    cost += [0.0] * ntr\n",
    "    cap  += [1.0] * ntr\n",
    "\n",
    "    # depot-to-trip and trip-to-depot arcs\n",
    "    for d in range(m):\n",
    "        # depS -> trS\n",
    "        src += [int(depS[d])] * ntr\n",
    "        dst += trS.tolist()\n",
    "        cost += (5000 + 10 * D[d, p]).astype(float).tolist()\n",
    "        cap  += [1.0] * ntr\n",
    "\n",
    "        # trT -> depT\n",
    "        src += trT.tolist()\n",
    "        dst += [int(depT[d])] * ntr\n",
    "        cost += (5000 + 10 * D[q, d]).astype(float).tolist()\n",
    "        cap  += [1.0] * ntr\n",
    "\n",
    "    # feasible trip successor arcs (trT -> next trS)\n",
    "    order = np.argsort(s)\n",
    "    p2, s2 = p[order], s[order]\n",
    "    max_succ_eff = ntr if max_succ is None else int(max_succ)\n",
    "\n",
    "    for i in range(ntr):\n",
    "        travel = D[q[i], p2]                          # time from trip i end depot -> next trip start depot\n",
    "        feas = np.flatnonzero(s2 >= e[i] + travel)[:max_succ_eff]\n",
    "        j = order[feas]\n",
    "        if j.size:\n",
    "            src += [int(trT[i])] * j.size\n",
    "            dst += trS[j].tolist()\n",
    "            cost += (8 * travel[feas] + 2 * (s[j] - e[i])).astype(float).tolist()\n",
    "            cap  += [1.0] * j.size\n",
    "\n",
    "    src = np.asarray(src, dtype=np.int64)\n",
    "    dst = np.asarray(dst, dtype=np.int64)\n",
    "    cost = np.asarray(cost, dtype=np.float64)\n",
    "    cap0 = np.asarray(cap, dtype=np.float64)\n",
    "\n",
    "    # sample capacities X and compute y via max-flow + min-cost-flow\n",
    "    X = rng.integers(x_min, np.asarray(x_max) + 1, size=(K, m)).astype(np.float64)\n",
    "    y = np.empty(K, dtype=np.float64)\n",
    "\n",
    "    for k in range(K):\n",
    "        cap_k = cap0.copy()\n",
    "        cap_k[idxSS] = X[k]\n",
    "        cap_k[idxTT] = X[k]\n",
    "\n",
    "        Fmax = lemon_mcf.max_flow(N, src, dst, cap_k, SS, TT)[\"value\"]\n",
    "        supply = np.zeros(N, dtype=np.float64)\n",
    "        supply[SS] = Fmax\n",
    "        supply[TT] = -Fmax\n",
    "\n",
    "        y[k] = lemon_mcf.solve_mcf(N, src, dst, cost, cap_k, supply)[\"total_cost\"]\n",
    "        if noise_std:\n",
    "            y[k] += noise_std * rng.normal()\n",
    "\n",
    "    gt = dict(\n",
    "        type=\"mdvsp\", N=int(N), SS=int(SS), TT=int(TT),\n",
    "        src=src, dst=dst, cost=cost, cap0=cap0, idxSS=idxSS, idxTT=idxTT\n",
    "    )\n",
    "    return X.astype(np.float32), y.astype(np.float32), gt\n",
    "\n",
    "\n",
    "def generate_bipartite_subset_matching_dataset(\n",
    "    K: int, num_nodes: int, c_min: int, c_max: int, noise_std: float = 0.0, seed: int = 0\n",
    "):\n",
    "    \"\"\"Assignment-style dataset: choose a subset of left nodes, match to right nodes with min cost.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    C = rng.integers(c_min, c_max + 1, size=(num_nodes, num_nodes)).astype(np.float32)\n",
    "\n",
    "    X = np.zeros((K, num_nodes), dtype=np.float32)\n",
    "    y = np.zeros((K,), dtype=np.float32)\n",
    "\n",
    "    for k in range(K):\n",
    "        mask = rng.integers(0, 2, size=num_nodes, dtype=np.int8)\n",
    "        while not mask.any():\n",
    "            mask = rng.integers(0, 2, size=num_nodes, dtype=np.int8)\n",
    "        idx = np.flatnonzero(mask)\n",
    "        X[k, idx] = 1.0\n",
    "\n",
    "        r, c = linear_sum_assignment(C[idx, :])\n",
    "        y[k] = C[idx, :][r, c].sum()\n",
    "        if noise_std:\n",
    "            y[k] += noise_std * rng.normal()\n",
    "\n",
    "    gt = {\"type\": \"assignment\", \"C\": C.astype(np.float32)}\n",
    "    return X, y, gt\n",
    "\n",
    "\n",
    "def generate_convex_quadratic_dataset(\n",
    "    K: int,\n",
    "    dim: int,\n",
    "    eigen_min: float,\n",
    "    eigen_max: float,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "    x_star_zero: bool = False,\n",
    "):\n",
    "    \"\"\"y = (x - x*)^T Q (x - x*) + noise, with Q symmetric PSD.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    U, R = np.linalg.qr(rng.standard_normal((dim, dim)))\n",
    "    U *= np.sign(np.diag(R) + 1e-12)\n",
    "    Q = U @ np.diag(rng.uniform(eigen_min, eigen_max, dim)) @ U.T\n",
    "\n",
    "    x_star = np.zeros(dim, dtype=np.int64) if x_star_zero else rng.integers(x_min, x_max + 1, size=dim)\n",
    "\n",
    "    X = rng.integers(x_min, x_max + 1, size=(K, dim)).astype(np.float32)\n",
    "    d = X - x_star\n",
    "    y = np.einsum(\"bi,ij,bj->b\", d, Q, d) + noise_std * rng.normal(size=K)\n",
    "\n",
    "    gt = {\"type\": \"quadratic\", \"Q\": Q.astype(np.float32), \"x_star\": x_star.astype(np.int64)}\n",
    "    return X.astype(np.float32), y.astype(np.float32), gt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dfc62e4",
   "metadata": {},
   "source": [
    "### Quick sanity checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "89201b4a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50, 16) (50,) \n",
      " [[ 5.  5.  8. 10.  0.  1.  9. 10.  2.  3.  9.  4.  3.  9.  2.  4.]\n",
      " [ 7.  6.  0.  0.  9.  8.  9.  5.  8.  3.  4.  8.  1.  3.  1.  4.]\n",
      " [10.  1.  4.  4.  9.  2.  5.  2.  0.  8.  0.  3.  5.  5.  1. 10.]\n",
      " [ 8. 10.  1.  7.  3.  5. 10.  3.  7.  1.  3. 10.  4.  5.  3.  1.]\n",
      " [ 4.  6.  5.  8.  3.  6.  8. 10.  4.  0.  7.  5.  9.  5.  4.  0.]] \n",
      " [853669.8  770258.3  700160.2  822280.25 852159.1 ] \n",
      "\n",
      "(50, 10) (50,) \n",
      " [[0. 1. 1. 0. 0. 0. 0. 1. 0. 1.]\n",
      " [1. 0. 0. 0. 0. 0. 1. 1. 0. 0.]\n",
      " [0. 1. 0. 0. 1. 0. 1. 0. 1. 1.]\n",
      " [1. 1. 0. 0. 1. 1. 1. 1. 0. 1.]\n",
      " [0. 1. 0. 1. 1. 0. 0. 0. 1. 0.]] \n",
      " [ 9.  4. 13. 17. 11.] \n",
      "\n",
      "(50, 10) (50,) \n",
      " [[-47. -87.  -1.  69.  25. -87.  30. -31. -55. -14.]\n",
      " [ 75.  94. -72.  13.  53. -48. -46. -52. -58.  78.]\n",
      " [-57. -55. -75. -75.  56. -43.  61.  17.  72.  11.]\n",
      " [ 53.  62. -88.  12.  -9. -43. -10. -18.  -2.  64.]\n",
      " [ 65.  25.  42.  92.  28. -26. -84.  11. -54.  19.]] \n",
      " [1.5913905e+06 9.7035825e+05 7.8526031e+05 5.9535256e+05 1.0773938e+06] \n",
      "\n"
     ]
    }
   ],
   "source": [
    "# NOTE: MDVSP requires that `filename` exists on disk.\n",
    "X, y, _ = make_mdvsp_dataset(\n",
    "    K=50, filename=\"RN-16-3000-05.dat\", x_min=0, x_max=10, noise_std=1.0, seed=1, max_trips=200, max_succ=5\n",
    ")\n",
    "print(X.shape, y.shape, \"\\n\", X[:5], \"\\n\", y[:5], \"\\n\")\n",
    "\n",
    "X, y, _ = generate_bipartite_subset_matching_dataset(\n",
    "    K=50, num_nodes=10, c_min=1, c_max=10, noise_std=0.0, seed=0\n",
    ")\n",
    "print(X.shape, y.shape, \"\\n\", X[:5], \"\\n\", y[:5], \"\\n\")\n",
    "\n",
    "X, y, _ = generate_convex_quadratic_dataset(\n",
    "    K=50, dim=10, eigen_min=1.0, eigen_max=20.0, x_min=-100, x_max=100, noise_std=0.0, seed=0\n",
    ")\n",
    "print(X.shape, y.shape, \"\\n\", X[:5], \"\\n\", y[:5], \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70071f5f",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "07278db8",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def _ste_round(x: torch.Tensor) -> torch.Tensor:\n",
    "    return x + (torch.round(x) - x).detach()\n",
    "\n",
    "class _MCFValue(torch.autograd.Function):\n",
    "    \n",
    "    @staticmethod\n",
    "    def forward(ctx, n_nodes, src, dst, cost, cap, supply):\n",
    "        n = int(n_nodes)\n",
    "\n",
    "        src = src.to(dtype=torch.int64).contiguous()\n",
    "        dst = dst.to(dtype=torch.int64).contiguous()\n",
    "\n",
    "        m = int(src.numel())\n",
    "        if (dst.numel() != m) or (cost.numel() != m) or (cap.numel() != m) or (supply.numel() != n):\n",
    "            raise ValueError(\"Bad shapes for MCF inputs.\")\n",
    "        if torch.abs(supply.double().sum()) > _TOL:\n",
    "            raise ValueError(\"Require sum(supply)=0\")\n",
    "\n",
    "        def as_np(t: torch.Tensor, dtype):\n",
    "            return t.detach().cpu().contiguous().view(-1).numpy().astype(dtype, copy=False)\n",
    "\n",
    "        out = lemon_mcf.solve_mcf(\n",
    "            n,\n",
    "            as_np(src, np.int64),\n",
    "            as_np(dst, np.int64),\n",
    "            as_np(cost, np.float64),\n",
    "            as_np(cap, np.float64),\n",
    "            as_np(supply, np.float64),\n",
    "            tol=_TOL,\n",
    "        )\n",
    "        if out[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"LEMON failed (status={out['status']})\")\n",
    "\n",
    "        flow = out[\"flow\"]\n",
    "        pot = out[\"potential\"]\n",
    "        red = out[\"reduced_cost\"]\n",
    "        at = out.get(\"at_cap\", out.get(\"at_capacity\", None))\n",
    "        if at is None:\n",
    "            at = np.abs(flow - as_np(cap, np.float64)) <= _TOL\n",
    "\n",
    "        ctx.flow, ctx.pot, ctx.red, ctx.at = flow, pot, red, at\n",
    "        return cost.new_tensor(float(out[\"total_cost\"]))\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g):\n",
    "        dev, dt = g.device, g.dtype\n",
    "        flow = torch.as_tensor(ctx.flow, device=dev, dtype=dt)\n",
    "        pot  = torch.as_tensor(ctx.pot,  device=dev, dtype=dt)\n",
    "        red  = torch.as_tensor(ctx.red,  device=dev, dtype=dt)\n",
    "        at   = torch.as_tensor(ctx.at,   device=dev, dtype=torch.bool)\n",
    "\n",
    "        grad_cost = flow\n",
    "        grad_cap  = torch.where(at, red, torch.zeros_like(red))\n",
    "        grad_sup  = pot.mean() - pot\n",
    "\n",
    "        return None, None, None, grad_cost * g, grad_cap * g, grad_sup * g\n",
    "\n",
    "\n",
    "class DFN(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        layer_sizes,\n",
    "        p_list,\n",
    "        big_cost: float = 1e6,\n",
    "        big_cap: float = 1e6,\n",
    "        seed: int = 0,\n",
    "        A_fixed=None,\n",
    "        alpha: float = 1e-6,\n",
    "        beta: float = -0.0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.alpha = float(alpha)\n",
    "        self.beta  = float(beta)\n",
    "\n",
    "        layer_sizes = list(map(int, layer_sizes))\n",
    "        if len(layer_sizes) < 2 or len(p_list) != len(layer_sizes) - 1:\n",
    "            raise ValueError(\"Need len(layer_sizes)>=2 and len(p_list)=len(layer_sizes)-1\")\n",
    "\n",
    "        self.n = int(sum(layer_sizes))\n",
    "        if self.n <= 0:\n",
    "            raise ValueError(\"sum(layer_sizes) must be > 0\")\n",
    "\n",
    "        # node indices per layer\n",
    "        layers, off = [], 0\n",
    "        for s in layer_sizes:\n",
    "            layers.append(torch.arange(off, off + s, dtype=torch.long))\n",
    "            off += s\n",
    "\n",
    "        L1, LK = layers[0], layers[-1]\n",
    "        if L1.numel() == 0 or LK.numel() == 0:\n",
    "            raise ValueError(\"First/last layer must be non-empty.\")\n",
    "\n",
    "        self.fix_node = int(LK[-1].item())\n",
    "        boundary = torch.cat([L1, LK[:-1]], 0)\n",
    "        self.register_buffer(\"boundary\", boundary)\n",
    "\n",
    "        gen = torch.Generator().manual_seed(int(seed))\n",
    "\n",
    "        def bipartite(U: torch.Tensor, V: torch.Tensor):\n",
    "            su, sv = int(U.numel()), int(V.numel())\n",
    "            return U.repeat_interleave(sv), V.repeat(su)\n",
    "\n",
    "        def sample_edges(U, V, p: float):\n",
    "            s, t = bipartite(U, V)\n",
    "            if p < 1.0:\n",
    "                keep = torch.rand(s.numel(), generator=gen) < float(p)\n",
    "                s, t = s[keep], t[keep]\n",
    "            return s, t\n",
    "\n",
    "        # learnable arcs between consecutive layers (both directions)\n",
    "        sf, tf, sb, tb = [], [], [], []\n",
    "        for i, p in enumerate(map(float, p_list)):\n",
    "            if not (0.0 <= p <= 1.0):\n",
    "                raise ValueError(\"p_list entries must be in [0,1]\")\n",
    "\n",
    "            s, t = sample_edges(layers[i], layers[i + 1], p)  # forward\n",
    "            sf.append(s); tf.append(t)\n",
    "\n",
    "            s, t = sample_edges(layers[i + 1], layers[i], p)  # backward\n",
    "            sb.append(s); tb.append(t)\n",
    "\n",
    "        src_param = torch.cat([torch.cat(sf, 0), torch.cat(sb, 0)], 0)\n",
    "        dst_param = torch.cat([torch.cat(tf, 0), torch.cat(tb, 0)], 0)\n",
    "        if src_param.numel() == 0:\n",
    "            raise ValueError(\"No learnable arcs (increase p_list / layer sizes).\")\n",
    "\n",
    "        s1, t1 = bipartite(L1, LK)\n",
    "        s2, t2 = bipartite(LK, L1)\n",
    "        src_fixed = torch.cat([s1, s2], 0)\n",
    "        dst_fixed = torch.cat([t1, t2], 0)\n",
    "        m_fixed = int(src_fixed.numel())\n",
    "\n",
    "        self.register_buffer(\"src\", torch.cat([src_param, src_fixed], 0))\n",
    "        self.register_buffer(\"dst\", torch.cat([dst_param, dst_fixed], 0))\n",
    "        self.register_buffer(\"cap_fixed\",  torch.full((m_fixed,), float(big_cap),  dtype=torch.float32))\n",
    "        self.register_buffer(\"cost_fixed\", torch.full((m_fixed,), float(big_cost), dtype=torch.float32))\n",
    "\n",
    "        nb = int(boundary.numel())\n",
    "        input_dim = int(input_dim)\n",
    "\n",
    "        m_param = int(src_param.numel())\n",
    "        self.cap_raw  = nn.Parameter(torch.zeros(m_param) + 0.542)\n",
    "        self.cost_raw = nn.Parameter(torch.randn(m_param) + 1.0)\n",
    "        self.b_raw    = nn.Parameter(torch.zeros(nb))\n",
    "\n",
    "        if A_fixed is None:\n",
    "            A = torch.zeros(nb, input_dim)\n",
    "            rows = torch.arange(nb)\n",
    "            A[rows, rows % input_dim] = 1.0\n",
    "            self.A = nn.Parameter(A)\n",
    "        else:\n",
    "            A_fixed = torch.as_tensor(A_fixed, dtype=torch.float32)\n",
    "            if A_fixed.shape != (nb, input_dim):\n",
    "                raise ValueError(f\"A_fixed must have shape {(nb, input_dim)}, got {tuple(A_fixed.shape)}\")\n",
    "            self.register_buffer(\"A\", A_fixed)\n",
    "\n",
    "    def forward(self, w: torch.Tensor) -> torch.Tensor:\n",
    "        capP  = _ste_round(F.softplus(self.cap_raw))\n",
    "        costP = self.cost_raw\n",
    "        b     = _ste_round(self.b_raw)\n",
    "        A     = _ste_round(self.A) if isinstance(self.A, nn.Parameter) else self.A\n",
    "\n",
    "        cap  = torch.cat([capP,  self.cap_fixed.to(w.device, w.dtype)], 0)\n",
    "        cost = torch.cat([costP, self.cost_fixed.to(w.device, w.dtype)], 0)\n",
    "\n",
    "        def one(w1: torch.Tensor) -> torch.Tensor:\n",
    "            supply = torch.zeros(self.n, device=w1.device, dtype=torch.float64)\n",
    "            supply[self.boundary] = (A.double() @ w1.double()) + b.double()\n",
    "            supply[self.fix_node] = -supply.sum()\n",
    "            return _MCFValue.apply(self.n, self.src, self.dst, cost, cap, supply)\n",
    "\n",
    "        out = one(w) if w.dim() == 1 else torch.stack([one(wi) for wi in w], 0)\n",
    "        return self.alpha * out + self.beta\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dims, out_dim):\n",
    "        super().__init__()\n",
    "        dims = [in_dim] + list(hidden_dims) + [out_dim]\n",
    "        layers = []\n",
    "        for a, b in zip(dims[:-2], dims[1:-1]):\n",
    "            layers += [nn.Linear(a, b), nn.ReLU()]\n",
    "        layers += [nn.Linear(dims[-2], dims[-1])]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "class MaxAffine(nn.Module):\n",
    "    def __init__(self, in_dim: int, n_pieces: int):\n",
    "        super().__init__()\n",
    "        self.W = nn.Parameter(torch.randn(n_pieces, in_dim) / (in_dim**0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces))\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        return (x @ self.W.T + self.b).max(dim=1).values\n",
    "\n",
    "\n",
    "class LSET(nn.Module):\n",
    "    def __init__(self, in_dim: int, n_pieces: int, T: float = 0.01):\n",
    "        super().__init__()\n",
    "        self.T = float(T)\n",
    "        if self.T == 0.0:\n",
    "            raise ValueError(\"T must be nonzero\")\n",
    "        self.A = nn.Parameter(torch.randn(n_pieces, in_dim) / (in_dim**0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces))\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        z = (x @ self.A.t() + self.b) / self.T\n",
    "        return self.T * torch.logsumexp(z, dim=-1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cca0120",
   "metadata": {},
   "source": [
    "## Training Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0a7db872",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def _split_train_val_test(X: torch.Tensor, y: torch.Tensor, *, val_frac: float, test_frac: float, seed: int):\n",
    "    N = int(X.shape[0])\n",
    "    n_test = int(round(test_frac * N))\n",
    "    n_val  = int(round(val_frac  * N))\n",
    "    n_train = N - n_val - n_test\n",
    "    if n_train <= 0:\n",
    "        raise ValueError(\"splits too large; train set would be empty\")\n",
    "\n",
    "    g = torch.Generator().manual_seed(int(seed))\n",
    "    perm = torch.randperm(N, generator=g)\n",
    "    i_tr = perm[:n_train]\n",
    "    i_va = perm[n_train:n_train + n_val]\n",
    "    i_te = perm[n_train + n_val:]\n",
    "\n",
    "    return (X[i_tr], y[i_tr]), (X[i_va], y[i_va]), (X[i_te], y[i_te])\n",
    "\n",
    "\n",
    "def _fit_standardizer(Xtr: torch.Tensor, ytr: torch.Tensor, *, eps: float):\n",
    "    x_mean = Xtr.mean(0, keepdim=True)\n",
    "    x_std  = Xtr.std(0, unbiased=False, keepdim=True)\n",
    "    x_std  = torch.where(x_std < eps, torch.ones_like(x_std), x_std)\n",
    "\n",
    "    y_mean = ytr.mean()\n",
    "    y_std  = ytr.std(unbiased=False).clamp_min(eps)\n",
    "\n",
    "    return {\"x_mean\": x_mean, \"x_std\": x_std, \"y_mean\": y_mean, \"y_std\": y_std}\n",
    "\n",
    "\n",
    "def _apply_standardizer(X: torch.Tensor, y: torch.Tensor, scaler):\n",
    "    Xn = (X - scaler[\"x_mean\"]) / scaler[\"x_std\"]\n",
    "    yn = (y - scaler[\"y_mean\"]) / scaler[\"y_std\"]\n",
    "    return Xn, yn\n",
    "\n",
    "@torch.no_grad()\n",
    "def _mse_norm(model: nn.Module, loader: DataLoader, device: str):\n",
    "    model.eval()\n",
    "    tot, n = 0.0, 0\n",
    "    for xb, yb in loader:\n",
    "        xb, yb = xb.to(device), yb.to(device)\n",
    "        pred = model(xb).squeeze(-1)\n",
    "        tot += F.mse_loss(pred, yb, reduction=\"sum\").item()\n",
    "        n += yb.numel()\n",
    "    return tot / max(n, 1)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def _predict_in_chunks(model: nn.Module, X: torch.Tensor, *, chunk: int):\n",
    "    out = []\n",
    "    for i in range(0, int(X.shape[0]), int(chunk)):\n",
    "        out.append(model(X[i:i + chunk]).squeeze(-1))\n",
    "    return torch.cat(out, 0)\n",
    "\n",
    "\n",
    "import json, hashlib\n",
    "\n",
    "# ---- Caching / persistence helpers ----\n",
    "_CACHE_ROOT_DEFAULT = Path(\"saved_runs\")\n",
    "\n",
    "def _to_hashable(x):\n",
    "    \"\"\"Convert nested objects to a deterministic, hashable (and mostly JSON-friendly) form.\"\"\"\n",
    "    import numpy as _np\n",
    "    import torch as _torch\n",
    "    from pathlib import Path as _Path\n",
    "\n",
    "    if x is None or isinstance(x, (bool, int, str)):\n",
    "        return x\n",
    "    if isinstance(x, float):\n",
    "        return float(x)\n",
    "    if isinstance(x, _Path):\n",
    "        return str(x)\n",
    "    if isinstance(x, (list, tuple)):\n",
    "        return [_to_hashable(v) for v in x]\n",
    "    if isinstance(x, dict):\n",
    "        return {str(k): _to_hashable(x[k]) for k in sorted(x.keys(), key=lambda z: str(z))}\n",
    "    if isinstance(x, _np.ndarray):\n",
    "        h = hashlib.sha256(x.tobytes(order=\"C\")).hexdigest()\n",
    "        return {\"__ndarray__\": True, \"dtype\": str(x.dtype), \"shape\": list(x.shape), \"sha256\": h}\n",
    "    if isinstance(x, _torch.Tensor):\n",
    "        t = x.detach().cpu()\n",
    "        h = hashlib.sha256(t.numpy().tobytes(order=\"C\")).hexdigest()\n",
    "        return {\"__tensor__\": True, \"dtype\": str(t.dtype), \"shape\": list(t.shape), \"sha256\": h}\n",
    "    return {\"__repr__\": repr(x)}\n",
    "\n",
    "def _run_signature(dataset_type, dataset_params, model_type, model_params, train_sig):\n",
    "    return {\n",
    "        \"dataset_type\": str(dataset_type).lower(),\n",
    "        \"dataset_params\": _to_hashable(dict(dataset_params)),\n",
    "        \"model_type\": str(model_type),\n",
    "        \"model_params\": _to_hashable(dict(model_params)),\n",
    "        \"train_params\": _to_hashable(dict(train_sig)),\n",
    "    }\n",
    "\n",
    "def _run_id_from_signature(sig) -> str:\n",
    "    blob = json.dumps(sig, sort_keys=True, separators=(\",\", \":\"), ensure_ascii=True).encode(\"utf-8\")\n",
    "    return hashlib.sha256(blob).hexdigest()[:16]\n",
    "\n",
    "def _get_run_dir(cache_root: Path, dataset_type, model_type, run_id: str) -> Path:\n",
    "    return Path(cache_root) / str(dataset_type).lower() / str(model_type) / run_id\n",
    "\n",
    "def _cpuify(obj):\n",
    "    import torch as _torch\n",
    "    if isinstance(obj, _torch.Tensor):\n",
    "        return obj.detach().cpu()\n",
    "    if isinstance(obj, dict):\n",
    "        return {k: _cpuify(v) for k, v in obj.items()}\n",
    "    if isinstance(obj, (list, tuple)):\n",
    "        return [_cpuify(v) for v in obj]\n",
    "    return obj\n",
    "\n",
    "def _set_full_determinism(seed: int):\n",
    "    \"\"\"CPU-only determinism: seed controls Python/NumPy/Torch RNGs.\n",
    "    (CUDA is intentionally not used in these notebooks.)\n",
    "    \"\"\"\n",
    "    import os, random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    seed = int(seed)\n",
    "\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    # Determinism flags (will error if a non-deterministic op is used)\n",
    "    try:\n",
    "        torch.use_deterministic_algorithms(True)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "def _seed_worker(worker_id: int):\n",
    "    import random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    worker_seed = torch.initial_seed() % 2**32\n",
    "    np.random.seed(worker_seed)\n",
    "    random.seed(worker_seed)\n",
    "\n",
    "def _try_load_cached_run(run_dir: Path):\n",
    "    artifact_path = Path(run_dir) / \"artifact.pt\"\n",
    "    if not artifact_path.exists():\n",
    "        return None\n",
    "    try:\n",
    "        return torch.load(str(artifact_path), map_location=\"cpu\")\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "def _build_model_from_params(model_type, model_params):\n",
    "    mt = str(model_type)\n",
    "    mp = dict(model_params)\n",
    "    if mt == \"DFN\":\n",
    "        model = DFN(**mp)\n",
    "    else:\n",
    "        mp.pop(\"alpha\", None)\n",
    "        mp.pop(\"beta\", None)\n",
    "        if mt == \"MLP\":\n",
    "            model = MLP(**mp)\n",
    "        elif mt == \"MaxAffine\":\n",
    "            model = MaxAffine(**mp)\n",
    "        elif mt == \"LSET\":\n",
    "            model = LSET(**mp)\n",
    "        else:\n",
    "            raise ValueError(\"model_type must be: DFN | MLP | MaxAffine | LSET\")\n",
    "    return model\n",
    "\n",
    "def _save_run(run_dir: Path, signature: dict, model: nn.Module, data: dict, history: dict, spec: dict):\n",
    "    run_dir = Path(run_dir)\n",
    "    run_dir.mkdir(parents=True, exist_ok=True)\n",
    "    artifact_path = run_dir / \"artifact.pt\"\n",
    "    if artifact_path.exists():\n",
    "        return\n",
    "    state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n",
    "    artifact = {\n",
    "        \"signature\": signature,\n",
    "        \"state_dict\": state_dict,\n",
    "        \"data\": _cpuify(data),\n",
    "        \"history\": history,\n",
    "        \"spec\": spec,\n",
    "    }\n",
    "    torch.save(artifact, str(artifact_path))\n",
    "    try:\n",
    "        with open(run_dir / \"signature.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "            json.dump(signature, f, indent=2, sort_keys=True)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "def generate_and_train_simple(dataset_type, dataset_params, model_type, model_params, train_params=None):\n",
    "    \n",
    "    tp = train_params or {}\n",
    "\n",
    "    epochs     = int(tp.get(\"epochs\", 200))\n",
    "    batch_sz   = int(tp.get(\"batch_size\", 32))\n",
    "    lr         = float(tp.get(\"lr\", 1e-3))\n",
    "    wd         = float(tp.get(\"weight_decay\", 0.0))\n",
    "    val_frac   = float(tp.get(\"val_frac\", 0.15))\n",
    "    test_frac  = float(tp.get(\"test_frac\", 0.15))\n",
    "    eps        = float(tp.get(\"eps\", 1e-8))\n",
    "    seed       = int(tp.get(\"seed\", 0))\n",
    "    device     = tp.get(\"device\", \"cpu\")\n",
    "    plot_every = int(tp.get(\"plot_every\", 0) or 0)\n",
    "    plot_points= int(tp.get(\"plot_points\", 2048))\n",
    "    plot_chunk = int(tp.get(\"plot_chunk\", 4096))\n",
    "    print_stats= bool(tp.get(\"print_stats\", True))\n",
    "\n",
    "\n",
    "    # ---- determinism (seed controls data, model init, dataloader order, etc.) ----\n",
    "    _set_full_determinism(seed)\n",
    "    # dataset seed: caller may set it; otherwise default to the training seed\n",
    "    dataset_params = dict(dataset_params)\n",
    "    dataset_params.setdefault(\"seed\", int(seed))\n",
    "\n",
    "\n",
    "    # ---- cache check (skip training if identical run already saved) ----\n",
    "    train_sig = {\n",
    "        \"epochs\": int(epochs),\n",
    "        \"batch_size\": int(batch_sz),\n",
    "        \"lr\": float(lr),\n",
    "        \"weight_decay\": float(wd),\n",
    "        \"val_frac\": float(val_frac),\n",
    "        \"test_frac\": float(test_frac),\n",
    "        \"eps\": float(eps),\n",
    "        \"seed\": int(seed),\n",
    "        \"deterministic\": True,\n",
    "    }\n",
    "    cache_root = Path(tp.get(\"cache_root\", _CACHE_ROOT_DEFAULT))\n",
    "    signature = _run_signature(dataset_type, dataset_params, model_type, model_params, train_sig)\n",
    "    run_id = _run_id_from_signature(signature)\n",
    "    run_dir = _get_run_dir(cache_root, dataset_type, model_type, run_id)\n",
    "\n",
    "    cached = _try_load_cached_run(run_dir)\n",
    "    if cached is not None and isinstance(cached, dict) and \"state_dict\" in cached and \"data\" in cached and \"history\" in cached and \"spec\" in cached:\n",
    "        model = _build_model_from_params(model_type, model_params)\n",
    "        model.load_state_dict(cached[\"state_dict\"])\n",
    "        model = model.to(device)\n",
    "        return model, cached[\"data\"], cached[\"history\"], cached[\"spec\"]\n",
    "\n",
    "    # ---- data ----\n",
    "    dt = str(dataset_type).lower()\n",
    "    if dt == \"mdvsp\":\n",
    "        X, y, gt = make_mdvsp_dataset(**dataset_params)\n",
    "    elif dt == \"assignment\":\n",
    "        X, y, gt = generate_bipartite_subset_matching_dataset(**dataset_params)\n",
    "    elif dt == \"quadratic\":\n",
    "        X, y, gt = generate_convex_quadratic_dataset(**dataset_params)\n",
    "    else:\n",
    "        raise ValueError(\"dataset_type must be: mdvsp | assignment | quadratic\")\n",
    "\n",
    "    X = torch.as_tensor(X, dtype=torch.float32)\n",
    "    y = torch.as_tensor(y, dtype=torch.float32).view(-1)\n",
    "\n",
    "    # ---- print a quick dataset stats ----\n",
    "    if print_stats:\n",
    "        with torch.no_grad():\n",
    "            xmn = X.mean(0)\n",
    "            xsd = X.std(0, unbiased=False)\n",
    "            print(\n",
    "                f\"\\n--- Dataset stats ({dataset_type}) ---\\n\"\n",
    "                f\"  X: shape={tuple(X.shape)}  mean(mean)={xmn.mean():.3g}  std(mean)={xsd.mean():.3g}  \"\n",
    "                f\"min={float(X.min()):.3g}  max={float(X.max()):.3g}\\n\"\n",
    "                f\"  y: shape={tuple(y.shape)}  mean={float(y.mean()):.3g}  std={float(y.std(unbiased=False)):.3g}  \"\n",
    "                f\"min={float(y.min()):.3g}  max={float(y.max()):.3g}\\n\"\n",
    "            )\n",
    "\n",
    "    (Xtr, ytr), (Xva, yva), (Xte, yte) = _split_train_val_test(X, y, val_frac=val_frac, test_frac=test_frac, seed=seed)\n",
    "\n",
    "    scaler = _fit_standardizer(Xtr, ytr, eps=eps)\n",
    "    XtrN, ytrN = _apply_standardizer(Xtr, ytr, scaler)\n",
    "    XvaN, yvaN = _apply_standardizer(Xva, yva, scaler)\n",
    "    XteN, yteN = _apply_standardizer(Xte, yte, scaler)\n",
    "\n",
    "    g_dl = torch.Generator()\n",
    "    g_dl.manual_seed(seed)\n",
    "    train_loader = DataLoader(\n",
    "        TensorDataset(XtrN, ytrN),\n",
    "        batch_size=batch_sz,\n",
    "        shuffle=True,\n",
    "        generator=g_dl,\n",
    "        worker_init_fn=_seed_worker,\n",
    "    )\n",
    "    val_loader   = DataLoader(TensorDataset(XvaN, yvaN), batch_size=batch_sz, shuffle=False)\n",
    "\n",
    "    # subset for plotting\n",
    "    Nv = int(XvaN.shape[0])\n",
    "    if plot_points <= 0 or plot_points >= Nv:\n",
    "        plot_idx = torch.arange(Nv)\n",
    "    else:\n",
    "        g_plot = torch.Generator().manual_seed(seed + 12345)\n",
    "        plot_idx = torch.randperm(Nv, generator=g_plot)[:plot_points]\n",
    "\n",
    "    # ---- model ----\n",
    "    mt = str(model_type)\n",
    "    mp = dict(model_params)\n",
    "\n",
    "    if mt == \"DFN\":\n",
    "        model = DFN(**mp)\n",
    "        extra = f\"layers={mp.get('layer_sizes')} p_list={mp.get('p_list')} alpha={getattr(model,'alpha',None)} beta={getattr(model,'beta',None)}\"\n",
    "    else:\n",
    "        mp.pop(\"alpha\", None)\n",
    "        mp.pop(\"beta\", None)\n",
    "        if mt == \"MLP\":\n",
    "            model = MLP(**mp)\n",
    "            extra = f\"hidden={mp.get('hidden_dims')}\"\n",
    "        elif mt == \"MaxAffine\":\n",
    "            model = MaxAffine(**mp)\n",
    "            extra = f\"n_pieces={mp.get('n_pieces')}\"\n",
    "        elif mt == \"LSET\":\n",
    "            model = LSET(**mp)\n",
    "            extra = f\"n_pieces={mp.get('n_pieces')} T={mp.get('T')}\"\n",
    "        else:\n",
    "            raise ValueError(\"model_type must be: DFN | MLP | MaxAffine | LSET\")\n",
    "\n",
    "    model = model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "\n",
    "    n_params = sum(p.numel() for p in model.parameters())\n",
    "    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    print(\n",
    "        f\"\\n=== Run: {dataset_type} | {model_type} ===\\n\"\n",
    "        f\"  data: N={len(X)}  train/val/test={len(Xtr)}/{len(Xva)}/{len(Xte)}  dim={X.shape[1]}\\n\"\n",
    "        f\"  model: params={n_params:,} {extra}\\n\"\n",
    "        f\"  train: device={device}  epochs={epochs}  batch={batch_sz}  lr={lr:g}  wd={wd:g}  seed={seed}\\n\"\n",
    "    )\n",
    "\n",
    "    history = {\"train_mse_norm\": [], \"val_mse_norm\": []}\n",
    "    best_val, best_ep, best_state = float(\"inf\"), 0, None\n",
    "\n",
    "    live = display(None, display_id=True) if plot_every > 0 else None\n",
    "\n",
    "    for ep in range(1, epochs + 1):\n",
    "        model.train()\n",
    "        for xb, yb in train_loader:\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            loss = F.mse_loss(model(xb).squeeze(-1), yb)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "\n",
    "        tr_mse = _mse_norm(model, train_loader, device)\n",
    "        va_mse = _mse_norm(model, val_loader, device)\n",
    "        history[\"train_mse_norm\"].append(tr_mse)\n",
    "        history[\"val_mse_norm\"].append(va_mse)\n",
    "\n",
    "        if va_mse < best_val:\n",
    "            best_val, best_ep = va_mse, ep\n",
    "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
    "\n",
    "        if plot_every > 0 and (ep == 1 or ep % plot_every == 0 or ep == epochs):\n",
    "            model.eval()\n",
    "            x_plot = XvaN[plot_idx].to(device)\n",
    "            y_true = yvaN[plot_idx].to(device)\n",
    "            y_pred = _predict_in_chunks(model, x_plot, chunk=plot_chunk)\n",
    "\n",
    "            fig, ax = plt.subplots(1, 2, figsize=(10, 4))\n",
    "            ax[0].plot(history[\"train_mse_norm\"], label=\"train\")\n",
    "            ax[0].plot(history[\"val_mse_norm\"], label=\"val\")\n",
    "            ax[0].set_yscale(\"log\")\n",
    "            ax[0].set_title(f\"Epoch {ep}/{epochs} | val MSE={va_mse:.3e} (norm)\")\n",
    "            ax[0].legend()\n",
    "\n",
    "            yt = y_true.detach().cpu().numpy()\n",
    "            yp = y_pred.detach().cpu().numpy()\n",
    "            ax[1].scatter(yt, yp, s=10)\n",
    "            lo = float(min(yt.min(), yp.min()))\n",
    "            hi = float(max(yt.max(), yp.max()))\n",
    "            ax[1].plot([lo, hi], [lo, hi])\n",
    "            ax[1].set_xlabel(\"y_true (norm)\")\n",
    "            ax[1].set_ylabel(\"y_pred (norm)\")\n",
    "            ax[1].set_title(f\"Val scatter (n={len(yt)})\")\n",
    "\n",
    "            plt.tight_layout()\n",
    "            live.update(fig)\n",
    "            plt.close(fig)\n",
    "\n",
    "    if best_state is not None:\n",
    "        model.load_state_dict(best_state)\n",
    "\n",
    "    if print_stats:\n",
    "        print(f\"[DONE] best val MSE (norm) = {best_val:.3e} @ epoch {best_ep}\\n\")\n",
    "\n",
    "    data = {\n",
    "        \"raw\":  {\"Xtr\": Xtr,  \"ytr\": ytr,  \"Xva\": Xva,  \"yva\": yva,  \"Xte\": Xte,  \"yte\": yte},\n",
    "        \"norm\": {\"Xtr\": XtrN, \"ytr\": ytrN, \"Xva\": XvaN, \"yva\": yvaN, \"Xte\": XteN, \"yte\": yteN},\n",
    "        \"scaler\": scaler,\n",
    "        \"best_val_mse_norm\": float(best_val),\n",
    "        \"best_epoch\": int(best_ep),\n",
    "        \"stats\": {\"n_params\": int(n_params), \"n_trainable\": int(n_trainable)},\n",
    "        \"true\": gt,\n",
    "        \"device\": device,\n",
    "    }\n",
    "\n",
    "    spec = {\n",
    "        \"model\": model_type,\n",
    "        \"extra\": extra,\n",
    "        \"n_params\": int(n_params),\n",
    "        \"n_trainable\": int(n_trainable),\n",
    "        \"model_params\": dict(model_params),\n",
    "        \"train_params\": {\n",
    "            \"epochs\": int(epochs),\n",
    "            \"batch_size\": int(batch_sz),\n",
    "            \"lr\": float(lr),\n",
    "            \"weight_decay\": float(wd),\n",
    "            \"seed\": int(seed),\n",
    "            \"device\": str(device),\n",
    "            \"val_frac\": float(val_frac),\n",
    "            \"test_frac\": float(test_frac),\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "    # ---- save artifacts for this run ----\n",
    "\n",
    "    _save_run(run_dir, signature, model, data, history, spec)\n",
    "\n",
    "\n",
    "    return model, data, history, spec\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ace7f63f",
   "metadata": {},
   "source": [
    "## Optimization Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "38416530",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_true_obj(gt, x):\n",
    "    if not isinstance(gt, dict):\n",
    "        return np.nan\n",
    "\n",
    "    x = np.asarray(x).reshape(-1)\n",
    "    t = gt.get(\"type\", None)\n",
    "\n",
    "    if t == \"quadratic\":\n",
    "        Q = np.asarray(gt[\"Q\"], float)\n",
    "        xs = np.asarray(gt[\"x_star\"], float).reshape(-1)\n",
    "        d = x.astype(float) - xs\n",
    "        return float(d @ Q @ d)\n",
    "\n",
    "    if t == \"assignment\":\n",
    "        C = np.asarray(gt[\"C\"], float)\n",
    "        idx = np.flatnonzero(x > 0.5)\n",
    "        if idx.size == 0:\n",
    "            return 0.0\n",
    "        r, c = linear_sum_assignment(C[idx, :])\n",
    "        return float(C[idx, :][r, c].sum())\n",
    "\n",
    "    if t == \"mdvsp\":\n",
    "        N, SS, TT = int(gt[\"N\"]), int(gt[\"SS\"]), int(gt[\"TT\"])\n",
    "        src = np.asarray(gt[\"src\"], np.int64)\n",
    "        dst = np.asarray(gt[\"dst\"], np.int64)\n",
    "        cost = np.asarray(gt[\"cost\"], float)\n",
    "        cap0 = np.asarray(gt[\"cap0\"], float)\n",
    "        idxSS = np.asarray(gt[\"idxSS\"], np.int64)\n",
    "        idxTT = np.asarray(gt[\"idxTT\"], np.int64)\n",
    "\n",
    "        cap = cap0.copy()\n",
    "        cap[idxSS] = x.astype(float)\n",
    "        cap[idxTT] = x.astype(float)\n",
    "\n",
    "        Fmax = lemon_mcf.max_flow(N, src, dst, cap, SS, TT)[\"value\"]\n",
    "        supply = np.zeros(N, float)\n",
    "        supply[SS] = Fmax\n",
    "        supply[TT] = -Fmax\n",
    "        return float(lemon_mcf.solve_mcf(N, src, dst, cost, cap, supply)[\"total_cost\"])\n",
    "\n",
    "    return np.nan\n",
    "\n",
    "\n",
    "def local_search_l1_int(\n",
    "    f,\n",
    "    x0,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    delta: int,\n",
    "    sum_eq=None,\n",
    "    max_iters: int = 10_000,\n",
    "    print_every: int = 1,\n",
    "):\n",
    "    x  = np.asarray(x0,    int).ravel()\n",
    "    lo = np.asarray(x_min, int).ravel()\n",
    "    hi = np.asarray(x_max, int).ravel()\n",
    "    n = int(x.size)\n",
    "    delta = int(delta)\n",
    "\n",
    "    assert lo.size == n and hi.size == n\n",
    "    assert np.all(x >= lo) and np.all(x <= hi)\n",
    "    if sum_eq is not None:\n",
    "        sum_eq = int(sum_eq)\n",
    "        assert int(x.sum()) == sum_eq\n",
    "\n",
    "    def eval_batch(X):\n",
    "        y = np.asarray(f(X), float).reshape(-1)\n",
    "        return y\n",
    "\n",
    "    def ok(z):\n",
    "        if np.any(z < lo) or np.any(z > hi):\n",
    "            return False\n",
    "        if sum_eq is not None and int(z.sum()) != sum_eq:\n",
    "            return False\n",
    "        return True\n",
    "\n",
    "    # integer deltas with exact L1 = k\n",
    "    def deltas_exact(k):\n",
    "        d = np.zeros(n, int)\n",
    "\n",
    "        def rec(i, rem):\n",
    "            if i == n:\n",
    "                if rem == 0:\n",
    "                    yield d.copy()\n",
    "                return\n",
    "            for t in range(rem + 1):\n",
    "                if t == 0:\n",
    "                    d[i] = 0\n",
    "                    yield from rec(i + 1, rem)\n",
    "                else:\n",
    "                    d[i] = +t\n",
    "                    yield from rec(i + 1, rem - t)\n",
    "                    d[i] = -t\n",
    "                    yield from rec(i + 1, rem - t)\n",
    "            d[i] = 0\n",
    "\n",
    "        yield from rec(0, k)\n",
    "\n",
    "    t0 = time.perf_counter()\n",
    "    y = float(eval_batch(x[None, :])[0])\n",
    "    hist = [{\"iter\": 0, \"t\": 0.0, \"best_y\": y, \"x\": x.copy()}]\n",
    "    print(f\"iter=0  t=0.00s  best_y={y:.6g}  x={x.tolist()}\")\n",
    "\n",
    "    for it in range(1, int(max_iters) + 1):\n",
    "        cand = []\n",
    "        for k in range(1, delta + 1):\n",
    "            for dlt in deltas_exact(k):\n",
    "                z = x + dlt\n",
    "                if ok(z):\n",
    "                    cand.append(z)\n",
    "\n",
    "        if not cand:\n",
    "            print(f\"STOP: no feasible neighbors. best_y={y:.6g} x={x.tolist()}\")\n",
    "            break\n",
    "\n",
    "        Y = eval_batch(np.stack(cand, 0))\n",
    "        j = int(np.argmin(Y))\n",
    "        if float(Y[j]) >= y:\n",
    "            print(f\"STOP: local minimum. best_y={y:.6g} x={x.tolist()}\")\n",
    "            break\n",
    "\n",
    "        x, y = cand[j], float(Y[j])\n",
    "        hist.append({\"iter\": it, \"t\": time.perf_counter() - t0, \"best_y\": y, \"x\": x.copy()})\n",
    "        if it % max(1, int(print_every)) == 0:\n",
    "            print(f\"iter={it}  t={hist[-1]['t']:.2f}s  best_y={y:.6g}  x={x.tolist()}\")\n",
    "\n",
    "    return x, y, hist\n",
    "\n",
    "\n",
    "def _scaler_np(scaler):\n",
    "    xm = scaler[\"x_mean\"].detach().cpu().numpy().reshape(-1)\n",
    "    xs = scaler[\"x_std\"].detach().cpu().numpy().reshape(-1)\n",
    "    ym = float(scaler[\"y_mean\"].detach().cpu())\n",
    "    ys = float(scaler[\"y_std\"].detach().cpu())\n",
    "    return xm, xs, ym, ys\n",
    "\n",
    "\n",
    "def solve_dfn_ip_gurobi(dfn, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    x_mean, x_std, y_mean, y_std = _scaler_np(scaler)\n",
    "\n",
    "    cost = np.r_[dfn.cost_raw.detach().cpu().numpy(), dfn.cost_fixed.detach().cpu().numpy()]\n",
    "    cap  = np.r_[torch.round(F.softplus(dfn.cap_raw.detach())).cpu().numpy(), dfn.cap_fixed.detach().cpu().numpy()]\n",
    "\n",
    "    A = dfn.A.detach().cpu().numpy()\n",
    "    if isinstance(dfn.A, torch.nn.Parameter):\n",
    "        A = np.round(A)\n",
    "    b = np.round(dfn.b_raw.detach().cpu().numpy())\n",
    "\n",
    "    src = dfn.src.detach().cpu().numpy().astype(int)\n",
    "    dst = dfn.dst.detach().cpu().numpy().astype(int)\n",
    "    boundary = dfn.boundary.detach().cpu().numpy().astype(int)\n",
    "    fix = int(dfn.fix_node)\n",
    "    n = int(dfn.n)\n",
    "    m = int(src.size)\n",
    "    alpha = float(dfn.alpha)\n",
    "    beta  = float(dfn.beta)\n",
    "\n",
    "    out = [[] for _ in range(n)]\n",
    "    inn = [[] for _ in range(n)]\n",
    "    for e in range(m):\n",
    "        out[src[e]].append(e)\n",
    "        inn[dst[e]].append(e)\n",
    "\n",
    "    M = gp.Model(\"DFN_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    f = M.addVars(m, lb=0.0, ub=cap.tolist(), vtype=GRB.CONTINUOUS, name=\"f\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq)\n",
    "\n",
    "    xm_over_xs = x_mean / x_std\n",
    "    s = [0] * n\n",
    "    s_boundary = []\n",
    "    for r, v in enumerate(boundary):\n",
    "        const = float(b[r] - (A[r] * xm_over_xs).sum())\n",
    "        expr = const + gp.quicksum((A[r, j] / x_std[j]) * x[j] for j in range(d) if A[r, j] != 0)\n",
    "        s[v] = expr\n",
    "        s_boundary.append(expr)\n",
    "    s[fix] = -gp.quicksum(s_boundary)\n",
    "\n",
    "    for v in range(n):\n",
    "        M.addConstr(gp.quicksum(f[e] for e in out[v]) - gp.quicksum(f[e] for e in inn[v]) == s[v])\n",
    "\n",
    "    flow_cost = gp.quicksum(cost[e] * f[e] for e in range(m))\n",
    "    M.setObjective((alpha * flow_cost + beta) * y_std + y_mean, GRB.MINIMIZE)\n",
    "\n",
    "    M.optimize()\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No solution (Gurobi status {M.Status})\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None)}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_mlp_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    base, a_out, b_out = model, 1.0, 0.0\n",
    "    if hasattr(model, \"base\") and hasattr(model, \"a\") and hasattr(model, \"b\"):\n",
    "        base, a_out, b_out = model.base, float(model.a), float(model.b)\n",
    "\n",
    "    if not hasattr(base, \"net\"):\n",
    "        raise ValueError(\"Expected an MLP with attribute .net (nn.Sequential).\")\n",
    "    linears = [L for L in base.net if isinstance(L, torch.nn.Linear)]\n",
    "    if not linears:\n",
    "        raise ValueError(\"No Linear layers found in base.net\")\n",
    "\n",
    "    W = [L.weight.detach().cpu().numpy().astype(float) for L in linears]\n",
    "    b = [L.bias.detach().cpu().numpy().astype(float) for L in linears]\n",
    "\n",
    "    W[0] = W[0] / xs[None, :]\n",
    "    b[0] = b[0] - W[0] @ xm\n",
    "\n",
    "    W[-1] *= a_out\n",
    "    b[-1] = a_out * b[-1] + b_out\n",
    "\n",
    "    u = np.maximum(np.abs(x_min), np.abs(x_max))\n",
    "    preLU = []\n",
    "    for k in range(len(W) - 1):\n",
    "        U = np.abs(W[k]) @ u + np.abs(b[k])\n",
    "        preLU.append((-U, U))\n",
    "        u = np.maximum(0.0, U)\n",
    "\n",
    "    M = gp.Model(\"MLP_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    prev = [x[i] for i in range(d)]\n",
    "\n",
    "    for k in range(len(W) - 1):\n",
    "        Lk, Uk = preLU[k]\n",
    "        h = W[k].shape[0]\n",
    "\n",
    "        a = [M.addVar(lb=float(Lk[j]), ub=float(Uk[j]), name=f\"a{k}_{j}\") for j in range(h)]\n",
    "        z = [M.addVar(lb=0.0, ub=float(max(0.0, Uk[j])), name=f\"z{k}_{j}\") for j in range(h)]\n",
    "\n",
    "        for j in range(h):\n",
    "            M.addConstr(a[j] == b[k][j] + gp.quicksum(W[k][j, i] * prev[i] for i in range(len(prev))))\n",
    "\n",
    "            Lj, Uj = float(Lk[j]), float(Uk[j])\n",
    "            if Uj <= 0.0:\n",
    "                M.addConstr(z[j] == 0.0)\n",
    "            elif Lj >= 0.0:\n",
    "                M.addConstr(z[j] == a[j])\n",
    "            else:\n",
    "                s = M.addVar(vtype=GRB.BINARY, name=f\"s{k}_{j}\")\n",
    "                M.addConstr(z[j] >= a[j])\n",
    "                M.addConstr(z[j] >= 0.0)\n",
    "                M.addConstr(z[j] <= Uj * s)\n",
    "                M.addConstr(z[j] <= a[j] - Lj * (1 - s))\n",
    "\n",
    "        prev = z\n",
    "\n",
    "    if W[-1].shape[0] != 1:\n",
    "        raise ValueError(f\"Expected scalar output, got out_dim={W[-1].shape[0]}\")\n",
    "\n",
    "    y_norm = M.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name=\"y_norm\")\n",
    "    M.addConstr(y_norm == b[-1][0] + gp.quicksum(W[-1][0, i] * prev[i] for i in range(len(prev))))\n",
    "\n",
    "    M.setObjective(ys * y_norm + ym, GRB.MINIMIZE)\n",
    "\n",
    "    M.optimize()\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None), \"sol_count\": M.SolCount}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_maxaffine_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    base, a_out, b_out = model, 1.0, 0.0\n",
    "    if hasattr(model, \"base\") and hasattr(model, \"a\") and hasattr(model, \"b\"):\n",
    "        base, a_out, b_out = model.base, float(model.a), float(model.b)\n",
    "\n",
    "    W = base.W.detach().cpu().numpy().astype(float)\n",
    "    b = base.b.detach().cpu().numpy().astype(float)\n",
    "\n",
    "    Weff = W / xs[None, :]\n",
    "    beff = b - (Weff @ xm)\n",
    "\n",
    "    Weff *= a_out\n",
    "    beff  = a_out * beff + b_out\n",
    "\n",
    "    K = int(Weff.shape[0])\n",
    "\n",
    "    M = gp.Model(\"MaxAffine_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    t = M.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name=\"t_norm\")\n",
    "    for k in range(K):\n",
    "        M.addConstr(t >= beff[k] + gp.quicksum(Weff[k, j] * x[j] for j in range(d) if Weff[k, j] != 0.0))\n",
    "\n",
    "    M.setObjective(ys * t + ym, GRB.MINIMIZE)\n",
    "    M.optimize()\n",
    "\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None), \"sol_count\": M.SolCount}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_lset_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    A = model.A.detach().cpu().numpy().astype(float)\n",
    "    b = model.b.detach().cpu().numpy().astype(float)\n",
    "    T = float(model.T)\n",
    "    if T == 0.0:\n",
    "        raise ValueError(\"T must be nonzero\")\n",
    "\n",
    "    Aeff = A / xs[None, :]\n",
    "    beff = b - (Aeff @ xm)\n",
    "    K = int(Aeff.shape[0])\n",
    "\n",
    "    lin_lo = np.empty(K, dtype=float)\n",
    "    lin_hi = np.empty(K, dtype=float)\n",
    "    for k in range(K):\n",
    "        a = Aeff[k]\n",
    "        lo = beff[k]\n",
    "        hi = beff[k]\n",
    "        pos = a >= 0\n",
    "        lo += (a[pos] * x_min[pos]).sum() + (a[~pos] * x_max[~pos]).sum()\n",
    "        hi += (a[pos] * x_max[pos]).sum() + (a[~pos] * x_min[~pos]).sum()\n",
    "        lin_lo[k], lin_hi[k] = lo, hi\n",
    "\n",
    "    z_lo = lin_lo / T\n",
    "    z_hi = lin_hi / T\n",
    "    m_lo = float(np.max(z_lo))\n",
    "    m_hi = float(np.max(z_hi))\n",
    "\n",
    "    w_lo = float(np.min(z_lo - m_hi))  # <= 0\n",
    "    u_lo = float(np.exp(max(-700.0, w_lo)))\n",
    "    s_lo = max(1e-12, K * u_lo)\n",
    "    s_hi = float(K * 1.0)\n",
    "    v_lo = float(np.log(s_lo))\n",
    "    v_hi = float(np.log(s_hi))\n",
    "\n",
    "    yN_lo = float(T * (m_lo + v_lo))\n",
    "    yN_hi = float(T * (m_hi + v_hi))\n",
    "    y_lo  = float(ys * yN_lo + ym)\n",
    "    y_hi  = float(ys * yN_hi + ym)\n",
    "\n",
    "    M = gp.Model(\"lset_ip_stable_bounded\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    M.Params.FuncNonlinear = 1\n",
    "    M.Params.FeasibilityTol = 1e-9\n",
    "    M.Params.OptimalityTol  = 1e-9\n",
    "    M.Params.IntFeasTol     = 1e-9\n",
    "    M.Params.NumericFocus   = 3\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    z = M.addVars(K, lb=z_lo.tolist(), ub=z_hi.tolist(), vtype=GRB.CONTINUOUS, name=\"z\")\n",
    "    for k in range(K):\n",
    "        lin = beff[k] + gp.quicksum(Aeff[k, j] * x[j] for j in range(d) if Aeff[k, j] != 0.0)\n",
    "        M.addConstr(z[k] == lin / T, name=f\"zdef_{k}\")\n",
    "\n",
    "    m = M.addVar(lb=m_lo, ub=m_hi, vtype=GRB.CONTINUOUS, name=\"m\")\n",
    "    w = M.addVars(K, lb=w_lo, ub=0.0, vtype=GRB.CONTINUOUS, name=\"w\")\n",
    "    for k in range(K):\n",
    "        M.addConstr(m >= z[k], name=f\"m_ge_z_{k}\")\n",
    "        M.addConstr(w[k] == z[k] - m, name=f\"wdef_{k}\")\n",
    "\n",
    "    u = M.addVars(K, lb=0.0, ub=1.0, vtype=GRB.CONTINUOUS, name=\"u\")\n",
    "    for k in range(K):\n",
    "        M.addGenConstrExp(w[k], u[k], name=f\"exp_{k}\")\n",
    "\n",
    "    s = M.addVar(lb=s_lo, ub=s_hi, vtype=GRB.CONTINUOUS, name=\"s\")\n",
    "    M.addConstr(s == gp.quicksum(u[k] for k in range(K)), name=\"sumexp_shifted\")\n",
    "\n",
    "    v = M.addVar(lb=v_lo, ub=v_hi, vtype=GRB.CONTINUOUS, name=\"v\")\n",
    "    M.addGenConstrLog(s, v, name=\"log_shifted\")\n",
    "\n",
    "    y_norm = M.addVar(lb=yN_lo, ub=yN_hi, vtype=GRB.CONTINUOUS, name=\"y_norm\")\n",
    "    M.addConstr(y_norm == T * (m + v), name=\"y_norm_def\")\n",
    "\n",
    "    y_raw = M.addVar(lb=y_lo, ub=y_hi, vtype=GRB.CONTINUOUS, name=\"y_raw\")\n",
    "    M.addConstr(y_raw == ys * y_norm + ym, name=\"y_raw_def\")\n",
    "\n",
    "    M.setObjective(y_raw, GRB.MINIMIZE)\n",
    "    M.optimize()\n",
    "\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution found. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], dtype=float)\n",
    "    y_star = float(y_raw.X)\n",
    "\n",
    "    info = {\n",
    "        \"status\": M.Status,\n",
    "        \"runtime\": M.Runtime,\n",
    "        \"gap\": getattr(M, \"MIPGap\", None),\n",
    "        \"sol_count\": M.SolCount,\n",
    "        \"obj_gurobi\": float(M.ObjVal),\n",
    "        \"obj_bound\": float(getattr(M, \"ObjBound\", float(\"nan\"))),\n",
    "    }\n",
    "    return x_star, y_star, info\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4858979",
   "metadata": {},
   "source": [
    "## Test Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "569f2165",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_ip(model_type, model, scaler, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    if model_type == \"DFN\":\n",
    "        return solve_dfn_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if model_type == \"MLP\":\n",
    "        return solve_mlp_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if model_type == \"MaxAffine\":\n",
    "        return solve_maxaffine_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if model_type == \"LSET\":\n",
    "        return solve_lset_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    raise ValueError(model_type)\n",
    "\n",
    "\n",
    "def suppress_stdout(fn, *args, silence: bool = False, **kwargs):\n",
    "    if not silence:\n",
    "        return fn(*args, **kwargs), \"\"\n",
    "    buf = io.StringIO()\n",
    "    with contextlib.redirect_stdout(buf):\n",
    "        out = fn(*args, **kwargs)\n",
    "    return out, buf.getvalue()\n",
    "\n",
    "\n",
    "def make_obj(model, scaler, device, chunk: int = 4096):\n",
    "    xm = scaler[\"x_mean\"].to(device)\n",
    "    xs = scaler[\"x_std\"].to(device)\n",
    "    ym = scaler[\"y_mean\"].to(device)\n",
    "    ys = scaler[\"y_std\"].to(device)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def obj(Xraw):\n",
    "        Xraw = torch.as_tensor(Xraw, dtype=torch.float32, device=device)\n",
    "        if Xraw.dim() == 1:\n",
    "            Xraw = Xraw.unsqueeze(0)\n",
    "\n",
    "        outs = []\n",
    "        B = int(Xraw.shape[0])\n",
    "        for i in range(0, B, int(chunk)):\n",
    "            Xb = Xraw[i:i+chunk]\n",
    "            try:\n",
    "                Xn = (Xb - xm) / xs\n",
    "                yn = model(Xn)\n",
    "                yn = torch.as_tensor(yn).reshape(-1)  \n",
    "                y  = (yn * ys + ym).reshape(-1).detach().cpu() \n",
    "                if y.numel() != Xb.shape[0]:\n",
    "                    y = y.expand(Xb.shape[0]).contiguous()\n",
    "                outs.append(y)\n",
    "            except Exception as e:\n",
    "                obj._err_count += int(Xb.shape[0])\n",
    "                if obj._err_first is None:\n",
    "                    obj._err_first = repr(e)\n",
    "                outs.append(torch.full((int(Xb.shape[0]),), float(\"inf\"), device=\"cpu\"))\n",
    "\n",
    "        return torch.cat(outs, 0).numpy()\n",
    "\n",
    "    obj._err_count = 0\n",
    "    obj._err_first = None\n",
    "    return obj\n",
    "\n",
    "\n",
    "def safe_raw_mse(obj, Xraw, yraw):\n",
    "    yp = np.asarray(obj(Xraw), dtype=float).reshape(-1)\n",
    "    yt = yraw.detach().cpu().numpy().reshape(-1) if torch.is_tensor(yraw) else np.asarray(yraw, float).reshape(-1)\n",
    "    if yp.shape[0] != yt.shape[0]:\n",
    "        return np.nan, f\"shape mismatch: pred {yp.shape} vs true {yt.shape}\"\n",
    "    mask = np.isfinite(yp)\n",
    "    if mask.sum() == 0:\n",
    "        return np.nan, \"all predictions were non-finite (inf/nan)\"\n",
    "    return float(np.mean((yp[mask] - yt[mask])**2)), None\n",
    "\n",
    "\n",
    "def safe_norm_mse(model, scaler, device, Xraw, yraw, *, chunk=4096):\n",
    "    xm = scaler[\"x_mean\"].to(device)\n",
    "    xs = scaler[\"x_std\"].to(device)\n",
    "    ym = scaler[\"y_mean\"].to(device)\n",
    "    ys = scaler[\"y_std\"].to(device)\n",
    "\n",
    "    Xraw = torch.as_tensor(Xraw, dtype=torch.float32, device=device)\n",
    "    if Xraw.dim() == 1:\n",
    "        Xraw = Xraw.unsqueeze(0)\n",
    "\n",
    "    yraw_t = yraw\n",
    "    if torch.is_tensor(yraw_t):\n",
    "        yraw_t = yraw_t.to(device=device, dtype=torch.float32)\n",
    "    else:\n",
    "        yraw_t = torch.as_tensor(yraw_t, dtype=torch.float32, device=device)\n",
    "    yraw_t = yraw_t.reshape(-1)\n",
    "\n",
    "    preds = []\n",
    "    B = int(Xraw.shape[0])\n",
    "    try:\n",
    "        with torch.no_grad():\n",
    "            for i in range(0, B, int(chunk)):\n",
    "                Xb = Xraw[i:i+chunk]\n",
    "                Xn = (Xb - xm) / xs\n",
    "                yn = model(Xn)\n",
    "                yn = torch.as_tensor(yn).reshape(-1)\n",
    "                if yn.numel() != Xb.shape[0]:\n",
    "                    yn = yn.expand(Xb.shape[0]).contiguous()\n",
    "                preds.append(yn.detach().cpu())\n",
    "    except Exception as e:\n",
    "        return np.nan, f\"model forward failed in safe_norm_mse: {repr(e)}\"\n",
    "\n",
    "    yp = torch.cat(preds, 0).numpy().reshape(-1)\n",
    "    yt = ((yraw_t - ym) / ys).detach().cpu().numpy().reshape(-1)\n",
    "\n",
    "    if yp.shape[0] != yt.shape[0]:\n",
    "        return np.nan, f\"shape mismatch: pred {yp.shape} vs true {yt.shape}\"\n",
    "\n",
    "    mask = np.isfinite(yp) & np.isfinite(yt)\n",
    "    if mask.sum() == 0:\n",
    "        return np.nan, \"all predictions or targets were non-finite (inf/nan)\"\n",
    "\n",
    "    return float(np.mean((yp[mask] - yt[mask])**2)), None\n",
    "\n",
    "\n",
    "def t_to_best(hist, y_best):\n",
    "    for r in hist:\n",
    "        if abs(float(r.get(\"best_y\", np.inf)) - float(y_best)) < 1e-12:\n",
    "            return float(r.get(\"t\", float(\"nan\")))\n",
    "    return float(\"nan\")\n",
    "\n",
    "\n",
    "def check_ip_matches_obj(name, obj, x_ip, y_ip, *, strict: bool, tol: float):\n",
    "    y_obj = float(np.asarray(obj(np.asarray(x_ip)), dtype=float).reshape(-1)[0])\n",
    "    y_ip  = float(y_ip)\n",
    "    rel = abs(y_obj - y_ip) / (abs(y_obj) + 1e-12)\n",
    "    print(f\"[CHECK {name}] obj(x_ip)={y_obj:.6g}  ip_y={y_ip:.6g}  rel_err={rel:.3e}\")\n",
    "    if strict and rel > tol:\n",
    "        raise RuntimeError(f\"{name}: IP objective != obj() (rel_err={rel:.3e})\")\n",
    "    return rel\n",
    "\n",
    "\n",
    "def mean_se(x):\n",
    "    x = pd.to_numeric(pd.Series(x), errors=\"coerce\").to_numpy()\n",
    "    x = x[np.isfinite(x)]\n",
    "    n = int(x.shape[0])\n",
    "    if n == 0:\n",
    "        return np.nan, np.nan\n",
    "    m = float(x.mean())\n",
    "    se = float(x.std(ddof=1) / np.sqrt(n)) if n > 1 else 0.0\n",
    "    return m, se\n",
    "\n",
    "\n",
    "def fmt_mean_se(m, se):\n",
    "    if not np.isfinite(m):\n",
    "        return \"nan\"\n",
    "    if not np.isfinite(se):\n",
    "        return f\"{m:.6g}\"\n",
    "    return f\"{m:.6g} ± {se:.3g}\"\n",
    "\n",
    "\n",
    "def repr_solution(xs: pd.Series, seeds=None):\n",
    "    xs = xs.dropna().astype(str)\n",
    "    xs = xs[xs != \"None\"]\n",
    "    if xs.empty:\n",
    "        return None\n",
    "\n",
    "    # With seed information: show per-seed if solutions differ\n",
    "    if seeds is not None:\n",
    "        df = pd.DataFrame({\"seed\": pd.Series(seeds), \"x\": xs})\n",
    "        df = df.dropna()\n",
    "        df[\"x\"] = df[\"x\"].astype(str)\n",
    "        if df.empty:\n",
    "            return None\n",
    "        if df[\"x\"].nunique() == 1:\n",
    "            return df[\"x\"].iloc[0]\n",
    "        df = df.sort_values(\"seed\")\n",
    "        return \"\\n\".join([f\"seed={int(r.seed)}: {r.x}\" for r in df.itertuples(index=False)])\n",
    "\n",
    "    # No seed info: keep compact\n",
    "    if xs.nunique() == 1:\n",
    "        return xs.iloc[0]\n",
    "    vc = xs.value_counts()\n",
    "    top = vc.index[0]\n",
    "    n_unique = int(vc.shape[0])\n",
    "    return f\"{top} (+{n_unique-1} other)\"\n",
    "    \n",
    "\n",
    "# -----------------------------\n",
    "# Ground-truth optimum solvers\n",
    "# -----------------------------\n",
    "\n",
    "def solve_true_opt_quadratic(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    Q  = np.asarray(gt[\"Q\"], float)\n",
    "    xs = np.asarray(gt[\"x_star\"], float).reshape(-1)\n",
    "    n  = int(xs.shape[0])\n",
    "    assert Q.shape == (n, n)\n",
    "\n",
    "    m = gp.Model(\"gt_quadratic\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the (ground-truth) solve reproducible across runs\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    x = m.addVars(n, vtype=GRB.INTEGER, name=\"x\")\n",
    "    for i in range(n):\n",
    "        x[i].LB = int(xmin[i])\n",
    "        x[i].UB = int(xmax[i])\n",
    "    m.addConstr(gp.quicksum(x[i] for i in range(n)) == int(sum_eq), name=\"sum_eq\")\n",
    "\n",
    "    expr = gp.quicksum(float(Q[i, j]) * (x[i] - float(xs[i])) * (x[j] - float(xs[j])) for i in range(n) for j in range(n))\n",
    "    m.setObjective(expr, GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT quadratic solve failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(x[i].X)) for i in range(n)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status)}\n",
    "\n",
    "\n",
    "def solve_true_opt_assignment(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    C = np.asarray(gt[\"C\"], float)\n",
    "    n_rows, n_cols = C.shape\n",
    "    k = int(sum_eq)\n",
    "    if k > n_cols:\n",
    "        raise RuntimeError(f\"sum_eq={k} exceeds number of columns={n_cols} (infeasible)\")\n",
    "\n",
    "    m = gp.Model(\"gt_assignment\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the (ground-truth) solve reproducible across runs\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    z = m.addVars(n_rows, vtype=GRB.BINARY, name=\"z\")  # select row i\n",
    "    y = m.addVars(n_rows, n_cols, vtype=GRB.BINARY, name=\"y\")  # assign row i to col j\n",
    "\n",
    "    m.addConstr(gp.quicksum(z[i] for i in range(n_rows)) == k, name=\"k_rows\")\n",
    "    for i in range(n_rows):\n",
    "        m.addConstr(gp.quicksum(y[i, j] for j in range(n_cols)) == z[i], name=f\"assign_row_{i}\")\n",
    "    for j in range(n_cols):\n",
    "        m.addConstr(gp.quicksum(y[i, j] for i in range(n_rows)) <= 1, name=f\"assign_col_{j}\")\n",
    "\n",
    "    m.setObjective(gp.quicksum(float(C[i, j]) * y[i, j] for i in range(n_rows) for j in range(n_cols)), GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT assignment solve failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(z[i].X)) for i in range(n_rows)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status)}\n",
    "\n",
    "\n",
    "def solve_true_opt_mdvsp(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    N  = int(gt[\"N\"])\n",
    "    SS = int(gt[\"SS\"])\n",
    "    TT = int(gt[\"TT\"])\n",
    "    src  = np.asarray(gt[\"src\"], np.int64)\n",
    "    dst  = np.asarray(gt[\"dst\"], np.int64)\n",
    "    cost = np.asarray(gt[\"cost\"], float)\n",
    "    cap0 = np.asarray(gt[\"cap0\"], float)\n",
    "    idxSS = np.asarray(gt[\"idxSS\"], np.int64)\n",
    "    idxTT = np.asarray(gt[\"idxTT\"], np.int64)\n",
    "\n",
    "    E = int(src.shape[0])\n",
    "    k = int(idxSS.shape[0])\n",
    "    assert idxTT.shape[0] == k, \"Expect idxSS and idxTT to be same length\"\n",
    "\n",
    "    out_edges = [[] for _ in range(N)]\n",
    "    in_edges  = [[] for _ in range(N)]\n",
    "    for e in range(E):\n",
    "        out_edges[int(src[e])].append(e)\n",
    "        in_edges[int(dst[e])].append(e)\n",
    "\n",
    "    idxSS = idxSS.astype(int)\n",
    "    idxTT = idxTT.astype(int)\n",
    "    var_arc_to_i = {int(idxSS[i]): i for i in range(k)}\n",
    "    var_arc_to_i.update({int(idxTT[i]): i for i in range(k)})\n",
    "    var_arcs = sorted(var_arc_to_i.keys())\n",
    "\n",
    "    m = gp.Model(\"gt_mdvsp_true_opt\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the (ground-truth) solve reproducible across runs\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    m.Params.NonConvex = 2\n",
    "\n",
    "    x = m.addVars(k, vtype=GRB.INTEGER, name=\"x\")\n",
    "    for i in range(k):\n",
    "        x[i].LB = int(xmin[i])\n",
    "        x[i].UB = int(xmax[i])\n",
    "    m.addConstr(gp.quicksum(x[i] for i in range(k)) == int(sum_eq), name=\"sum_eq\")\n",
    "\n",
    "    f = m.addVars(E, vtype=GRB.CONTINUOUS, lb=0.0, name=\"f\")\n",
    "\n",
    "    for e in range(E):\n",
    "        if e in var_arc_to_i:\n",
    "            i = var_arc_to_i[e]\n",
    "            m.addConstr(f[e] <= x[i], name=f\"cap_var_{e}\")\n",
    "        else:\n",
    "            m.addConstr(f[e] <= float(cap0[e]), name=f\"cap_fix_{e}\")\n",
    "\n",
    "    F = m.addVar(vtype=GRB.CONTINUOUS, lb=0.0, name=\"F\")\n",
    "\n",
    "    # flow conservation: out - in = F at SS, = -F at TT, else 0\n",
    "    for v in range(N):\n",
    "        out_sum = gp.quicksum(f[e] for e in out_edges[v])\n",
    "        in_sum  = gp.quicksum(f[e] for e in in_edges[v])\n",
    "        rhs = F if v == SS else (-F if v == TT else 0.0)\n",
    "        m.addConstr(out_sum - in_sum == rhs, name=f\"flow_{v}\")\n",
    "\n",
    "    y = m.addVars(N, vtype=GRB.CONTINUOUS, lb=0.0, ub=1.0, name=\"y\")\n",
    "    m.addConstr(y[SS] == 1.0, name=\"ySS\")\n",
    "    m.addConstr(y[TT] == 0.0, name=\"yTT\")\n",
    "\n",
    "    z = m.addVars(E, vtype=GRB.CONTINUOUS, lb=0.0, ub=1.0, name=\"z\")\n",
    "    for e in range(E):\n",
    "        m.addConstr(z[e] >= y[int(src[e])] - y[int(dst[e])], name=f\"dual_{e}\")\n",
    "\n",
    "    dual_obj = gp.QuadExpr()\n",
    "    for e in range(E):\n",
    "        if e in var_arc_to_i:\n",
    "            i = var_arc_to_i[e]\n",
    "            # bilinear term: x_i * z_e\n",
    "            dual_obj += x[i] * z[e]\n",
    "        else:\n",
    "            dual_obj += float(cap0[e]) * z[e]\n",
    "    m.addConstr(F == dual_obj, name=\"strong_duality\")\n",
    "\n",
    "    m.setObjective(gp.quicksum(float(cost[e]) * f[e] for e in range(E)), GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT mdvsp true-opt failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(x[i].X)) for i in range(k)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status), \"F\": float(F.X)}\n",
    "    \n",
    "\n",
    "\n",
    "def solve_true_opt(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    \"\"\"Top-level dispatcher for ground-truth optimum.\"\"\"\n",
    "    if not isinstance(gt, dict):\n",
    "        raise RuntimeError(\"No ground-truth structure found (gt must be a dict).\")\n",
    "    t = gt.get(\"type\", None)\n",
    "    if t == \"quadratic\":\n",
    "        return solve_true_opt_quadratic(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if t == \"assignment\":\n",
    "        return solve_true_opt_assignment(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if t == \"mdvsp\":\n",
    "        return solve_true_opt_mdvsp(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    raise RuntimeError(f\"Unknown gt type: {t}\")\n",
    "\n",
    "\n",
    "# -----------------------------\n",
    "# Benchmark runner\n",
    "# -----------------------------\n",
    "def run_benchmark(\n",
    "    *,\n",
    "    dataset_type: str,\n",
    "    dataset_params: dict,\n",
    "    runs: list[tuple[str, str, dict]],\n",
    "    train_base: dict,\n",
    "    lr_map: dict,\n",
    "    x0: np.ndarray,\n",
    "    xmin: np.ndarray,\n",
    "    xmax: np.ndarray,\n",
    "    delta: int,\n",
    "    sum_eq: int,\n",
    "    n_seeds: int = 1,\n",
    "    vary_dataset_seed: bool = False,\n",
    "    vary_model_init_seed: bool = True,\n",
    "    strict_ip_check: bool = False,\n",
    "    ip_check_tol: float = 1e-4,\n",
    "    silence_local_search: bool = False,\n",
    "    allow_plots_multi_seed: bool = True,\n",
    "    time_limit=None,\n",
    "):\n",
    "    seeds = list(range(int(n_seeds)))\n",
    "\n",
    "    learn_rows, opt_rows, spec_rows, fail_rows = [], [], [], []\n",
    "    gt_rows = []\n",
    "\n",
    "    gt_cache_by_seed = {}\n",
    "\n",
    "    for seed in seeds:\n",
    "        print(f\"\\n\\n===================== SEED {seed} =====================\")\n",
    "        # Per-seed starting point + constraint (so the run seed controls x0 / sum_eq too)\n",
    "        x0_seed = x0\n",
    "        if isinstance(x0, (list, tuple, np.ndarray)) and np.asarray(x0).ndim == 2:\n",
    "            x0_seed = np.asarray(x0)[int(seed)]\n",
    "        x0_seed = np.asarray(x0_seed, int).ravel()\n",
    "\n",
    "        sum_eq_seed = sum_eq\n",
    "        if isinstance(sum_eq, (list, tuple, np.ndarray)) and np.asarray(sum_eq).ndim > 0:\n",
    "            sum_eq_seed = int(np.asarray(sum_eq)[int(seed)])\n",
    "        sum_eq_seed = int(sum_eq_seed)\n",
    "\n",
    "\n",
    "        for name, model_type, model_params_base in runs:\n",
    "            # ---- per-run params ----\n",
    "            dp = dict(dataset_params)\n",
    "            if vary_dataset_seed and \"seed\" in dp:\n",
    "                dp[\"seed\"] = int(seed)\n",
    "\n",
    "            mp = dict(model_params_base)\n",
    "            if vary_model_init_seed and \"seed\" in mp:\n",
    "                mp[\"seed\"] = int(seed)\n",
    "\n",
    "            tp = dict(train_base)\n",
    "            tp[\"seed\"] = int(seed)\n",
    "            tp[\"lr\"] = float(lr_map[model_type])\n",
    "\n",
    "            # avoid plot spam unless explicitly allowed\n",
    "            if (n_seeds > 1) and (tp.get(\"plot_every\", 0) not in (0, None)) and (not allow_plots_multi_seed):\n",
    "                tp[\"plot_every\"] = 0\n",
    "\n",
    "            # ---- TRAIN ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                out = generate_and_train_simple(dataset_type, dp, model_type, mp, tp)\n",
    "            except Exception as e:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"TRAIN\", error=repr(e)))\n",
    "                learn_rows.append(dict(seed=seed, model=name, train_time=np.nan, best_epoch=np.nan,\n",
    "                                       best_val=np.nan, test=np.nan, train_err=repr(e)))\n",
    "                continue\n",
    "            train_time = time.perf_counter() - t0\n",
    "\n",
    "            if isinstance(out, tuple) and len(out) == 4:\n",
    "                model, data, hist, spec = out\n",
    "            elif isinstance(out, tuple) and len(out) == 3:\n",
    "                model, data, hist = out\n",
    "                n_params = sum(p.numel() for p in model.parameters())\n",
    "                spec = dict(n_params=int(n_params), extra=\"\", train_params=tp)\n",
    "            else:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"TRAIN\", error=f\"Unexpected return: {type(out)}\"))\n",
    "                continue\n",
    "\n",
    "            # ---- model spec (seed 0 only) ----\n",
    "            if seed == seeds[0]:\n",
    "                spec_rows.append(dict(\n",
    "                    model=name,\n",
    "                    n_params=int(spec.get(\"n_params\", np.nan)),\n",
    "                    details=str(spec.get(\"extra\", \"\")),\n",
    "                    lr=float(spec.get(\"train_params\", {}).get(\"lr\", tp[\"lr\"])),\n",
    "                    batch_size=int(spec.get(\"train_params\", {}).get(\"batch_size\", tp[\"batch_size\"])),\n",
    "                    epochs=int(spec.get(\"train_params\", {}).get(\"epochs\", tp[\"epochs\"])),\n",
    "                ))\n",
    "\n",
    "            device = data[\"device\"]\n",
    "            scaler = data[\"scaler\"]\n",
    "            obj = make_obj(model, scaler, device, chunk=int(tp.get(\"plot_chunk\", 4096)))\n",
    "            gt = data.get(\"true\", None)\n",
    "\n",
    "            # ---- compute GT optimum (once per seed) ----\n",
    "            if seed not in gt_cache_by_seed:\n",
    "                t_gt0 = time.perf_counter()\n",
    "                try:\n",
    "                    x_gt, y_gt, gt_info = solve_true_opt(gt, xmin, xmax, sum_eq_seed, time_limit=time_limit, verbose=False, seed=seed)\n",
    "                    gt_time = time.perf_counter() - t_gt0\n",
    "                    # evaluate with the existing helper to be extra sure\n",
    "                    y_gt_check = float(eval_true_obj(gt, x_gt))\n",
    "                    if np.isfinite(y_gt_check) and abs(y_gt_check - float(y_gt)) / (abs(float(y_gt_check)) + 1e-12) > 1e-6:\n",
    "                        fail_rows.append(dict(seed=seed, model=name, stage=\"GT_OPT_CHECK\",\n",
    "                                              error=f\"gt solver mismatch: solver={y_gt:.6g} eval_true_obj={y_gt_check:.6g}\"))\n",
    "                    gt_cache_by_seed[seed] = dict(x=str(np.asarray(x_gt, int).tolist()),\n",
    "                                                  true_y=float(y_gt_check if np.isfinite(y_gt_check) else y_gt),\n",
    "                                                  runtime=float(gt_time),\n",
    "                                                  err=None)\n",
    "                except Exception as e:\n",
    "                    gt_time = time.perf_counter() - t_gt0\n",
    "                    gt_cache_by_seed[seed] = dict(x=None, true_y=np.nan, runtime=float(gt_time), err=repr(e))\n",
    "                    fail_rows.append(dict(seed=seed, model=name, stage=\"GT_OPT\", error=repr(e)))\n",
    "\n",
    "            # ---- learning metrics: best val (normalized) + test loss (normalized) ----\n",
    "            test_norm, err_te = safe_norm_mse(model, scaler, device, data['raw']['Xte'], data['raw']['yte'], chunk=int(tp.get('plot_chunk', 4096)))\n",
    "            if err_te:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"EVAL_TEST\", error=err_te))\n",
    "\n",
    "            if obj._err_count > 0:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"OBJ\",\n",
    "                                      error=f\"obj() had {obj._err_count} forward failures; first={obj._err_first}\"))\n",
    "\n",
    "            val_curve = np.asarray(hist.get(\"val_mse_norm\", []), dtype=float)\n",
    "            best_ep = int(np.argmin(val_curve) + 1) if val_curve.size else np.nan\n",
    "            best_val = float(np.min(val_curve)) if val_curve.size else np.nan\n",
    "\n",
    "            learn_rows.append(dict(\n",
    "                seed=seed, model=name,\n",
    "                train_time=float(train_time),\n",
    "                best_epoch=best_ep,\n",
    "                best_val=float(best_val),\n",
    "                test=float(test_norm) if np.isfinite(test_norm) else np.nan,\n",
    "                train_err=(err_te),\n",
    "            ))\n",
    "\n",
    "            # ---- LOCAL SEARCH ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                obj_ls = obj\n",
    "                (ls_out, _ls_log) = suppress_stdout(\n",
    "                    local_search_l1_int, obj_ls, x0_seed, xmin, xmax,\n",
    "                    delta=delta, sum_eq=sum_eq_seed, print_every=0,\n",
    "                    silence=silence_local_search\n",
    "                )\n",
    "                x_best, y_best, ls_hist = ls_out\n",
    "                ls_time = time.perf_counter() - t0\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"LS\",\n",
    "                    x=str(np.asarray(x_best, int).tolist()),\n",
    "                    y=float(y_best),\n",
    "                    true_y=float(eval_true_obj(gt, x_best)),\n",
    "                    runtime=float(ls_time),\n",
    "                    t_best=float(t_to_best(ls_hist, y_best)),\n",
    "                    iters=int(len(ls_hist) - 1),\n",
    "                    err=None,\n",
    "                ))\n",
    "            except Exception as e:\n",
    "                ls_time = time.perf_counter() - t0\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"LS\",\n",
    "                    x=None, y=np.nan, true_y=np.nan,\n",
    "                    runtime=float(ls_time),\n",
    "                    t_best=np.nan, iters=np.nan,\n",
    "                    err=repr(e),\n",
    "                ))\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"LS\", error=repr(e)))\n",
    "\n",
    "            # ---- IP SOLVE (learned objective) ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                x_ip, y_ip, info = solve_ip(\n",
    "                    model_type, model, scaler, xmin, xmax, sum_eq_seed,\n",
    "                    time_limit=time_limit, verbose=True, seed=seed\n",
    "                )\n",
    "                ip_time = time.perf_counter() - t0\n",
    "\n",
    "                rel_err = check_ip_matches_obj(name, obj, x_ip, y_ip, strict=strict_ip_check, tol=ip_check_tol)\n",
    "                if rel_err > ip_check_tol:\n",
    "                    fail_rows.append(dict(seed=seed, model=name, stage=\"IP_CHECK\",\n",
    "                                          error=f\"rel_err={rel_err:.3e} (tol={ip_check_tol})\"))\n",
    "\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"IP\",\n",
    "                    x=str(np.asarray(x_ip, int).tolist()),\n",
    "                    y=float(y_ip),\n",
    "                    true_y=float(eval_true_obj(gt, x_ip)),\n",
    "                    runtime=float(ip_time),\n",
    "                    status=(info.get(\"status\") if isinstance(info, dict) else None),\n",
    "                    gap=(info.get(\"gap\") if isinstance(info, dict) else None),\n",
    "                    ip_rel_err=float(rel_err),\n",
    "                    err=None,\n",
    "                ))\n",
    "            except Exception as e:\n",
    "                ip_time = time.perf_counter() - t0\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"IP\",\n",
    "                    x=None, y=np.nan, true_y=np.nan,\n",
    "                    runtime=float(ip_time),\n",
    "                    status=None, gap=None, ip_rel_err=np.nan,\n",
    "                    err=repr(e),\n",
    "                ))\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"IP\", error=repr(e)))\n",
    "\n",
    "    # ---- build DFs ----\n",
    "    spec_df  = pd.DataFrame(spec_rows).drop_duplicates(\"model\").sort_values(\"model\").reset_index(drop=True)\n",
    "    learn_df = pd.DataFrame(learn_rows).sort_values([\"model\", \"seed\"]).reset_index(drop=True)\n",
    "    opt_df   = pd.DataFrame(opt_rows).sort_values([\"model\", \"seed\", \"method\"]).reset_index(drop=True)\n",
    "\n",
    "    fail_df = pd.DataFrame(fail_rows)\n",
    "    if fail_df.empty:\n",
    "        fail_df = pd.DataFrame(columns=[\"seed\", \"model\", \"stage\", \"error\"])\n",
    "    else:\n",
    "        for c in [\"stage\", \"model\", \"seed\"]:\n",
    "            if c not in fail_df.columns:\n",
    "                fail_df[c] = np.nan\n",
    "        fail_df = fail_df.sort_values([\"stage\", \"model\", \"seed\"]).reset_index(drop=True)\n",
    "\n",
    "    gt_df = pd.DataFrame([dict(seed=s, **d) for s, d in sorted(gt_cache_by_seed.items())]).sort_values(\"seed\")\n",
    "\n",
    "    # ---- per-seed gaps ----\n",
    "    gap_seed_df = pd.DataFrame(columns=[\"seed\", \"model\", \"ls_vs_ip_pct\", \"ip_true_vs_gt_pct\"])\n",
    "    if not opt_df.empty:\n",
    "        pivot_y = opt_df.pivot_table(index=[\"seed\", \"model\"], columns=\"method\", values=\"y\", aggfunc=\"first\")\n",
    "        pivot_true = opt_df.pivot_table(index=[\"seed\", \"model\"], columns=\"method\", values=\"true_y\", aggfunc=\"first\")\n",
    "\n",
    "        if (\"LS\" in pivot_y.columns) and (\"IP\" in pivot_y.columns):\n",
    "            ls_vs_ip = 100.0 * (pivot_y[\"LS\"] - pivot_y[\"IP\"]) / (np.abs(pivot_y[\"IP\"]) + 1e-12)\n",
    "        else:\n",
    "            ls_vs_ip = pd.Series(index=pivot_y.index, dtype=float)\n",
    "\n",
    "        # IP true vs GT optimum true\n",
    "        gt_true = gt_df.set_index(\"seed\")[\"true_y\"] if not gt_df.empty else pd.Series(dtype=float)\n",
    "        ip_true_vs_gt = []\n",
    "        for (seed, model), row in pivot_true.iterrows():\n",
    "            ipt = float(row.get(\"IP\", np.nan))\n",
    "            gtt = float(gt_true.get(seed, np.nan))\n",
    "            if np.isfinite(ipt) and np.isfinite(gtt):\n",
    "                ip_true_vs_gt.append(((seed, model), 100.0 * (ipt - gtt) / (abs(gtt) + 1e-12)))\n",
    "            else:\n",
    "                ip_true_vs_gt.append(((seed, model), np.nan))\n",
    "        ip_true_vs_gt = pd.Series({k: v for k, v in ip_true_vs_gt})\n",
    "\n",
    "        gap_seed_df = pd.DataFrame({\n",
    "            \"seed\": [k[0] for k in ip_true_vs_gt.index],\n",
    "            \"model\": [k[1] for k in ip_true_vs_gt.index],\n",
    "            \"ls_vs_ip_pct\": [float(ls_vs_ip.get(k, np.nan)) for k in ip_true_vs_gt.index],\n",
    "            \"ip_true_vs_gt_pct\": [float(ip_true_vs_gt.get(k, np.nan)) for k in ip_true_vs_gt.index],\n",
    "        })\n",
    "\n",
    "    # ---- LEARNING SUMMARY (mean ± SE over seeds) ----\n",
    "    learn_sum_rows = []\n",
    "    for model in sorted(learn_df[\"model\"].unique()):\n",
    "        sub = learn_df[learn_df[\"model\"] == model]\n",
    "        m, se = mean_se(sub[\"train_time\"]); train_time_s = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(sub[\"best_val\"]);   best_val_s   = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(sub[\"test\"]);       test_s       = fmt_mean_se(m, se)\n",
    "        learn_sum_rows.append(dict(model=model, train_time=train_time_s, best_val=best_val_s, test=test_s))\n",
    "    learn_summary_df = pd.DataFrame(learn_sum_rows).sort_values(\"model\").reset_index(drop=True)\n",
    "\n",
    "    # ---- OPTIMIZATION SUMMARY ----\n",
    "    opt_sum_rows = []\n",
    "    gt_x_repr = (repr_solution(gt_df.get(\"x\", pd.Series(dtype=str)), gt_df.get(\"seed\", None))\n",
    "                if not gt_df.empty else None)\n",
    "    m, se = mean_se(gt_df.get(\"true_y\", pd.Series(dtype=float))); gt_true_s = fmt_mean_se(m, se)\n",
    "    m, se = mean_se(gt_df.get(\"runtime\", pd.Series(dtype=float))); gt_time_s = fmt_mean_se(m, se)\n",
    "\n",
    "    for model in sorted(opt_df[\"model\"].unique()):\n",
    "        sub = opt_df[opt_df[\"model\"] == model]\n",
    "        row = {\"model\": model}\n",
    "\n",
    "        ls = sub[sub[\"method\"] == \"LS\"]\n",
    "        ip = sub[sub[\"method\"] == \"IP\"]\n",
    "\n",
    "        row[\"LS_x\"] = repr_solution(ls[\"x\"], ls.get(\"seed\", None))\n",
    "        m, se = mean_se(ls[\"y\"]);       row[\"LS_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ls[\"true_y\"]);  row[\"LS_true_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ls[\"runtime\"]); row[\"LS_time\"] = fmt_mean_se(m, se)\n",
    "\n",
    "        row[\"IP_x\"] = repr_solution(ip[\"x\"], ip.get(\"seed\", None))\n",
    "        m, se = mean_se(ip[\"y\"]);       row[\"IP_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ip[\"true_y\"]);  row[\"IP_true_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ip[\"runtime\"]); row[\"IP_time\"] = fmt_mean_se(m, se)\n",
    "\n",
    "        row[\"GT_x\"] = gt_x_repr\n",
    "        row[\"GT_true_y\"] = gt_true_s\n",
    "        row[\"GT_time\"] = gt_time_s\n",
    "\n",
    "        gsub = gap_seed_df[gap_seed_df[\"model\"] == model] if not gap_seed_df.empty else pd.DataFrame()\n",
    "        m, se = mean_se(gsub.get(\"ls_vs_ip_pct\", pd.Series(dtype=float))); row[\"LS_vs_IP_%\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(gsub.get(\"ip_true_vs_gt_pct\", pd.Series(dtype=float))); row[\"IP_true_vs_GT_%\"] = fmt_mean_se(m, se)\n",
    "\n",
    "        opt_sum_rows.append(row)\n",
    "\n",
    "    opt_summary_df = pd.DataFrame(opt_sum_rows).sort_values(\"model\").reset_index(drop=True)\n",
    "\n",
    "    # ---- Print tables ----\n",
    "    print(\"\\n=== MODEL SPECS (from seed 0 run) ===\")\n",
    "    if not spec_df.empty:\n",
    "        print(spec_df[[\"model\", \"n_params\", \"details\", \"lr\", \"batch_size\", \"epochs\"]].to_string(index=False))\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== LEARNING SUMMARY (mean ± SE over seeds) ===\")\n",
    "    if not learn_summary_df.empty:\n",
    "        print(learn_summary_df.to_string(index=False))\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== OPTIMIZATION SUMMARY (mean ± SE over seeds) ===\")\n",
    "    if not opt_summary_df.empty:\n",
    "        cols = [\n",
    "            \"model\",\n",
    "            \"LS_x\", \"LS_y\", \"LS_true_y\", \"LS_time\",\n",
    "            \"IP_x\", \"IP_y\", \"IP_true_y\", \"IP_time\",\n",
    "            \"GT_x\", \"GT_true_y\", \"GT_time\",\n",
    "            \"LS_vs_IP_%\", \"IP_true_vs_GT_%\",\n",
    "        ]\n",
    "        print(opt_summary_df[cols].to_string(index=False))\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== FAILURES / WARNINGS (if any) ===\")\n",
    "    if fail_df.shape[0] == 0:\n",
    "        print(\"None\")\n",
    "    else:\n",
    "        print(fail_df.to_string(index=False))\n",
    "\n",
    "    return spec_df, learn_df, opt_df, fail_df, learn_summary_df, opt_summary_df, gt_df\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e88e7a1f",
   "metadata": {},
   "source": [
    "## Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3ac79bcb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ASSIGNMENT] inferred in_dim=20 from Xtmp shape=(2048, 20)\n",
      "\n",
      "\n",
      "===================== SEED 0 =====================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/27/6hvffxjj6mnf6nw61tsrkldm0000gp/T/ipykernel_11465/1051077697.py:159: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  return torch.load(str(artifact_path), map_location=\"cpu\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter=0  t=0.00s  best_y=711.443  x=[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "iter=1  t=0.07s  best_y=469.136  x=[1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1]\n",
      "iter=2  t=0.14s  best_y=341.784  x=[1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]\n",
      "iter=3  t=0.20s  best_y=333.129  x=[1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]\n",
      "STOP: local minimum. best_y=333.129 x=[1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 500\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  500\n",
      "\n",
      "Optimize a model with 129 rows, 10260 columns and 21329 nonzeros\n",
      "Model fingerprint: 0x30d4787e\n",
      "Variable types: 10240 continuous, 20 integer (0 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e+00, 4e+01]\n",
      "  Objective range  [1e-03, 1e+06]\n",
      "  Bounds range     [1e+00, 1e+06]\n",
      "  RHS range        [6e-02, 1e+02]\n",
      "Found heuristic solution: objective 6.867358e+08\n",
      "Presolve removed 0 rows and 839 columns\n",
      "Presolve time: 0.01s\n",
      "Presolved: 129 rows, 9421 columns, 19651 nonzeros\n",
      "Variable types: 9401 continuous, 20 integer (20 binary)\n",
      "\n",
      "Root relaxation: objective 3.205109e+02, 197 iterations, 0.00 seconds (0.01 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  320.51088    0    6 6.8674e+08  320.51088   100%     -    0s\n",
      "H    0     0                     339.4861718  326.97152  3.69%     -    0s\n",
      "     0     0  326.97152    0    7  339.48617  326.97152  3.69%     -    0s\n",
      "     0     0  330.72539    0    6  339.48617  330.72539  2.58%     -    0s\n",
      "H    0     0                     333.1293619  330.72539  0.72%     -    0s\n",
      "     0     0  333.12936    0    6  333.12936  333.12936  0.00%     -    0s\n",
      "\n",
      "Cutting planes:\n",
      "  MIR: 21\n",
      "  RLT: 6\n",
      "\n",
      "Explored 1 nodes (236 simplex iterations) in 0.04 seconds (0.08 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 3: 333.129 339.486 6.86736e+08 \n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 3.331293619092e+02, best bound 3.331293619092e+02, gap 0.0000%\n",
      "[CHECK DFN] obj(x_ip)=333.129  ip_y=333.129  rel_err=8.529e-08\n",
      "iter=0  t=0.00s  best_y=662.891  x=[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "iter=1  t=0.00s  best_y=450.725  x=[1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1]\n",
      "iter=2  t=0.01s  best_y=345.693  x=[1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]\n",
      "STOP: local minimum. best_y=345.693 x=[1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 500\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  500\n",
      "\n",
      "Optimize a model with 1282 rows, 789 columns and 21397 nonzeros\n",
      "Model fingerprint: 0x293a6751\n",
      "Variable types: 513 continuous, 276 integer (256 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [4e-06, 1e+02]\n",
      "  Objective range  [2e+02, 2e+02]\n",
      "  Bounds range     [1e+00, 1e+02]\n",
      "  RHS range        [5e-05, 1e+02]\n",
      "Presolve removed 257 rows and 1 columns\n",
      "Presolve time: 0.04s\n",
      "Presolved: 1025 rows, 788 columns, 20891 nonzeros\n",
      "Variable types: 512 continuous, 276 integer (276 binary)\n",
      "Found heuristic solution: objective 702.0605117\n",
      "\n",
      "Root relaxation: objective -1.027513e+03, 378 iterations, 0.00 seconds (0.01 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0 -1027.5134    0  137  702.06051 -1027.5134   246%     -    0s\n",
      "     0     0 -704.42381    0  172  702.06051 -704.42381   200%     -    0s\n",
      "     0     0 -653.32439    0  170  702.06051 -653.32439   193%     -    0s\n",
      "     0     0 -603.37647    0  177  702.06051 -603.37647   186%     -    0s\n",
      "     0     0 -603.15541    0  179  702.06051 -603.15541   186%     -    0s\n",
      "     0     0 -494.26170    0  169  702.06051 -494.26170   170%     -    0s\n",
      "     0     0 -462.33213    0  171  702.06051 -462.33213   166%     -    0s\n",
      "     0     0 -452.24222    0  167  702.06051 -452.24222   164%     -    0s\n",
      "     0     0 -449.96480    0  166  702.06051 -449.96480   164%     -    0s\n",
      "     0     0 -449.16999    0  166  702.06051 -449.16999   164%     -    0s\n",
      "     0     0 -448.78954    0  167  702.06051 -448.78954   164%     -    0s\n",
      "     0     0 -448.16914    0  166  702.06051 -448.16914   164%     -    0s\n",
      "     0     0 -448.14461    0  166  702.06051 -448.14461   164%     -    0s\n",
      "     0     0 -292.56651    0  175  702.06051 -292.56651   142%     -    0s\n",
      "     0     0 -260.45333    0  176  702.06051 -260.45333   137%     -    1s\n",
      "     0     0 -258.67701    0  176  702.06051 -258.67701   137%     -    1s\n",
      "     0     0 -257.90038    0  176  702.06051 -257.90038   137%     -    1s\n",
      "     0     0 -257.49683    0  175  702.06051 -257.49683   137%     -    1s\n",
      "     0     0 -214.58242    0  172  702.06051 -214.58242   131%     -    1s\n",
      "     0     0 -210.71713    0  173  702.06051 -210.71713   130%     -    1s\n",
      "     0     0 -209.53224    0  174  702.06051 -209.53224   130%     -    1s\n",
      "     0     0 -209.09258    0  172  702.06051 -209.09258   130%     -    1s\n",
      "     0     0 -195.84713    0  173  702.06051 -195.84713   128%     -    1s\n",
      "     0     0 -193.34270    0  173  702.06051 -193.34270   128%     -    1s\n",
      "H    0     0                     647.0590537 -192.42774   130%     -    1s\n",
      "     0     0 -192.42774    0  174  647.05905 -192.42774   130%     -    1s\n",
      "     0     0 -192.03597    0  173  647.05905 -192.03597   130%     -    1s\n",
      "     0     0 -190.99831    0  175  647.05905 -190.99831   130%     -    1s\n",
      "     0     0 -190.93750    0  174  647.05905 -190.93750   130%     -    2s\n",
      "     0     2 -190.80579    0  174  647.05905 -190.80579   129%     -    2s\n",
      "* 1437   809             109     592.2009672  -73.32066   112%  41.1    3s\n",
      "* 1533   803              99     442.5630021  -51.56164   112%  40.0    3s\n",
      "* 1598   799              99     386.8048301  -44.88264   112%  40.1    3s\n",
      "H 2609  1276                     385.7071561  -15.70896   104%  40.3    4s\n",
      "  2710  1339   19.45753    9  174  385.70716  -14.49159   104%  40.6    5s\n",
      "H 2719  1277                     372.1160447  -14.49159   104%  40.4    8s\n",
      "  2732  1286   32.86824   10  181  372.11604  -14.49159   104%  40.2   10s\n",
      "H 5422  2013                     371.0049734   -1.35378   100%  49.6   13s\n",
      "  8246  3070 infeasible   27       371.00497   29.74967  92.0%  50.2   15s\n",
      " 14427  5448  107.86744   23  153  371.00497   67.05590  81.9%  50.3   20s\n",
      "*14670  5513              27     363.8873487   67.52367  81.4%  50.7   24s\n",
      " 14853  5615  263.31827   28  120  363.88735   67.52367  81.4%  50.9   25s\n",
      "H15425  5854                     357.0189974   72.72019  79.6%  51.4   26s\n",
      "H15770  5790                     345.6927800   73.41816  78.8%  51.2   26s\n",
      " 20015  7213  191.48056   28  141  345.69278   89.99052  74.0%  50.6   30s\n",
      " 28753  8982  165.21612   29  149  345.69278  117.28699  66.1%  51.2   35s\n",
      " 33746  9687  200.01254   28  141  345.69278  130.98539  62.1%  52.0   40s\n",
      " 39307 10427     cutoff   36       345.69278  145.90983  57.8%  52.7   61s\n",
      " 42304 10957  292.14317   25  137  345.69278  154.76208  55.2%  54.0   65s\n",
      " 49779 10824  234.59267   35  125  345.69278  172.07065  50.2%  55.0   70s\n",
      " 57524 10717  248.91768   26  145  345.69278  188.46884  45.5%  55.6   75s\n",
      " 62624 10118  246.06662   32  130  345.69278  203.21524  41.2%  56.3   80s\n",
      " 70910  9069  295.89669   43  135  345.69278  223.98858  35.2%  56.7   85s\n",
      " 79395  7646 infeasible   29       345.69278  243.29807  29.6%  56.6   90s\n",
      " 87925  5543     cutoff   29       345.69278  266.58116  22.9%  56.2   95s\n",
      " 97655  1598     cutoff   37       345.69278  296.19583  14.3%  55.5  100s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 32\n",
      "  Cover: 511\n",
      "  Implied bound: 5\n",
      "  MIR: 1460\n",
      "  Flow cover: 486\n",
      "  Inf proof: 29\n",
      "  RLT: 1620\n",
      "  Relax-and-lift: 77\n",
      "\n",
      "Explored 101655 nodes (5605798 simplex iterations) in 101.05 seconds (295.36 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 10: 345.693 357.019 363.887 ... 647.059\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 3.456927800035e+02, best bound 3.456927800035e+02, gap 0.0000%\n",
      "[CHECK MLP] obj(x_ip)=345.693  ip_y=345.693  rel_err=1.338e-09\n",
      "iter=0  t=0.00s  best_y=696.178  x=[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "iter=1  t=0.11s  best_y=505.658  x=[1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1]\n",
      "iter=2  t=0.23s  best_y=413.966  x=[1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1]\n",
      "iter=3  t=0.34s  best_y=385.394  x=[1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]\n",
      "STOP: local minimum. best_y=385.394 x=[1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 500\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  500\n",
      "\n",
      "Optimize a model with 278 rows, 10992 columns and 22004 nonzeros\n",
      "Model fingerprint: 0xd1c508c1\n",
      "Variable types: 10972 continuous, 20 integer (0 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e+00, 2e+00]\n",
      "  Objective range  [3e-03, 1e+06]\n",
      "  Bounds range     [1e+00, 1e+06]\n",
      "  RHS range        [1e+00, 5e+01]\n",
      "Found heuristic solution: objective 1.563794e+08\n",
      "Presolve removed 0 rows and 3449 columns\n",
      "Presolve time: 0.01s\n",
      "Presolved: 278 rows, 7543 columns, 15105 nonzeros\n",
      "Variable types: 7523 continuous, 20 integer (20 binary)\n",
      "\n",
      "Root relaxation: objective 3.840260e+02, 319 iterations, 0.00 seconds (0.01 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  384.02605    0    3 1.5638e+08  384.02605   100%     -    0s\n",
      "H    0     0                     441.2187507  384.02605  13.0%     -    0s\n",
      "     0     0  384.87097    0    2  441.21875  384.87097  12.8%     -    0s\n",
      "H    0     0                     440.2938934  384.88446  12.6%     -    0s\n",
      "H    0     0                     385.3943509  384.88446  0.13%     -    0s\n",
      "\n",
      "Cutting planes:\n",
      "  MIR: 3\n",
      "\n",
      "Explored 1 nodes (344 simplex iterations) in 0.06 seconds (0.06 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 3: 385.394 440.294 1.56379e+08 \n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 3.853943509283e+02, best bound 3.853943509283e+02, gap 0.0000%\n",
      "[CHECK DFN_AfixI] obj(x_ip)=385.394  ip_y=385.394  rel_err=7.223e-09\n",
      "iter=0  t=0.00s  best_y=713.712  x=[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "iter=1  t=0.00s  best_y=423.313  x=[1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1]\n",
      "iter=2  t=0.01s  best_y=352.515  x=[1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1]\n",
      "iter=3  t=0.02s  best_y=343.986  x=[1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]\n",
      "STOP: local minimum. best_y=343.986 x=[1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 500\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  500\n",
      "\n",
      "Optimize a model with 1001 rows, 21 columns and 21020 nonzeros\n",
      "Model fingerprint: 0xfa6de42f\n",
      "Variable types: 1 continuous, 20 integer (0 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [4e-06, 2e+00]\n",
      "  Objective range  [2e+02, 2e+02]\n",
      "  Bounds range     [1e+00, 1e+00]\n",
      "  RHS range        [2e+00, 1e+01]\n",
      "Found heuristic solution: objective 976.8576599\n",
      "Presolve time: 0.01s\n",
      "Presolved: 1001 rows, 21 columns, 21020 nonzeros\n",
      "Variable types: 1 continuous, 20 integer (20 binary)\n",
      "\n",
      "Root relaxation: objective 3.238178e+02, 25 iterations, 0.00 seconds (0.01 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  323.81780    0    2  976.85766  323.81780  66.9%     -    0s\n",
      "H    0     0                     502.0955191  323.81780  35.5%     -    0s\n",
      "H    0     0                     354.7466260  323.81780  8.72%     -    0s\n",
      "H    0     0                     352.5149490  323.81780  8.14%     -    0s\n",
      "H    0     0                     346.0514656  323.81780  6.42%     -    0s\n",
      "H    0     0                     343.9864078  323.81780  5.86%     -    0s\n",
      "\n",
      "Cutting planes:\n",
      "  Cover: 1\n",
      "\n",
      "Explored 1 nodes (25 simplex iterations) in 0.07 seconds (0.11 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 6: 343.986 346.051 352.515 ... 976.858\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 3.439864078003e+02, best bound 3.439864078003e+02, gap 0.0000%\n",
      "[CHECK MaxAffine] obj(x_ip)=343.986  ip_y=343.986  rel_err=5.419e-08\n",
      "iter=0  t=0.00s  best_y=664.82  x=[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "iter=1  t=0.01s  best_y=438.247  x=[1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1]\n",
      "iter=2  t=0.01s  best_y=353.701  x=[1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1]\n",
      "STOP: local minimum. best_y=353.701 x=[1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1]\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 500\n",
      "Set parameter FuncNonlinear to value 1\n",
      "Set parameter FeasibilityTol to value 1e-09\n",
      "Set parameter OptimalityTol to value 1e-09\n",
      "Set parameter IntFeasTol to value 1e-09\n",
      "Set parameter NumericFocus to value 3\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  500\n",
      "FeasibilityTol  1e-09\n",
      "IntFeasTol  1e-09\n",
      "OptimalityTol  1e-09\n",
      "NumericFocus  3\n",
      "\n",
      "Optimize a model with 3004 rows, 3025 columns and 27026 nonzeros\n",
      "Model fingerprint: 0xac378dc6\n",
      "Model has 1001 function constraints treated as nonlinear\n",
      "  1000 EXP, 1 LOG\n",
      "Variable types: 3005 continuous, 20 integer (0 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [6e-05, 2e+02]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-12, 2e+03]\n",
      "  RHS range        [1e+01, 5e+02]\n",
      "Warning: x < 1e-10 in domain of log(x).\n",
      "         Setting lower bound to 1e-10.\n",
      "Presolve removed 2002 rows and 1002 columns\n",
      "Presolve time: 0.02s\n",
      "Presolved: 6007 rows, 2024 columns, 33031 nonzeros\n",
      "Presolved model has 1001 nonlinear constraint(s)\n",
      "\n",
      "Solving non-convex MINLP\n",
      "\n",
      "Variable types: 2004 continuous, 20 integer (20 binary)\n",
      "\n",
      "Root relaxation: objective 9.747240e+01, 1398 iterations, 0.02 seconds (0.08 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0   97.47240    0 1001          -   97.47240      -     -    0s\n",
      "     0     2   98.76601    0 1001          -   98.76601      -     -    0s\n",
      "*  789   587              29     437.6052039  110.14015  74.8%  20.6    4s\n",
      "H 1189   792                     370.2970011  116.92663  68.4%  30.1    4s\n",
      "H 1191   792                     351.1806245  116.92663  66.7%  30.0    4s\n",
      "* 1339   785              37     332.8149989  137.21618  58.8%  30.5    4s\n",
      "  1375   769  226.20242   13  412  332.81500  137.21618  58.8%  30.4    5s\n",
      "* 1397   769              38     329.8873600  137.21618  58.4%  30.3    5s\n",
      "* 3650   559              29     320.1427421  264.74543  17.3%  37.7    7s\n",
      "* 3857   455              28     319.8650991  264.74543  17.2%  36.5    7s\n",
      "* 4221   262              29     316.5946919  287.25161  9.27%  33.9    7s\n",
      "\n",
      "Explored 5053 nodes (148170 simplex iterations) in 8.18 seconds (18.40 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 8: 316.595 319.865 320.143 ... 437.605\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 3.165946918774e+02, best bound 3.165946918774e+02, gap 0.0000%\n",
      "[CHECK LSET] obj(x_ip)=353.701  ip_y=316.595  rel_err=1.049e-01\n",
      "\n",
      "=== MODEL SPECS (from seed 0 run) ===\n",
      "    model  n_params                                                  details    lr  batch_size  epochs\n",
      "      DFN     17707  layers=[32, 64, 32] p_list=[1, 1] alpha=0.005 beta=-2.0 0.100           8     500\n",
      "DFN_AfixI     21524 layers=[10, 256, 11] p_list=[1, 1] alpha=0.005 beta=-2.0 0.100           8     500\n",
      "     LSET     21000                                     n_pieces=1000 T=0.05 0.001           8     500\n",
      "      MLP     19329                                        hidden=[128, 128] 0.001           8     500\n",
      "MaxAffine     21000                                            n_pieces=1000 0.001           8     500\n",
      "\n",
      "=== LEARNING SUMMARY (mean ± SE over seeds) ===\n",
      "    model     train_time       best_val           test\n",
      "      DFN 0.00743608 ± 0 0.00804592 ± 0 0.00764216 ± 0\n",
      "DFN_AfixI 0.00854958 ± 0  0.0186284 ± 0   0.017786 ± 0\n",
      "     LSET 0.00350825 ± 0 0.00530773 ± 0  0.0063577 ± 0\n",
      "      MLP 0.00319529 ± 0 0.00605459 ± 0 0.00772086 ± 0\n",
      "MaxAffine 0.00396737 ± 0 0.00421545 ± 0 0.00461269 ± 0\n",
      "\n",
      "=== OPTIMIZATION SUMMARY (mean ± SE over seeds) ===\n",
      "    model                                                         LS_x        LS_y LS_true_y       LS_time                                                         IP_x        IP_y IP_true_y      IP_time                                                         GT_x GT_true_y        GT_time       LS_vs_IP_% IP_true_vs_GT_%\n",
      "      DFN [1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1] 333.129 ± 0   350 ± 0  0.262906 ± 0 [1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1] 333.129 ± 0   350 ± 0 0.105354 ± 0 [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]   350 ± 0 0.00347883 ± 0 -8.52915e-06 ± 0           0 ± 0\n",
      "DFN_AfixI [1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1] 385.394 ± 0   350 ± 0  0.451983 ± 0 [1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1] 385.394 ± 0   350 ± 0 0.123885 ± 0 [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]   350 ± 0 0.00347883 ± 0 -7.22313e-07 ± 0           0 ± 0\n",
      "     LSET [1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1] 353.701 ± 0   359 ± 0 0.0158945 ± 0 [1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1] 316.595 ± 0   359 ± 0  8.29231 ± 0 [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]   350 ± 0 0.00347883 ± 0      11.7205 ± 0     2.57143 ± 0\n",
      "      MLP [1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1] 345.693 ± 0   365 ± 0 0.0123989 ± 0 [1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1] 345.693 ± 0   365 ± 0  101.133 ± 0 [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]   350 ± 0 0.00347883 ± 0  -1.3378e-07 ± 0     4.28571 ± 0\n",
      "MaxAffine [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1] 343.986 ± 0   350 ± 0   0.01991 ± 0 [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1] 343.986 ± 0   350 ± 0 0.147768 ± 0 [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1]   350 ± 0 0.00347883 ± 0 -5.41887e-06 ± 0           0 ± 0\n",
      "\n",
      "=== FAILURES / WARNINGS (if any) ===\n",
      " seed model    stage                          error\n",
      "    0  LSET IP_CHECK rel_err=1.049e-01 (tol=0.0001)\n"
     ]
    }
   ],
   "source": [
    "# ===================== ASSIGNMENT DATASET =====================\n",
    "N_SEEDS = 1\n",
    "VARY_DATASET_SEED = True\n",
    "VARY_MODEL_INIT_SEED = True\n",
    "STRICT_IP_CHECK = False\n",
    "IP_CHECK_TOL = 1e-4\n",
    "SILENCE_LOCAL_SEARCH = False\n",
    "ALLOW_PLOTS_MULTI_SEED = True\n",
    "\n",
    "dataset_type = \"assignment\"\n",
    "dataset_params = dict(\n",
    "    K=2048,\n",
    "    num_nodes=20,\n",
    "    c_min=1,\n",
    "    c_max=1000,\n",
    "    noise_std=0.0,\n",
    "    seed=0,\n",
    ")\n",
    "\n",
    "Xtmp, _, _ = generate_bipartite_subset_matching_dataset(**dataset_params)\n",
    "in_dim = int(np.asarray(Xtmp).shape[1])\n",
    "print(f\"[ASSIGNMENT] inferred in_dim={in_dim} from Xtmp shape={np.asarray(Xtmp).shape}\")\n",
    "\n",
    "train_base = dict(\n",
    "    epochs=500,\n",
    "    batch_size=8,\n",
    "    val_frac=0.15,\n",
    "    test_frac=0.15,\n",
    "    seed=0,\n",
    "    device=\"cpu\",\n",
    "    eps=1e-8,\n",
    "    weight_decay=0.0,\n",
    "    plot_every=10,\n",
    "    plot_points=128,\n",
    "    plot_chunk=128,\n",
    ")\n",
    "\n",
    "# ---- DFN (learnable A) ----\n",
    "dfn_params = dict(\n",
    "    input_dim=in_dim, layer_sizes=[32, 64, 32], p_list=[1, 1],\n",
    "    seed=0, alpha=5e-3, beta=-2.0\n",
    ")\n",
    "\n",
    "# ---- DFN (fixed A = I) ----\n",
    "dfn_Afix_params = dict(\n",
    "    input_dim=in_dim,\n",
    "    layer_sizes=[10, 256, 11],\n",
    "    p_list=[1, 1],\n",
    "    seed=0,\n",
    "    alpha=5e-3,\n",
    "    beta=-2.0,\n",
    "    A_fixed=np.eye(in_dim, dtype=np.float32),\n",
    ")\n",
    "\n",
    "# ---- other models ----\n",
    "mlp_params  = dict(in_dim=in_dim, hidden_dims=[128, 128], out_dim=1)\n",
    "maff_params = dict(in_dim=in_dim, n_pieces=1000)\n",
    "lset_params = dict(in_dim=in_dim, n_pieces=1000, T=0.05)\n",
    "\n",
    "lr_map = dict(DFN=1e-1, MLP=1e-3, MaxAffine=1e-3, LSET=1e-3)\n",
    "time_limit = 500\n",
    "\n",
    "# feasible integer box + feasible x0 / sum_eq\n",
    "xmin  = np.full(in_dim, 0, dtype=int)\n",
    "xmax  = np.full(in_dim, 1, dtype=int)\n",
    "\n",
    "# x0 + sum_eq are generated per seed so the run seed controls them too\n",
    "x0_list = []\n",
    "sum_eq_list = []\n",
    "for seed in range(int(N_SEEDS)):\n",
    "    rng = np.random.default_rng(int(seed))\n",
    "    x0_s = rng.integers(0, 2, size=in_dim, dtype=int)\n",
    "    x0_list.append(x0_s)\n",
    "    sum_eq_list.append(int(x0_s.sum()))\n",
    "x0 = x0_list[0] if int(N_SEEDS) == 1 else np.stack(x0_list, axis=0)\n",
    "sum_eq = sum_eq_list[0] if int(N_SEEDS) == 1 else sum_eq_list\n",
    "delta = 2\n",
    "\n",
    "\n",
    "runs = [\n",
    "    (\"DFN\",        \"DFN\", dfn_params),\n",
    "    (\"MLP\",        \"MLP\", mlp_params),\n",
    "    (\"DFN_AfixI\",  \"DFN\", dfn_Afix_params),\n",
    "    (\"MaxAffine\",  \"MaxAffine\", maff_params),\n",
    "    (\"LSET\",       \"LSET\", lset_params),\n",
    "]\n",
    "\n",
    "_ = run_benchmark(\n",
    "    dataset_type=dataset_type,\n",
    "    dataset_params=dataset_params,\n",
    "    runs=runs,\n",
    "    train_base=train_base,\n",
    "    lr_map=lr_map,\n",
    "    x0=x0, xmin=xmin, xmax=xmax,\n",
    "    delta=delta, sum_eq=sum_eq,\n",
    "    n_seeds=N_SEEDS,\n",
    "    vary_dataset_seed=VARY_DATASET_SEED,\n",
    "    vary_model_init_seed=VARY_MODEL_INIT_SEED,\n",
    "    strict_ip_check=STRICT_IP_CHECK,\n",
    "    ip_check_tol=IP_CHECK_TOL,\n",
    "    silence_local_search=SILENCE_LOCAL_SEARCH,\n",
    "    allow_plots_multi_seed=ALLOW_PLOTS_MULTI_SEED,\n",
    "    time_limit=time_limit,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92dbc0ee-bbe1-4fdd-8782-44ece09dda4b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dfn]",
   "language": "python",
   "name": "conda-env-dfn-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}