{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f5b14808",
   "metadata": {},
   "source": [
    "# DFN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2efb5921",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ce8269a4",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import cppimport\n",
    "\n",
    "import sys, time, math, io, contextlib\n",
    "from pathlib import Path\n",
    "from typing import Optional, List, Tuple, Dict, Any\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from IPython.display import display\n",
    "\n",
    "import gurobipy as gp\n",
    "from gurobipy import GRB\n",
    "\n",
    "_TOL = 1e-9"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417583e0",
   "metadata": {},
   "source": [
    "## LEMON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "62767112",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "repo = Path().resolve().parent\n",
    "if str(repo) not in sys.path:\n",
    "    sys.path.insert(0, str(repo))\n",
    "\n",
    "lemon_mcf = cppimport.imp(\"lemon_mcf\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15e6bd1b",
   "metadata": {},
   "source": [
    "### Quick sanity checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "88381d28",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'status': 1, 'flow': array([1., 2., 0.]), 'potential': array([-2.,  0., -1.]), 'reduced_cost': array([0., 0., 4.]), 'at_capacity': array([False,  True, False]), 'total_cost': 4.0}\n",
      "{'value': 4.0, 'flow': array([2., 2., 2., 2.])}\n"
     ]
    }
   ],
   "source": [
    "n = 3\n",
    "src    = np.array([0, 0, 1], dtype=np.int64)\n",
    "dst    = np.array([1, 2, 2], dtype=np.int64)\n",
    "cost   = np.array([2.0, 1.0, 3.0], dtype=np.float64)\n",
    "cap    = np.array([5.0, 2.0, 4.0], dtype=np.float64)\n",
    "supply = np.array([3.0, -1.0, -2.0], dtype=np.float64)\n",
    "\n",
    "out_min_cost_flow = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply)\n",
    "print(out_min_cost_flow)\n",
    "\n",
    "n = 4\n",
    "src = np.array([0,0,1,2], dtype=np.int64)\n",
    "dst = np.array([1,2,3,3], dtype=np.int64)\n",
    "cap = np.array([3.0,2.0,2.0,4.0], dtype=np.float64)\n",
    "out_max_flow = lemon_mcf.max_flow(n, src, dst, cap, 0, 3)\n",
    "print(out_max_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75235271",
   "metadata": {},
   "source": [
    "## Dataset generators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "67763cc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_mdvsp_dataset(\n",
    "    K: int,\n",
    "    filename: str,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "    max_trips=None,\n",
    "    max_succ=None,\n",
    "):\n",
    "    \"\"\"Multiple-Depot Vehicle Scheduling (MDVSP) dataset.\n",
    "\n",
    "    Returns:\n",
    "      X: (K, m) integer-ish capacities (float32)\n",
    "      y: (K,) min-cost-flow objective values (float32)\n",
    "      gt: dict containing the fixed network pieces (for later evaluation)\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    with open(filename) as f:\n",
    "        m, n, l = map(int, f.readline().split())\n",
    "        f.readline()  # blank line\n",
    "        trips = np.loadtxt(f, max_rows=n, dtype=np.int64)[:max_trips]\n",
    "        D = np.loadtxt(f, max_rows=l, dtype=np.int64)\n",
    "\n",
    "    p, s, q, e = trips.T\n",
    "    ntr = len(trips)\n",
    "\n",
    "    # node ids\n",
    "    SS = 0\n",
    "    depS = 1 + np.arange(m)\n",
    "    depT = 1 + m + np.arange(m)\n",
    "    trS  = 1 + 2*m + np.arange(ntr)\n",
    "    trT  = 1 + 2*m + ntr + np.arange(ntr)\n",
    "    TT = 1 + 2*m + 2*ntr\n",
    "    N = TT + 1\n",
    "\n",
    "    src, dst, cost, cap = [], [], [], []\n",
    "\n",
    "    # depot boundary arcs (capacity is what we learn/provide per sample)\n",
    "    for d in range(m):\n",
    "        src += [SS, int(depT[d])]\n",
    "        dst += [int(depS[d]), TT]\n",
    "        cost += [0.0, 0.0]\n",
    "        cap  += [0.0, 0.0]\n",
    "    idxSS = np.arange(0, 2*m, 2)  # arcs SS->depS\n",
    "    idxTT = np.arange(1, 2*m, 2)  # arcs depT->TT\n",
    "\n",
    "    # trip arcs (always cap=1)\n",
    "    src += trS.tolist()\n",
    "    dst += trT.tolist()\n",
    "    cost += [0.0] * ntr\n",
    "    cap  += [1.0] * ntr\n",
    "\n",
    "    # depot-to-trip and trip-to-depot arcs\n",
    "    for d in range(m):\n",
    "        # depS -> trS\n",
    "        src += [int(depS[d])] * ntr\n",
    "        dst += trS.tolist()\n",
    "        cost += (5000 + 10 * D[d, p]).astype(float).tolist()\n",
    "        cap  += [1.0] * ntr\n",
    "\n",
    "        # trT -> depT\n",
    "        src += trT.tolist()\n",
    "        dst += [int(depT[d])] * ntr\n",
    "        cost += (5000 + 10 * D[q, d]).astype(float).tolist()\n",
    "        cap  += [1.0] * ntr\n",
    "\n",
    "    # feasible trip successor arcs (trT -> next trS)\n",
    "    order = np.argsort(s)\n",
    "    p2, s2 = p[order], s[order]\n",
    "    max_succ_eff = ntr if max_succ is None else int(max_succ)\n",
    "\n",
    "    for i in range(ntr):\n",
    "        travel = D[q[i], p2]                          # time from trip i end depot -> next trip start depot\n",
    "        feas = np.flatnonzero(s2 >= e[i] + travel)[:max_succ_eff]\n",
    "        j = order[feas]\n",
    "        if j.size:\n",
    "            src += [int(trT[i])] * j.size\n",
    "            dst += trS[j].tolist()\n",
    "            cost += (8 * travel[feas] + 2 * (s[j] - e[i])).astype(float).tolist()\n",
    "            cap  += [1.0] * j.size\n",
    "\n",
    "    src = np.asarray(src, dtype=np.int64)\n",
    "    dst = np.asarray(dst, dtype=np.int64)\n",
    "    cost = np.asarray(cost, dtype=np.float64)\n",
    "    cap0 = np.asarray(cap, dtype=np.float64)\n",
    "\n",
    "    # sample capacities X and compute y via max-flow + min-cost-flow\n",
    "    X = rng.integers(x_min, np.asarray(x_max) + 1, size=(K, m)).astype(np.float64)\n",
    "    y = np.empty(K, dtype=np.float64)\n",
    "\n",
    "    for k in range(K):\n",
    "        cap_k = cap0.copy()\n",
    "        cap_k[idxSS] = X[k]\n",
    "        cap_k[idxTT] = X[k]\n",
    "\n",
    "        Fmax = lemon_mcf.max_flow(N, src, dst, cap_k, SS, TT)[\"value\"]\n",
    "        supply = np.zeros(N, dtype=np.float64)\n",
    "        supply[SS] = Fmax\n",
    "        supply[TT] = -Fmax\n",
    "\n",
    "        y[k] = lemon_mcf.solve_mcf(N, src, dst, cost, cap_k, supply)[\"total_cost\"]\n",
    "        if noise_std:\n",
    "            y[k] += noise_std * rng.normal()\n",
    "\n",
    "    gt = dict(\n",
    "        type=\"mdvsp\", N=int(N), SS=int(SS), TT=int(TT),\n",
    "        src=src, dst=dst, cost=cost, cap0=cap0, idxSS=idxSS, idxTT=idxTT\n",
    "    )\n",
    "    return X.astype(np.float32), y.astype(np.float32), gt\n",
    "\n",
    "\n",
    "def generate_bipartite_subset_matching_dataset(\n",
    "    K: int, num_nodes: int, c_min: int, c_max: int, noise_std: float = 0.0, seed: int = 0\n",
    "):\n",
    "    \"\"\"Assignment-style dataset: choose a subset of left nodes, match to right nodes with min cost.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    C = rng.integers(c_min, c_max + 1, size=(num_nodes, num_nodes)).astype(np.float32)\n",
    "\n",
    "    X = np.zeros((K, num_nodes), dtype=np.float32)\n",
    "    y = np.zeros((K,), dtype=np.float32)\n",
    "\n",
    "    for k in range(K):\n",
    "        mask = rng.integers(0, 2, size=num_nodes, dtype=np.int8)\n",
    "        while not mask.any():\n",
    "            mask = rng.integers(0, 2, size=num_nodes, dtype=np.int8)\n",
    "        idx = np.flatnonzero(mask)\n",
    "        X[k, idx] = 1.0\n",
    "\n",
    "        r, c = linear_sum_assignment(C[idx, :])\n",
    "        y[k] = C[idx, :][r, c].sum()\n",
    "        if noise_std:\n",
    "            y[k] += noise_std * rng.normal()\n",
    "\n",
    "    gt = {\"type\": \"assignment\", \"C\": C.astype(np.float32)}\n",
    "    return X, y, gt\n",
    "\n",
    "\n",
    "def generate_convex_quadratic_dataset(\n",
    "    K: int,\n",
    "    dim: int,\n",
    "    eigen_min: float,\n",
    "    eigen_max: float,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    noise_std: float = 0.0,\n",
    "    seed: int = 0,\n",
    "    x_star_zero: bool = False,\n",
    "):\n",
    "    \"\"\"y = (x - x*)^T Q (x - x*) + noise, with Q symmetric PSD.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    U, R = np.linalg.qr(rng.standard_normal((dim, dim)))\n",
    "    U *= np.sign(np.diag(R) + 1e-12)\n",
    "    Q = U @ np.diag(rng.uniform(eigen_min, eigen_max, dim)) @ U.T\n",
    "\n",
    "    x_star = np.zeros(dim, dtype=np.int64) if x_star_zero else rng.integers(x_min, x_max + 1, size=dim)\n",
    "\n",
    "    X = rng.integers(x_min, x_max + 1, size=(K, dim)).astype(np.float32)\n",
    "    d = X - x_star\n",
    "    y = np.einsum(\"bi,ij,bj->b\", d, Q, d) + noise_std * rng.normal(size=K)\n",
    "\n",
    "    gt = {\"type\": \"quadratic\", \"Q\": Q.astype(np.float32), \"x_star\": x_star.astype(np.int64)}\n",
    "    return X.astype(np.float32), y.astype(np.float32), gt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dfc62e4",
   "metadata": {},
   "source": [
    "### Quick sanity checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "89201b4a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 16) (10,) \n",
      " [[284. 307. 453. 571.  20.  86. 494. 570. 149. 187. 522. 254. 164. 497.\n",
      "  154. 245.]\n",
      " [386. 330.  51.  16. 520. 452. 503. 323. 491. 198. 272. 473.  74. 182.\n",
      "   74. 272.]\n",
      " [587.  80. 230. 242. 543. 122. 301. 157.  11. 450.  37. 168. 299. 291.\n",
      "   70. 589.]\n",
      " [450. 577.  55. 435. 176. 325. 555. 166. 436.  96. 193. 582. 253. 310.\n",
      "  176.  69.]\n",
      " [255. 374. 273. 466. 218. 368. 464. 551. 256.  23. 431. 317. 524. 276.\n",
      "  221.  37.]] \n",
      " [3.059159e+07 3.057458e+07 3.069260e+07 3.056290e+07 3.054872e+07] \n",
      " 45057.67 \n",
      "\n",
      "(10, 10) (10,) \n",
      " [[0. 1. 1. 0. 0. 0. 0. 1. 0. 1.]\n",
      " [1. 0. 0. 0. 0. 0. 1. 1. 0. 0.]\n",
      " [0. 1. 0. 0. 1. 0. 1. 0. 1. 1.]\n",
      " [1. 1. 0. 0. 1. 1. 1. 1. 0. 1.]\n",
      " [0. 1. 0. 1. 1. 0. 0. 0. 1. 0.]] \n",
      " [ 665.  226.  993. 1351.  851.] \n",
      " 296.35287 \n",
      "\n",
      "(10, 10) (10,) \n",
      " [[-5. -9.  0.  7.  3. -9.  3. -3. -6. -1.]\n",
      " [ 8. 10. -8.  1.  5. -5. -5. -5. -6.  8.]\n",
      " [-6. -6. -8. -8.  6. -4.  6.  2.  8.  1.]\n",
      " [ 6.  7. -9.  1. -1. -4. -1. -2.  0.  7.]\n",
      " [ 7.  3.  4. 10.  3. -3. -9.  1. -6.  2.]] \n",
      " [16987.797  9943.988  8211.554  5990.925 11936.206] \n",
      " 2931.8052 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "# NOTE: MDVSP requires that `filename` exists on disk.\n",
    "X, y, _ = make_mdvsp_dataset(\n",
    "    K=10, filename=\"RN-16-3000-05.dat\", x_min=0, x_max=600, noise_std=0.0, seed=1, max_trips=5000, max_succ=50\n",
    ")\n",
    "print(X.shape, y.shape, \"\\n\", X[:5], \"\\n\", y[:5], \"\\n\", y.std(), \"\\n\")\n",
    "\n",
    "X, y, _ = generate_bipartite_subset_matching_dataset(\n",
    "    K=10, num_nodes=10, c_min=1, c_max=1000, noise_std=0.0, seed=0\n",
    ")\n",
    "print(X.shape, y.shape, \"\\n\", X[:5], \"\\n\", y[:5], \"\\n\", y.std(), \"\\n\")\n",
    "\n",
    "X, y, _ = generate_convex_quadratic_dataset(\n",
    "    K=10, dim=10, eigen_min=1.0, eigen_max=20.0, x_min=-10, x_max=10, noise_std=0.0, seed=0\n",
    ")\n",
    "print(X.shape, y.shape, \"\\n\", X[:5], \"\\n\", y[:5], \"\\n\", y.std(), \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70071f5f",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "07278db8",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def _ste_round(x: torch.Tensor) -> torch.Tensor:\n",
    "    return x + (torch.round(x) - x).detach()\n",
    "\n",
    "class _MCFValue(torch.autograd.Function):\n",
    "    \n",
    "    @staticmethod\n",
    "    def forward(ctx, n_nodes, src, dst, cost, cap, supply):\n",
    "        n = int(n_nodes)\n",
    "\n",
    "        src = src.to(dtype=torch.int64).contiguous()\n",
    "        dst = dst.to(dtype=torch.int64).contiguous()\n",
    "\n",
    "        m = int(src.numel())\n",
    "        if (dst.numel() != m) or (cost.numel() != m) or (cap.numel() != m) or (supply.numel() != n):\n",
    "            raise ValueError(\"Bad shapes for MCF inputs.\")\n",
    "        if torch.abs(supply.double().sum()) > _TOL:\n",
    "            raise ValueError(\"Require sum(supply)=0\")\n",
    "\n",
    "        def as_np(t: torch.Tensor, dtype):\n",
    "            return t.detach().cpu().contiguous().view(-1).numpy().astype(dtype, copy=False)\n",
    "\n",
    "        out = lemon_mcf.solve_mcf(\n",
    "            n,\n",
    "            as_np(src, np.int64),\n",
    "            as_np(dst, np.int64),\n",
    "            as_np(cost, np.float64),\n",
    "            as_np(cap, np.float64),\n",
    "            as_np(supply, np.float64),\n",
    "            tol=_TOL,\n",
    "        )\n",
    "        if out[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"LEMON failed (status={out['status']})\")\n",
    "\n",
    "        flow = out[\"flow\"]\n",
    "        pot = out[\"potential\"]\n",
    "        red = out[\"reduced_cost\"]\n",
    "        at = out.get(\"at_cap\", out.get(\"at_capacity\", None))\n",
    "        if at is None:\n",
    "            at = np.abs(flow - as_np(cap, np.float64)) <= _TOL\n",
    "\n",
    "        ctx.flow, ctx.pot, ctx.red, ctx.at = flow, pot, red, at\n",
    "        return cost.new_tensor(float(out[\"total_cost\"]))\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g):\n",
    "        dev, dt = g.device, g.dtype\n",
    "        flow = torch.as_tensor(ctx.flow, device=dev, dtype=dt)\n",
    "        pot  = torch.as_tensor(ctx.pot,  device=dev, dtype=dt)\n",
    "        red  = torch.as_tensor(ctx.red,  device=dev, dtype=dt)\n",
    "        at   = torch.as_tensor(ctx.at,   device=dev, dtype=torch.bool)\n",
    "\n",
    "        grad_cost = flow\n",
    "        grad_cap  = torch.where(at, red, torch.zeros_like(red))\n",
    "        grad_sup  = pot.mean() - pot\n",
    "\n",
    "        return None, None, None, grad_cost * g, grad_cap * g, grad_sup * g\n",
    "\n",
    "\n",
    "class DFN(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        layer_sizes,\n",
    "        p_list,\n",
    "        big_cost: float = 1e6,\n",
    "        big_cap: float = 1e6,\n",
    "        seed: int = 0,\n",
    "        A_fixed=None,\n",
    "        alpha: float = 1e-6,\n",
    "        beta: float = -0.0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.alpha = float(alpha)\n",
    "        self.beta  = float(beta)\n",
    "\n",
    "        layer_sizes = list(map(int, layer_sizes))\n",
    "        if len(layer_sizes) < 2 or len(p_list) != len(layer_sizes) - 1:\n",
    "            raise ValueError(\"Need len(layer_sizes)>=2 and len(p_list)=len(layer_sizes)-1\")\n",
    "\n",
    "        self.n = int(sum(layer_sizes))\n",
    "        if self.n <= 0:\n",
    "            raise ValueError(\"sum(layer_sizes) must be > 0\")\n",
    "\n",
    "        # node indices per layer\n",
    "        layers, off = [], 0\n",
    "        for s in layer_sizes:\n",
    "            layers.append(torch.arange(off, off + s, dtype=torch.long))\n",
    "            off += s\n",
    "\n",
    "        L1, LK = layers[0], layers[-1]\n",
    "        if L1.numel() == 0 or LK.numel() == 0:\n",
    "            raise ValueError(\"First/last layer must be non-empty.\")\n",
    "\n",
    "        self.fix_node = int(LK[-1].item())\n",
    "        boundary = torch.cat([L1, LK[:-1]], 0)\n",
    "        self.register_buffer(\"boundary\", boundary)\n",
    "\n",
    "        gen = torch.Generator().manual_seed(int(seed))\n",
    "\n",
    "        def bipartite(U: torch.Tensor, V: torch.Tensor):\n",
    "            su, sv = int(U.numel()), int(V.numel())\n",
    "            return U.repeat_interleave(sv), V.repeat(su)\n",
    "\n",
    "        def sample_edges(U, V, p: float):\n",
    "            s, t = bipartite(U, V)\n",
    "            if p < 1.0:\n",
    "                keep = torch.rand(s.numel(), generator=gen) < float(p)\n",
    "                s, t = s[keep], t[keep]\n",
    "            return s, t\n",
    "\n",
    "        # learnable arcs between consecutive layers (both directions)\n",
    "        sf, tf, sb, tb = [], [], [], []\n",
    "        for i, p in enumerate(map(float, p_list)):\n",
    "            if not (0.0 <= p <= 1.0):\n",
    "                raise ValueError(\"p_list entries must be in [0,1]\")\n",
    "\n",
    "            s, t = sample_edges(layers[i], layers[i + 1], p)  # forward\n",
    "            sf.append(s); tf.append(t)\n",
    "\n",
    "            s, t = sample_edges(layers[i + 1], layers[i], p)  # backward\n",
    "            sb.append(s); tb.append(t)\n",
    "\n",
    "        src_param = torch.cat([torch.cat(sf, 0), torch.cat(sb, 0)], 0)\n",
    "        dst_param = torch.cat([torch.cat(tf, 0), torch.cat(tb, 0)], 0)\n",
    "        if src_param.numel() == 0:\n",
    "            raise ValueError(\"No learnable arcs (increase p_list / layer sizes).\")\n",
    "\n",
    "        s1, t1 = bipartite(L1, LK)\n",
    "        s2, t2 = bipartite(LK, L1)\n",
    "        src_fixed = torch.cat([s1, s2], 0)\n",
    "        dst_fixed = torch.cat([t1, t2], 0)\n",
    "        m_fixed = int(src_fixed.numel())\n",
    "\n",
    "        self.register_buffer(\"src\", torch.cat([src_param, src_fixed], 0))\n",
    "        self.register_buffer(\"dst\", torch.cat([dst_param, dst_fixed], 0))\n",
    "        self.register_buffer(\"cap_fixed\",  torch.full((m_fixed,), float(big_cap),  dtype=torch.float32))\n",
    "        self.register_buffer(\"cost_fixed\", torch.full((m_fixed,), float(big_cost), dtype=torch.float32))\n",
    "\n",
    "        nb = int(boundary.numel())\n",
    "        input_dim = int(input_dim)\n",
    "\n",
    "        m_param = int(src_param.numel())\n",
    "        self.cap_raw  = nn.Parameter(torch.zeros(m_param) + 0.542)\n",
    "        self.cost_raw = nn.Parameter(torch.randn(m_param) + 1.0)\n",
    "        self.b_raw    = nn.Parameter(torch.zeros(nb))\n",
    "\n",
    "        if A_fixed is None:\n",
    "            A = torch.zeros(nb, input_dim)\n",
    "            rows = torch.arange(nb)\n",
    "            A[rows, rows % input_dim] = 1.0\n",
    "            self.A = nn.Parameter(A)\n",
    "        else:\n",
    "            A_fixed = torch.as_tensor(A_fixed, dtype=torch.float32)\n",
    "            if A_fixed.shape != (nb, input_dim):\n",
    "                raise ValueError(f\"A_fixed must have shape {(nb, input_dim)}, got {tuple(A_fixed.shape)}\")\n",
    "            self.register_buffer(\"A\", A_fixed)\n",
    "\n",
    "    def forward(self, w: torch.Tensor) -> torch.Tensor:\n",
    "        capP  = _ste_round(F.softplus(self.cap_raw))\n",
    "        costP = self.cost_raw\n",
    "        b     = _ste_round(self.b_raw)\n",
    "        A     = _ste_round(self.A) if isinstance(self.A, nn.Parameter) else self.A\n",
    "\n",
    "        cap  = torch.cat([capP,  self.cap_fixed.to(w.device, w.dtype)], 0)\n",
    "        cost = torch.cat([costP, self.cost_fixed.to(w.device, w.dtype)], 0)\n",
    "\n",
    "        def one(w1: torch.Tensor) -> torch.Tensor:\n",
    "            supply = torch.zeros(self.n, device=w1.device, dtype=torch.float64)\n",
    "            supply[self.boundary] = (A.double() @ w1.double()) + b.double()\n",
    "            supply[self.fix_node] = -supply.sum()\n",
    "            return _MCFValue.apply(self.n, self.src, self.dst, cost, cap, supply)\n",
    "\n",
    "        out = one(w) if w.dim() == 1 else torch.stack([one(wi) for wi in w], 0)\n",
    "        return self.alpha * out + self.beta\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dims, out_dim):\n",
    "        super().__init__()\n",
    "        dims = [in_dim] + list(hidden_dims) + [out_dim]\n",
    "        layers = []\n",
    "        for a, b in zip(dims[:-2], dims[1:-1]):\n",
    "            layers += [nn.Linear(a, b), nn.ReLU()]\n",
    "        layers += [nn.Linear(dims[-2], dims[-1])]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "class MaxAffine(nn.Module):\n",
    "    def __init__(self, in_dim: int, n_pieces: int):\n",
    "        super().__init__()\n",
    "        self.W = nn.Parameter(torch.randn(n_pieces, in_dim) / (in_dim**0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces))\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        return (x @ self.W.T + self.b).max(dim=1).values\n",
    "\n",
    "\n",
    "class LSET(nn.Module):\n",
    "    def __init__(self, in_dim: int, n_pieces: int, T: float = 0.01):\n",
    "        super().__init__()\n",
    "        self.T = float(T)\n",
    "        if self.T == 0.0:\n",
    "            raise ValueError(\"T must be nonzero\")\n",
    "        self.A = nn.Parameter(torch.randn(n_pieces, in_dim) / (in_dim**0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces))\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        z = (x @ self.A.t() + self.b) / self.T\n",
    "        return self.T * torch.logsumexp(z, dim=-1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cca0120",
   "metadata": {},
   "source": [
    "## Training Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0a7db872",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def _split_train_val_test(X: torch.Tensor, y: torch.Tensor, *, val_frac: float, test_frac: float, seed: int):\n",
    "    N = int(X.shape[0])\n",
    "    n_test = int(round(test_frac * N))\n",
    "    n_val  = int(round(val_frac  * N))\n",
    "    n_train = N - n_val - n_test\n",
    "    if n_train <= 0:\n",
    "        raise ValueError(\"splits too large; train set would be empty\")\n",
    "\n",
    "    g = torch.Generator().manual_seed(int(seed))\n",
    "    perm = torch.randperm(N, generator=g)\n",
    "    i_tr = perm[:n_train]\n",
    "    i_va = perm[n_train:n_train + n_val]\n",
    "    i_te = perm[n_train + n_val:]\n",
    "\n",
    "    return (X[i_tr], y[i_tr]), (X[i_va], y[i_va]), (X[i_te], y[i_te])\n",
    "\n",
    "\n",
    "def _fit_standardizer(Xtr: torch.Tensor, ytr: torch.Tensor, *, eps: float):\n",
    "    x_mean = Xtr.mean(0, keepdim=True)\n",
    "    x_std  = Xtr.std(0, unbiased=False, keepdim=True)\n",
    "    x_std  = torch.where(x_std < eps, torch.ones_like(x_std), x_std)\n",
    "\n",
    "    y_mean = ytr.mean()\n",
    "    y_std  = ytr.std(unbiased=False).clamp_min(eps)\n",
    "\n",
    "    return {\"x_mean\": x_mean, \"x_std\": x_std, \"y_mean\": y_mean, \"y_std\": y_std}\n",
    "\n",
    "\n",
    "def _apply_standardizer(X: torch.Tensor, y: torch.Tensor, scaler):\n",
    "    Xn = (X - scaler[\"x_mean\"]) / scaler[\"x_std\"]\n",
    "    yn = (y - scaler[\"y_mean\"]) / scaler[\"y_std\"]\n",
    "    return Xn, yn\n",
    "\n",
    "@torch.no_grad()\n",
    "def _mse_norm(model: nn.Module, loader: DataLoader, device: str):\n",
    "    model.eval()\n",
    "    tot, n = 0.0, 0\n",
    "    for xb, yb in loader:\n",
    "        xb, yb = xb.to(device), yb.to(device)\n",
    "        pred = model(xb).squeeze(-1)\n",
    "        tot += F.mse_loss(pred, yb, reduction=\"sum\").item()\n",
    "        n += yb.numel()\n",
    "    return tot / max(n, 1)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def _predict_in_chunks(model: nn.Module, X: torch.Tensor, *, chunk: int):\n",
    "    out = []\n",
    "    for i in range(0, int(X.shape[0]), int(chunk)):\n",
    "        out.append(model(X[i:i + chunk]).squeeze(-1))\n",
    "    return torch.cat(out, 0)\n",
    "\n",
    "\n",
    "import json, hashlib\n",
    "\n",
    "# ---- Caching / persistence helpers ----\n",
    "_CACHE_ROOT_DEFAULT = Path(\"saved_runs\")\n",
    "\n",
    "def _to_hashable(x):\n",
    "    \"\"\"Convert nested objects to a deterministic, hashable (and mostly JSON-friendly) form.\"\"\"\n",
    "    import numpy as _np\n",
    "    import torch as _torch\n",
    "    from pathlib import Path as _Path\n",
    "\n",
    "    if x is None or isinstance(x, (bool, int, str)):\n",
    "        return x\n",
    "    if isinstance(x, float):\n",
    "        return float(x)\n",
    "    if isinstance(x, _Path):\n",
    "        return str(x)\n",
    "    if isinstance(x, (list, tuple)):\n",
    "        return [_to_hashable(v) for v in x]\n",
    "    if isinstance(x, dict):\n",
    "        return {str(k): _to_hashable(x[k]) for k in sorted(x.keys(), key=lambda z: str(z))}\n",
    "    if isinstance(x, _np.ndarray):\n",
    "        h = hashlib.sha256(x.tobytes(order=\"C\")).hexdigest()\n",
    "        return {\"__ndarray__\": True, \"dtype\": str(x.dtype), \"shape\": list(x.shape), \"sha256\": h}\n",
    "    if isinstance(x, _torch.Tensor):\n",
    "        t = x.detach().cpu()\n",
    "        h = hashlib.sha256(t.numpy().tobytes(order=\"C\")).hexdigest()\n",
    "        return {\"__tensor__\": True, \"dtype\": str(t.dtype), \"shape\": list(t.shape), \"sha256\": h}\n",
    "    return {\"__repr__\": repr(x)}\n",
    "\n",
    "def _run_signature(dataset_type, dataset_params, model_type, model_params, train_sig):\n",
    "    return {\n",
    "        \"dataset_type\": str(dataset_type).lower(),\n",
    "        \"dataset_params\": _to_hashable(dict(dataset_params)),\n",
    "        \"model_type\": str(model_type),\n",
    "        \"model_params\": _to_hashable(dict(model_params)),\n",
    "        \"train_params\": _to_hashable(dict(train_sig)),\n",
    "    }\n",
    "\n",
    "def _run_id_from_signature(sig) -> str:\n",
    "    blob = json.dumps(sig, sort_keys=True, separators=(\",\", \":\"), ensure_ascii=True).encode(\"utf-8\")\n",
    "    return hashlib.sha256(blob).hexdigest()[:16]\n",
    "\n",
    "def _get_run_dir(cache_root: Path, dataset_type, model_type, run_id: str) -> Path:\n",
    "    return Path(cache_root) / str(dataset_type).lower() / str(model_type) / run_id\n",
    "\n",
    "def _cpuify(obj):\n",
    "    import torch as _torch\n",
    "    if isinstance(obj, _torch.Tensor):\n",
    "        return obj.detach().cpu()\n",
    "    if isinstance(obj, dict):\n",
    "        return {k: _cpuify(v) for k, v in obj.items()}\n",
    "    if isinstance(obj, (list, tuple)):\n",
    "        return [_cpuify(v) for v in obj]\n",
    "    return obj\n",
    "\n",
    "def _set_full_determinism(seed: int):\n",
    "    \"\"\"CPU-only determinism: seed controls Python/NumPy/Torch RNGs.\n",
    "    (CUDA is intentionally not used in these notebooks.)\n",
    "    \"\"\"\n",
    "    import os, random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    seed = int(seed)\n",
    "\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    # Determinism flags (will error if a non-deterministic op is used)\n",
    "    try:\n",
    "        torch.use_deterministic_algorithms(True)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "def _seed_worker(worker_id: int):\n",
    "    import random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    worker_seed = torch.initial_seed() % 2**32\n",
    "    np.random.seed(worker_seed)\n",
    "    random.seed(worker_seed)\n",
    "\n",
    "def _try_load_cached_run(run_dir: Path):\n",
    "    artifact_path = Path(run_dir) / \"artifact.pt\"\n",
    "    if not artifact_path.exists():\n",
    "        return None\n",
    "    try:\n",
    "        return torch.load(str(artifact_path), map_location=\"cpu\")\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "def _build_model_from_params(model_type, model_params):\n",
    "    mt = str(model_type)\n",
    "    mp = dict(model_params)\n",
    "    if mt == \"DFN\":\n",
    "        model = DFN(**mp)\n",
    "    else:\n",
    "        mp.pop(\"alpha\", None)\n",
    "        mp.pop(\"beta\", None)\n",
    "        if mt == \"MLP\":\n",
    "            model = MLP(**mp)\n",
    "        elif mt == \"MaxAffine\":\n",
    "            model = MaxAffine(**mp)\n",
    "        elif mt == \"LSET\":\n",
    "            model = LSET(**mp)\n",
    "        else:\n",
    "            raise ValueError(\"model_type must be: DFN | MLP | MaxAffine | LSET\")\n",
    "    return model\n",
    "\n",
    "def _save_run(run_dir: Path, signature: dict, model: nn.Module, data: dict, history: dict, spec: dict):\n",
    "    run_dir = Path(run_dir)\n",
    "    run_dir.mkdir(parents=True, exist_ok=True)\n",
    "    artifact_path = run_dir / \"artifact.pt\"\n",
    "    if artifact_path.exists():\n",
    "        return\n",
    "    state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n",
    "    artifact = {\n",
    "        \"signature\": signature,\n",
    "        \"state_dict\": state_dict,\n",
    "        \"data\": _cpuify(data),\n",
    "        \"history\": history,\n",
    "        \"spec\": spec,\n",
    "    }\n",
    "    torch.save(artifact, str(artifact_path))\n",
    "    try:\n",
    "        with open(run_dir / \"signature.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "            json.dump(signature, f, indent=2, sort_keys=True)\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "def generate_and_train_simple(dataset_type, dataset_params, model_type, model_params, train_params=None):\n",
    "    \n",
    "    tp = train_params or {}\n",
    "\n",
    "    epochs     = int(tp.get(\"epochs\", 200))\n",
    "    batch_sz   = int(tp.get(\"batch_size\", 32))\n",
    "    lr         = float(tp.get(\"lr\", 1e-3))\n",
    "    wd         = float(tp.get(\"weight_decay\", 0.0))\n",
    "    val_frac   = float(tp.get(\"val_frac\", 0.15))\n",
    "    test_frac  = float(tp.get(\"test_frac\", 0.15))\n",
    "    eps        = float(tp.get(\"eps\", 1e-8))\n",
    "    seed       = int(tp.get(\"seed\", 0))\n",
    "    device     = tp.get(\"device\", \"cpu\")\n",
    "    plot_every = int(tp.get(\"plot_every\", 0) or 0)\n",
    "    plot_points= int(tp.get(\"plot_points\", 2048))\n",
    "    plot_chunk = int(tp.get(\"plot_chunk\", 4096))\n",
    "    print_stats= bool(tp.get(\"print_stats\", True))\n",
    "\n",
    "\n",
    "    # ---- determinism (seed controls data, model init, dataloader order, etc.) ----\n",
    "    _set_full_determinism(seed)\n",
    "    # dataset seed: caller may set it; otherwise default to the training seed\n",
    "    dataset_params = dict(dataset_params)\n",
    "    dataset_params.setdefault(\"seed\", int(seed))\n",
    "\n",
    "\n",
    "    # ---- cache check (skip training if identical run already saved) ----\n",
    "    train_sig = {\n",
    "        \"epochs\": int(epochs),\n",
    "        \"batch_size\": int(batch_sz),\n",
    "        \"lr\": float(lr),\n",
    "        \"weight_decay\": float(wd),\n",
    "        \"val_frac\": float(val_frac),\n",
    "        \"test_frac\": float(test_frac),\n",
    "        \"eps\": float(eps),\n",
    "        \"seed\": int(seed),\n",
    "        \"deterministic\": True,\n",
    "    }\n",
    "    cache_root = Path(tp.get(\"cache_root\", _CACHE_ROOT_DEFAULT))\n",
    "    signature = _run_signature(dataset_type, dataset_params, model_type, model_params, train_sig)\n",
    "    run_id = _run_id_from_signature(signature)\n",
    "    run_dir = _get_run_dir(cache_root, dataset_type, model_type, run_id)\n",
    "\n",
    "    cached = _try_load_cached_run(run_dir)\n",
    "    if cached is not None and isinstance(cached, dict) and \"state_dict\" in cached and \"data\" in cached and \"history\" in cached and \"spec\" in cached:\n",
    "        model = _build_model_from_params(model_type, model_params)\n",
    "        model.load_state_dict(cached[\"state_dict\"])\n",
    "        model = model.to(device)\n",
    "        return model, cached[\"data\"], cached[\"history\"], cached[\"spec\"]\n",
    "\n",
    "    # ---- data ----\n",
    "    dt = str(dataset_type).lower()\n",
    "    if dt == \"mdvsp\":\n",
    "        X, y, gt = make_mdvsp_dataset(**dataset_params)\n",
    "    elif dt == \"assignment\":\n",
    "        X, y, gt = generate_bipartite_subset_matching_dataset(**dataset_params)\n",
    "    elif dt == \"quadratic\":\n",
    "        X, y, gt = generate_convex_quadratic_dataset(**dataset_params)\n",
    "    else:\n",
    "        raise ValueError(\"dataset_type must be: mdvsp | assignment | quadratic\")\n",
    "\n",
    "    X = torch.as_tensor(X, dtype=torch.float32)\n",
    "    y = torch.as_tensor(y, dtype=torch.float32).view(-1)\n",
    "\n",
    "    # ---- print a quick dataset stats ----\n",
    "    if print_stats:\n",
    "        with torch.no_grad():\n",
    "            xmn = X.mean(0)\n",
    "            xsd = X.std(0, unbiased=False)\n",
    "            print(\n",
    "                f\"\\n--- Dataset stats ({dataset_type}) ---\\n\"\n",
    "                f\"  X: shape={tuple(X.shape)}  mean(mean)={xmn.mean():.3g}  std(mean)={xsd.mean():.3g}  \"\n",
    "                f\"min={float(X.min()):.3g}  max={float(X.max()):.3g}\\n\"\n",
    "                f\"  y: shape={tuple(y.shape)}  mean={float(y.mean()):.3g}  std={float(y.std(unbiased=False)):.3g}  \"\n",
    "                f\"min={float(y.min()):.3g}  max={float(y.max()):.3g}\\n\"\n",
    "            )\n",
    "\n",
    "    (Xtr, ytr), (Xva, yva), (Xte, yte) = _split_train_val_test(X, y, val_frac=val_frac, test_frac=test_frac, seed=seed)\n",
    "\n",
    "    scaler = _fit_standardizer(Xtr, ytr, eps=eps)\n",
    "    XtrN, ytrN = _apply_standardizer(Xtr, ytr, scaler)\n",
    "    XvaN, yvaN = _apply_standardizer(Xva, yva, scaler)\n",
    "    XteN, yteN = _apply_standardizer(Xte, yte, scaler)\n",
    "\n",
    "    g_dl = torch.Generator()\n",
    "    g_dl.manual_seed(seed)\n",
    "    train_loader = DataLoader(\n",
    "        TensorDataset(XtrN, ytrN),\n",
    "        batch_size=batch_sz,\n",
    "        shuffle=True,\n",
    "        generator=g_dl,\n",
    "        worker_init_fn=_seed_worker,\n",
    "    )\n",
    "    val_loader   = DataLoader(TensorDataset(XvaN, yvaN), batch_size=batch_sz, shuffle=False)\n",
    "\n",
    "    # subset for plotting\n",
    "    Nv = int(XvaN.shape[0])\n",
    "    if plot_points <= 0 or plot_points >= Nv:\n",
    "        plot_idx = torch.arange(Nv)\n",
    "    else:\n",
    "        g_plot = torch.Generator().manual_seed(seed + 12345)\n",
    "        plot_idx = torch.randperm(Nv, generator=g_plot)[:plot_points]\n",
    "\n",
    "    # ---- model ----\n",
    "    mt = str(model_type)\n",
    "    mp = dict(model_params)\n",
    "\n",
    "    if mt == \"DFN\":\n",
    "        model = DFN(**mp)\n",
    "        extra = f\"layers={mp.get('layer_sizes')} p_list={mp.get('p_list')} alpha={getattr(model,'alpha',None)} beta={getattr(model,'beta',None)}\"\n",
    "    else:\n",
    "        mp.pop(\"alpha\", None)\n",
    "        mp.pop(\"beta\", None)\n",
    "        if mt == \"MLP\":\n",
    "            model = MLP(**mp)\n",
    "            extra = f\"hidden={mp.get('hidden_dims')}\"\n",
    "        elif mt == \"MaxAffine\":\n",
    "            model = MaxAffine(**mp)\n",
    "            extra = f\"n_pieces={mp.get('n_pieces')}\"\n",
    "        elif mt == \"LSET\":\n",
    "            model = LSET(**mp)\n",
    "            extra = f\"n_pieces={mp.get('n_pieces')} T={mp.get('T')}\"\n",
    "        else:\n",
    "            raise ValueError(\"model_type must be: DFN | MLP | MaxAffine | LSET\")\n",
    "\n",
    "    model = model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "\n",
    "    n_params = sum(p.numel() for p in model.parameters())\n",
    "    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    print(\n",
    "        f\"\\n=== Run: {dataset_type} | {model_type} ===\\n\"\n",
    "        f\"  data: N={len(X)}  train/val/test={len(Xtr)}/{len(Xva)}/{len(Xte)}  dim={X.shape[1]}\\n\"\n",
    "        f\"  model: params={n_params:,} {extra}\\n\"\n",
    "        f\"  train: device={device}  epochs={epochs}  batch={batch_sz}  lr={lr:g}  wd={wd:g}  seed={seed}\\n\"\n",
    "    )\n",
    "\n",
    "    history = {\"train_mse_norm\": [], \"val_mse_norm\": []}\n",
    "    best_val, best_ep, best_state = float(\"inf\"), 0, None\n",
    "\n",
    "    live = display(None, display_id=True) if plot_every > 0 else None\n",
    "\n",
    "    for ep in range(1, epochs + 1):\n",
    "        model.train()\n",
    "        for xb, yb in train_loader:\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            loss = F.mse_loss(model(xb).squeeze(-1), yb)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "\n",
    "        tr_mse = _mse_norm(model, train_loader, device)\n",
    "        va_mse = _mse_norm(model, val_loader, device)\n",
    "        history[\"train_mse_norm\"].append(tr_mse)\n",
    "        history[\"val_mse_norm\"].append(va_mse)\n",
    "\n",
    "        if va_mse < best_val:\n",
    "            best_val, best_ep = va_mse, ep\n",
    "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
    "\n",
    "        if plot_every > 0 and (ep == 1 or ep % plot_every == 0 or ep == epochs):\n",
    "            model.eval()\n",
    "            x_plot = XvaN[plot_idx].to(device)\n",
    "            y_true = yvaN[plot_idx].to(device)\n",
    "            y_pred = _predict_in_chunks(model, x_plot, chunk=plot_chunk)\n",
    "\n",
    "            fig, ax = plt.subplots(1, 2, figsize=(10, 4))\n",
    "            ax[0].plot(history[\"train_mse_norm\"], label=\"train\")\n",
    "            ax[0].plot(history[\"val_mse_norm\"], label=\"val\")\n",
    "            ax[0].set_yscale(\"log\")\n",
    "            ax[0].set_title(f\"Epoch {ep}/{epochs} | val MSE={va_mse:.3e} (norm)\")\n",
    "            ax[0].legend()\n",
    "\n",
    "            yt = y_true.detach().cpu().numpy()\n",
    "            yp = y_pred.detach().cpu().numpy()\n",
    "            ax[1].scatter(yt, yp, s=10)\n",
    "            lo = float(min(yt.min(), yp.min()))\n",
    "            hi = float(max(yt.max(), yp.max()))\n",
    "            ax[1].plot([lo, hi], [lo, hi])\n",
    "            ax[1].set_xlabel(\"y_true (norm)\")\n",
    "            ax[1].set_ylabel(\"y_pred (norm)\")\n",
    "            ax[1].set_title(f\"Val scatter (n={len(yt)})\")\n",
    "\n",
    "            plt.tight_layout()\n",
    "            live.update(fig)\n",
    "            plt.close(fig)\n",
    "\n",
    "    if best_state is not None:\n",
    "        model.load_state_dict(best_state)\n",
    "\n",
    "    if print_stats:\n",
    "        print(f\"[DONE] best val MSE (norm) = {best_val:.3e} @ epoch {best_ep}\\n\")\n",
    "\n",
    "    data = {\n",
    "        \"raw\":  {\"Xtr\": Xtr,  \"ytr\": ytr,  \"Xva\": Xva,  \"yva\": yva,  \"Xte\": Xte,  \"yte\": yte},\n",
    "        \"norm\": {\"Xtr\": XtrN, \"ytr\": ytrN, \"Xva\": XvaN, \"yva\": yvaN, \"Xte\": XteN, \"yte\": yteN},\n",
    "        \"scaler\": scaler,\n",
    "        \"best_val_mse_norm\": float(best_val),\n",
    "        \"best_epoch\": int(best_ep),\n",
    "        \"stats\": {\"n_params\": int(n_params), \"n_trainable\": int(n_trainable)},\n",
    "        \"true\": gt,\n",
    "        \"device\": device,\n",
    "    }\n",
    "\n",
    "    spec = {\n",
    "        \"model\": model_type,\n",
    "        \"extra\": extra,\n",
    "        \"n_params\": int(n_params),\n",
    "        \"n_trainable\": int(n_trainable),\n",
    "        \"model_params\": dict(model_params),\n",
    "        \"train_params\": {\n",
    "            \"epochs\": int(epochs),\n",
    "            \"batch_size\": int(batch_sz),\n",
    "            \"lr\": float(lr),\n",
    "            \"weight_decay\": float(wd),\n",
    "            \"seed\": int(seed),\n",
    "            \"device\": str(device),\n",
    "            \"val_frac\": float(val_frac),\n",
    "            \"test_frac\": float(test_frac),\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "    # ---- save artifacts for this run ----\n",
    "\n",
    "    _save_run(run_dir, signature, model, data, history, spec)\n",
    "\n",
    "\n",
    "    return model, data, history, spec\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ace7f63f",
   "metadata": {},
   "source": [
    "## Optimization Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "38416530",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_true_obj(gt, x):\n",
    "    if not isinstance(gt, dict):\n",
    "        return np.nan\n",
    "\n",
    "    x = np.asarray(x).reshape(-1)\n",
    "    t = gt.get(\"type\", None)\n",
    "\n",
    "    if t == \"quadratic\":\n",
    "        Q = np.asarray(gt[\"Q\"], float)\n",
    "        xs = np.asarray(gt[\"x_star\"], float).reshape(-1)\n",
    "        d = x.astype(float) - xs\n",
    "        return float(d @ Q @ d)\n",
    "\n",
    "    if t == \"assignment\":\n",
    "        C = np.asarray(gt[\"C\"], float)\n",
    "        idx = np.flatnonzero(x > 0.5)\n",
    "        if idx.size == 0:\n",
    "            return 0.0\n",
    "        r, c = linear_sum_assignment(C[idx, :])\n",
    "        return float(C[idx, :][r, c].sum())\n",
    "\n",
    "    if t == \"mdvsp\":\n",
    "        N, SS, TT = int(gt[\"N\"]), int(gt[\"SS\"]), int(gt[\"TT\"])\n",
    "        src = np.asarray(gt[\"src\"], np.int64)\n",
    "        dst = np.asarray(gt[\"dst\"], np.int64)\n",
    "        cost = np.asarray(gt[\"cost\"], float)\n",
    "        cap0 = np.asarray(gt[\"cap0\"], float)\n",
    "        idxSS = np.asarray(gt[\"idxSS\"], np.int64)\n",
    "        idxTT = np.asarray(gt[\"idxTT\"], np.int64)\n",
    "\n",
    "        cap = cap0.copy()\n",
    "        cap[idxSS] = x.astype(float)\n",
    "        cap[idxTT] = x.astype(float)\n",
    "\n",
    "        Fmax = lemon_mcf.max_flow(N, src, dst, cap, SS, TT)[\"value\"]\n",
    "        supply = np.zeros(N, float)\n",
    "        supply[SS] = Fmax\n",
    "        supply[TT] = -Fmax\n",
    "        return float(lemon_mcf.solve_mcf(N, src, dst, cost, cap, supply)[\"total_cost\"])\n",
    "\n",
    "    return np.nan\n",
    "\n",
    "\n",
    "def local_search_l1_int(\n",
    "    f,\n",
    "    x0,\n",
    "    x_min,\n",
    "    x_max,\n",
    "    delta: int,\n",
    "    sum_eq=None,\n",
    "    max_iters: int = 10_000,\n",
    "    print_every: int = 1,\n",
    "):\n",
    "    x  = np.asarray(x0,    int).ravel()\n",
    "    lo = np.asarray(x_min, int).ravel()\n",
    "    hi = np.asarray(x_max, int).ravel()\n",
    "    n = int(x.size)\n",
    "    delta = int(delta)\n",
    "\n",
    "    assert lo.size == n and hi.size == n\n",
    "    assert np.all(x >= lo) and np.all(x <= hi)\n",
    "    if sum_eq is not None:\n",
    "        sum_eq = int(sum_eq)\n",
    "        assert int(x.sum()) == sum_eq\n",
    "\n",
    "    def eval_batch(X):\n",
    "        y = np.asarray(f(X), float).reshape(-1)\n",
    "        return y\n",
    "\n",
    "    def ok(z):\n",
    "        if np.any(z < lo) or np.any(z > hi):\n",
    "            return False\n",
    "        if sum_eq is not None and int(z.sum()) != sum_eq:\n",
    "            return False\n",
    "        return True\n",
    "\n",
    "    # integer deltas with exact L1 = k\n",
    "    def deltas_exact(k):\n",
    "        d = np.zeros(n, int)\n",
    "\n",
    "        def rec(i, rem):\n",
    "            if i == n:\n",
    "                if rem == 0:\n",
    "                    yield d.copy()\n",
    "                return\n",
    "            for t in range(rem + 1):\n",
    "                if t == 0:\n",
    "                    d[i] = 0\n",
    "                    yield from rec(i + 1, rem)\n",
    "                else:\n",
    "                    d[i] = +t\n",
    "                    yield from rec(i + 1, rem - t)\n",
    "                    d[i] = -t\n",
    "                    yield from rec(i + 1, rem - t)\n",
    "            d[i] = 0\n",
    "\n",
    "        yield from rec(0, k)\n",
    "\n",
    "    t0 = time.perf_counter()\n",
    "    y = float(eval_batch(x[None, :])[0])\n",
    "    hist = [{\"iter\": 0, \"t\": 0.0, \"best_y\": y, \"x\": x.copy()}]\n",
    "    print(f\"iter=0  t=0.00s  best_y={y:.6g}  x={x.tolist()}\")\n",
    "\n",
    "    for it in range(1, int(max_iters) + 1):\n",
    "        cand = []\n",
    "        for k in range(1, delta + 1):\n",
    "            for dlt in deltas_exact(k):\n",
    "                z = x + dlt\n",
    "                if ok(z):\n",
    "                    cand.append(z)\n",
    "\n",
    "        if not cand:\n",
    "            print(f\"STOP: no feasible neighbors. best_y={y:.6g} x={x.tolist()}\")\n",
    "            break\n",
    "\n",
    "        Y = eval_batch(np.stack(cand, 0))\n",
    "        j = int(np.argmin(Y))\n",
    "        if float(Y[j]) >= y:\n",
    "            print(f\"STOP: local minimum. best_y={y:.6g} x={x.tolist()}\")\n",
    "            break\n",
    "\n",
    "        x, y = cand[j], float(Y[j])\n",
    "        hist.append({\"iter\": it, \"t\": time.perf_counter() - t0, \"best_y\": y, \"x\": x.copy()})\n",
    "        if it % max(1, int(print_every)) == 0:\n",
    "            print(f\"iter={it}  t={hist[-1]['t']:.2f}s  best_y={y:.6g}  x={x.tolist()}\")\n",
    "\n",
    "    return x, y, hist\n",
    "\n",
    "\n",
    "def _scaler_np(scaler):\n",
    "    xm = scaler[\"x_mean\"].detach().cpu().numpy().reshape(-1)\n",
    "    xs = scaler[\"x_std\"].detach().cpu().numpy().reshape(-1)\n",
    "    ym = float(scaler[\"y_mean\"].detach().cpu())\n",
    "    ys = float(scaler[\"y_std\"].detach().cpu())\n",
    "    return xm, xs, ym, ys\n",
    "\n",
    "\n",
    "def solve_dfn_ip_gurobi(dfn, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    x_mean, x_std, y_mean, y_std = _scaler_np(scaler)\n",
    "\n",
    "    cost = np.r_[dfn.cost_raw.detach().cpu().numpy(), dfn.cost_fixed.detach().cpu().numpy()]\n",
    "    cap  = np.r_[torch.round(F.softplus(dfn.cap_raw.detach())).cpu().numpy(), dfn.cap_fixed.detach().cpu().numpy()]\n",
    "\n",
    "    A = dfn.A.detach().cpu().numpy()\n",
    "    if isinstance(dfn.A, torch.nn.Parameter):\n",
    "        A = np.round(A)\n",
    "    b = np.round(dfn.b_raw.detach().cpu().numpy())\n",
    "\n",
    "    src = dfn.src.detach().cpu().numpy().astype(int)\n",
    "    dst = dfn.dst.detach().cpu().numpy().astype(int)\n",
    "    boundary = dfn.boundary.detach().cpu().numpy().astype(int)\n",
    "    fix = int(dfn.fix_node)\n",
    "    n = int(dfn.n)\n",
    "    m = int(src.size)\n",
    "    alpha = float(dfn.alpha)\n",
    "    beta  = float(dfn.beta)\n",
    "\n",
    "    out = [[] for _ in range(n)]\n",
    "    inn = [[] for _ in range(n)]\n",
    "    for e in range(m):\n",
    "        out[src[e]].append(e)\n",
    "        inn[dst[e]].append(e)\n",
    "\n",
    "    M = gp.Model(\"DFN_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    f = M.addVars(m, lb=0.0, ub=cap.tolist(), vtype=GRB.CONTINUOUS, name=\"f\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq)\n",
    "\n",
    "    xm_over_xs = x_mean / x_std\n",
    "    s = [0] * n\n",
    "    s_boundary = []\n",
    "    for r, v in enumerate(boundary):\n",
    "        const = float(b[r] - (A[r] * xm_over_xs).sum())\n",
    "        expr = const + gp.quicksum((A[r, j] / x_std[j]) * x[j] for j in range(d) if A[r, j] != 0)\n",
    "        s[v] = expr\n",
    "        s_boundary.append(expr)\n",
    "    s[fix] = -gp.quicksum(s_boundary)\n",
    "\n",
    "    for v in range(n):\n",
    "        M.addConstr(gp.quicksum(f[e] for e in out[v]) - gp.quicksum(f[e] for e in inn[v]) == s[v])\n",
    "\n",
    "    flow_cost = gp.quicksum(cost[e] * f[e] for e in range(m))\n",
    "    M.setObjective((alpha * flow_cost + beta) * y_std + y_mean, GRB.MINIMIZE)\n",
    "\n",
    "    M.optimize()\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No solution (Gurobi status {M.Status})\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None)}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_mlp_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    base, a_out, b_out = model, 1.0, 0.0\n",
    "    if hasattr(model, \"base\") and hasattr(model, \"a\") and hasattr(model, \"b\"):\n",
    "        base, a_out, b_out = model.base, float(model.a), float(model.b)\n",
    "\n",
    "    if not hasattr(base, \"net\"):\n",
    "        raise ValueError(\"Expected an MLP with attribute .net (nn.Sequential).\")\n",
    "    linears = [L for L in base.net if isinstance(L, torch.nn.Linear)]\n",
    "    if not linears:\n",
    "        raise ValueError(\"No Linear layers found in base.net\")\n",
    "\n",
    "    W = [L.weight.detach().cpu().numpy().astype(float) for L in linears]\n",
    "    b = [L.bias.detach().cpu().numpy().astype(float) for L in linears]\n",
    "\n",
    "    W[0] = W[0] / xs[None, :]\n",
    "    b[0] = b[0] - W[0] @ xm\n",
    "\n",
    "    W[-1] *= a_out\n",
    "    b[-1] = a_out * b[-1] + b_out\n",
    "\n",
    "    u = np.maximum(np.abs(x_min), np.abs(x_max))\n",
    "    preLU = []\n",
    "    for k in range(len(W) - 1):\n",
    "        U = np.abs(W[k]) @ u + np.abs(b[k])\n",
    "        preLU.append((-U, U))\n",
    "        u = np.maximum(0.0, U)\n",
    "\n",
    "    M = gp.Model(\"MLP_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    prev = [x[i] for i in range(d)]\n",
    "\n",
    "    for k in range(len(W) - 1):\n",
    "        Lk, Uk = preLU[k]\n",
    "        h = W[k].shape[0]\n",
    "\n",
    "        a = [M.addVar(lb=float(Lk[j]), ub=float(Uk[j]), name=f\"a{k}_{j}\") for j in range(h)]\n",
    "        z = [M.addVar(lb=0.0, ub=float(max(0.0, Uk[j])), name=f\"z{k}_{j}\") for j in range(h)]\n",
    "\n",
    "        for j in range(h):\n",
    "            M.addConstr(a[j] == b[k][j] + gp.quicksum(W[k][j, i] * prev[i] for i in range(len(prev))))\n",
    "\n",
    "            Lj, Uj = float(Lk[j]), float(Uk[j])\n",
    "            if Uj <= 0.0:\n",
    "                M.addConstr(z[j] == 0.0)\n",
    "            elif Lj >= 0.0:\n",
    "                M.addConstr(z[j] == a[j])\n",
    "            else:\n",
    "                s = M.addVar(vtype=GRB.BINARY, name=f\"s{k}_{j}\")\n",
    "                M.addConstr(z[j] >= a[j])\n",
    "                M.addConstr(z[j] >= 0.0)\n",
    "                M.addConstr(z[j] <= Uj * s)\n",
    "                M.addConstr(z[j] <= a[j] - Lj * (1 - s))\n",
    "\n",
    "        prev = z\n",
    "\n",
    "    if W[-1].shape[0] != 1:\n",
    "        raise ValueError(f\"Expected scalar output, got out_dim={W[-1].shape[0]}\")\n",
    "\n",
    "    y_norm = M.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name=\"y_norm\")\n",
    "    M.addConstr(y_norm == b[-1][0] + gp.quicksum(W[-1][0, i] * prev[i] for i in range(len(prev))))\n",
    "\n",
    "    M.setObjective(ys * y_norm + ym, GRB.MINIMIZE)\n",
    "\n",
    "    M.optimize()\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None), \"sol_count\": M.SolCount}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_maxaffine_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    base, a_out, b_out = model, 1.0, 0.0\n",
    "    if hasattr(model, \"base\") and hasattr(model, \"a\") and hasattr(model, \"b\"):\n",
    "        base, a_out, b_out = model.base, float(model.a), float(model.b)\n",
    "\n",
    "    W = base.W.detach().cpu().numpy().astype(float)\n",
    "    b = base.b.detach().cpu().numpy().astype(float)\n",
    "\n",
    "    Weff = W / xs[None, :]\n",
    "    beff = b - (Weff @ xm)\n",
    "\n",
    "    Weff *= a_out\n",
    "    beff  = a_out * beff + b_out\n",
    "\n",
    "    K = int(Weff.shape[0])\n",
    "\n",
    "    M = gp.Model(\"MaxAffine_IP\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    t = M.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name=\"t_norm\")\n",
    "    for k in range(K):\n",
    "        M.addConstr(t >= beff[k] + gp.quicksum(Weff[k, j] * x[j] for j in range(d) if Weff[k, j] != 0.0))\n",
    "\n",
    "    M.setObjective(ys * t + ym, GRB.MINIMIZE)\n",
    "    M.optimize()\n",
    "\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], float)\n",
    "    info = {\"status\": M.Status, \"runtime\": M.Runtime, \"gap\": getattr(M, \"MIPGap\", None), \"sol_count\": M.SolCount}\n",
    "    return x_star, float(M.ObjVal), info\n",
    "\n",
    "\n",
    "def solve_lset_ip_gurobi(model, scaler, x_min, x_max, sum_eq, *, integer_x=True, verbose=False, time_limit=None, seed=None):\n",
    "    import gurobipy as gp\n",
    "    from gurobipy import GRB\n",
    "\n",
    "    x_min = np.asarray(x_min, float).ravel()\n",
    "    x_max = np.asarray(x_max, float).ravel()\n",
    "    d = int(x_min.size)\n",
    "    sum_eq = float(sum_eq)\n",
    "\n",
    "    xm, xs, ym, ys = _scaler_np(scaler)\n",
    "\n",
    "    A = model.A.detach().cpu().numpy().astype(float)\n",
    "    b = model.b.detach().cpu().numpy().astype(float)\n",
    "    T = float(model.T)\n",
    "    if T == 0.0:\n",
    "        raise ValueError(\"T must be nonzero\")\n",
    "\n",
    "    Aeff = A / xs[None, :]\n",
    "    beff = b - (Aeff @ xm)\n",
    "    K = int(Aeff.shape[0])\n",
    "\n",
    "    lin_lo = np.empty(K, dtype=float)\n",
    "    lin_hi = np.empty(K, dtype=float)\n",
    "    for k in range(K):\n",
    "        a = Aeff[k]\n",
    "        lo = beff[k]\n",
    "        hi = beff[k]\n",
    "        pos = a >= 0\n",
    "        lo += (a[pos] * x_min[pos]).sum() + (a[~pos] * x_max[~pos]).sum()\n",
    "        hi += (a[pos] * x_max[pos]).sum() + (a[~pos] * x_min[~pos]).sum()\n",
    "        lin_lo[k], lin_hi[k] = lo, hi\n",
    "\n",
    "    z_lo = lin_lo / T\n",
    "    z_hi = lin_hi / T\n",
    "    m_lo = float(np.max(z_lo))\n",
    "    m_hi = float(np.max(z_hi))\n",
    "\n",
    "    w_lo = float(np.min(z_lo - m_hi))  # <= 0\n",
    "    u_lo = float(np.exp(max(-700.0, w_lo)))\n",
    "    s_lo = max(1e-12, K * u_lo)\n",
    "    s_hi = float(K * 1.0)\n",
    "    v_lo = float(np.log(s_lo))\n",
    "    v_hi = float(np.log(s_hi))\n",
    "\n",
    "    yN_lo = float(T * (m_lo + v_lo))\n",
    "    yN_hi = float(T * (m_hi + v_hi))\n",
    "    y_lo  = float(ys * yN_lo + ym)\n",
    "    y_hi  = float(ys * yN_hi + ym)\n",
    "\n",
    "    M = gp.Model(\"lset_ip_stable_bounded\")\n",
    "    M.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the IP solve reproducible across runs\n",
    "    try:\n",
    "        M.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            M.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        M.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    M.Params.FuncNonlinear = 1\n",
    "    M.Params.FeasibilityTol = 1e-9\n",
    "    M.Params.OptimalityTol  = 1e-9\n",
    "    M.Params.IntFeasTol     = 1e-9\n",
    "    M.Params.NumericFocus   = 3\n",
    "\n",
    "    xt = GRB.INTEGER if integer_x else GRB.CONTINUOUS\n",
    "    x = M.addVars(d, lb=x_min.tolist(), ub=x_max.tolist(), vtype=xt, name=\"x\")\n",
    "    M.addConstr(gp.quicksum(x[i] for i in range(d)) == sum_eq, name=\"sum_eq\")\n",
    "\n",
    "    z = M.addVars(K, lb=z_lo.tolist(), ub=z_hi.tolist(), vtype=GRB.CONTINUOUS, name=\"z\")\n",
    "    for k in range(K):\n",
    "        lin = beff[k] + gp.quicksum(Aeff[k, j] * x[j] for j in range(d) if Aeff[k, j] != 0.0)\n",
    "        M.addConstr(z[k] == lin / T, name=f\"zdef_{k}\")\n",
    "\n",
    "    m = M.addVar(lb=m_lo, ub=m_hi, vtype=GRB.CONTINUOUS, name=\"m\")\n",
    "    w = M.addVars(K, lb=w_lo, ub=0.0, vtype=GRB.CONTINUOUS, name=\"w\")\n",
    "    for k in range(K):\n",
    "        M.addConstr(m >= z[k], name=f\"m_ge_z_{k}\")\n",
    "        M.addConstr(w[k] == z[k] - m, name=f\"wdef_{k}\")\n",
    "\n",
    "    u = M.addVars(K, lb=0.0, ub=1.0, vtype=GRB.CONTINUOUS, name=\"u\")\n",
    "    for k in range(K):\n",
    "        M.addGenConstrExp(w[k], u[k], name=f\"exp_{k}\")\n",
    "\n",
    "    s = M.addVar(lb=s_lo, ub=s_hi, vtype=GRB.CONTINUOUS, name=\"s\")\n",
    "    M.addConstr(s == gp.quicksum(u[k] for k in range(K)), name=\"sumexp_shifted\")\n",
    "\n",
    "    v = M.addVar(lb=v_lo, ub=v_hi, vtype=GRB.CONTINUOUS, name=\"v\")\n",
    "    M.addGenConstrLog(s, v, name=\"log_shifted\")\n",
    "\n",
    "    y_norm = M.addVar(lb=yN_lo, ub=yN_hi, vtype=GRB.CONTINUOUS, name=\"y_norm\")\n",
    "    M.addConstr(y_norm == T * (m + v), name=\"y_norm_def\")\n",
    "\n",
    "    y_raw = M.addVar(lb=y_lo, ub=y_hi, vtype=GRB.CONTINUOUS, name=\"y_raw\")\n",
    "    M.addConstr(y_raw == ys * y_norm + ym, name=\"y_raw_def\")\n",
    "\n",
    "    M.setObjective(y_raw, GRB.MINIMIZE)\n",
    "    M.optimize()\n",
    "\n",
    "    if M.SolCount == 0:\n",
    "        raise RuntimeError(f\"No feasible solution found. Gurobi status {M.Status}\")\n",
    "\n",
    "    x_star = np.array([x[i].X for i in range(d)], dtype=float)\n",
    "    y_star = float(y_raw.X)\n",
    "\n",
    "    info = {\n",
    "        \"status\": M.Status,\n",
    "        \"runtime\": M.Runtime,\n",
    "        \"gap\": getattr(M, \"MIPGap\", None),\n",
    "        \"sol_count\": M.SolCount,\n",
    "        \"obj_gurobi\": float(M.ObjVal),\n",
    "        \"obj_bound\": float(getattr(M, \"ObjBound\", float(\"nan\"))),\n",
    "    }\n",
    "    return x_star, y_star, info\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4858979",
   "metadata": {},
   "source": [
    "## Test Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "569f2165",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_ip(model_type, model, scaler, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    if model_type == \"DFN\":\n",
    "        return solve_dfn_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if model_type == \"MLP\":\n",
    "        return solve_mlp_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if model_type == \"MaxAffine\":\n",
    "        return solve_maxaffine_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if model_type == \"LSET\":\n",
    "        return solve_lset_ip_gurobi(model, scaler, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    raise ValueError(model_type)\n",
    "\n",
    "\n",
    "def suppress_stdout(fn, *args, silence: bool = False, **kwargs):\n",
    "    if not silence:\n",
    "        return fn(*args, **kwargs), \"\"\n",
    "    buf = io.StringIO()\n",
    "    with contextlib.redirect_stdout(buf):\n",
    "        out = fn(*args, **kwargs)\n",
    "    return out, buf.getvalue()\n",
    "\n",
    "\n",
    "def make_obj(model, scaler, device, chunk: int = 4096):\n",
    "    xm = scaler[\"x_mean\"].to(device)\n",
    "    xs = scaler[\"x_std\"].to(device)\n",
    "    ym = scaler[\"y_mean\"].to(device)\n",
    "    ys = scaler[\"y_std\"].to(device)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def obj(Xraw):\n",
    "        Xraw = torch.as_tensor(Xraw, dtype=torch.float32, device=device)\n",
    "        if Xraw.dim() == 1:\n",
    "            Xraw = Xraw.unsqueeze(0)\n",
    "\n",
    "        outs = []\n",
    "        B = int(Xraw.shape[0])\n",
    "        for i in range(0, B, int(chunk)):\n",
    "            Xb = Xraw[i:i+chunk]\n",
    "            try:\n",
    "                Xn = (Xb - xm) / xs\n",
    "                yn = model(Xn)\n",
    "                yn = torch.as_tensor(yn).reshape(-1)  \n",
    "                y  = (yn * ys + ym).reshape(-1).detach().cpu() \n",
    "                if y.numel() != Xb.shape[0]:\n",
    "                    y = y.expand(Xb.shape[0]).contiguous()\n",
    "                outs.append(y)\n",
    "            except Exception as e:\n",
    "                obj._err_count += int(Xb.shape[0])\n",
    "                if obj._err_first is None:\n",
    "                    obj._err_first = repr(e)\n",
    "                outs.append(torch.full((int(Xb.shape[0]),), float(\"inf\"), device=\"cpu\"))\n",
    "\n",
    "        return torch.cat(outs, 0).numpy()\n",
    "\n",
    "    obj._err_count = 0\n",
    "    obj._err_first = None\n",
    "    return obj\n",
    "\n",
    "\n",
    "def safe_raw_mse(obj, Xraw, yraw):\n",
    "    yp = np.asarray(obj(Xraw), dtype=float).reshape(-1)\n",
    "    yt = yraw.detach().cpu().numpy().reshape(-1) if torch.is_tensor(yraw) else np.asarray(yraw, float).reshape(-1)\n",
    "    if yp.shape[0] != yt.shape[0]:\n",
    "        return np.nan, f\"shape mismatch: pred {yp.shape} vs true {yt.shape}\"\n",
    "    mask = np.isfinite(yp)\n",
    "    if mask.sum() == 0:\n",
    "        return np.nan, \"all predictions were non-finite (inf/nan)\"\n",
    "    return float(np.mean((yp[mask] - yt[mask])**2)), None\n",
    "\n",
    "\n",
    "def safe_norm_mse(model, scaler, device, Xraw, yraw, *, chunk=4096):\n",
    "    xm = scaler[\"x_mean\"].to(device)\n",
    "    xs = scaler[\"x_std\"].to(device)\n",
    "    ym = scaler[\"y_mean\"].to(device)\n",
    "    ys = scaler[\"y_std\"].to(device)\n",
    "\n",
    "    Xraw = torch.as_tensor(Xraw, dtype=torch.float32, device=device)\n",
    "    if Xraw.dim() == 1:\n",
    "        Xraw = Xraw.unsqueeze(0)\n",
    "\n",
    "    yraw_t = yraw\n",
    "    if torch.is_tensor(yraw_t):\n",
    "        yraw_t = yraw_t.to(device=device, dtype=torch.float32)\n",
    "    else:\n",
    "        yraw_t = torch.as_tensor(yraw_t, dtype=torch.float32, device=device)\n",
    "    yraw_t = yraw_t.reshape(-1)\n",
    "\n",
    "    preds = []\n",
    "    B = int(Xraw.shape[0])\n",
    "    try:\n",
    "        with torch.no_grad():\n",
    "            for i in range(0, B, int(chunk)):\n",
    "                Xb = Xraw[i:i+chunk]\n",
    "                Xn = (Xb - xm) / xs\n",
    "                yn = model(Xn)\n",
    "                yn = torch.as_tensor(yn).reshape(-1)\n",
    "                if yn.numel() != Xb.shape[0]:\n",
    "                    yn = yn.expand(Xb.shape[0]).contiguous()\n",
    "                preds.append(yn.detach().cpu())\n",
    "    except Exception as e:\n",
    "        return np.nan, f\"model forward failed in safe_norm_mse: {repr(e)}\"\n",
    "\n",
    "    yp = torch.cat(preds, 0).numpy().reshape(-1)\n",
    "    yt = ((yraw_t - ym) / ys).detach().cpu().numpy().reshape(-1)\n",
    "\n",
    "    if yp.shape[0] != yt.shape[0]:\n",
    "        return np.nan, f\"shape mismatch: pred {yp.shape} vs true {yt.shape}\"\n",
    "\n",
    "    mask = np.isfinite(yp) & np.isfinite(yt)\n",
    "    if mask.sum() == 0:\n",
    "        return np.nan, \"all predictions or targets were non-finite (inf/nan)\"\n",
    "\n",
    "    return float(np.mean((yp[mask] - yt[mask])**2)), None\n",
    "\n",
    "\n",
    "def t_to_best(hist, y_best):\n",
    "    for r in hist:\n",
    "        if abs(float(r.get(\"best_y\", np.inf)) - float(y_best)) < 1e-12:\n",
    "            return float(r.get(\"t\", float(\"nan\")))\n",
    "    return float(\"nan\")\n",
    "\n",
    "\n",
    "def check_ip_matches_obj(name, obj, x_ip, y_ip, *, strict: bool, tol: float):\n",
    "    y_obj = float(np.asarray(obj(np.asarray(x_ip)), dtype=float).reshape(-1)[0])\n",
    "    y_ip  = float(y_ip)\n",
    "    rel = abs(y_obj - y_ip) / (abs(y_obj) + 1e-12)\n",
    "    print(f\"[CHECK {name}] obj(x_ip)={y_obj:.6g}  ip_y={y_ip:.6g}  rel_err={rel:.3e}\")\n",
    "    if strict and rel > tol:\n",
    "        raise RuntimeError(f\"{name}: IP objective != obj() (rel_err={rel:.3e})\")\n",
    "    return rel\n",
    "\n",
    "\n",
    "def mean_se(x):\n",
    "    x = pd.to_numeric(pd.Series(x), errors=\"coerce\").to_numpy()\n",
    "    x = x[np.isfinite(x)]\n",
    "    n = int(x.shape[0])\n",
    "    if n == 0:\n",
    "        return np.nan, np.nan\n",
    "    m = float(x.mean())\n",
    "    se = float(x.std(ddof=1) / np.sqrt(n)) if n > 1 else 0.0\n",
    "    return m, se\n",
    "\n",
    "\n",
    "def fmt_mean_se(m, se):\n",
    "    if not np.isfinite(m):\n",
    "        return \"nan\"\n",
    "    if not np.isfinite(se):\n",
    "        return f\"{m:.6g}\"\n",
    "    return f\"{m:.6g} ± {se:.3g}\"\n",
    "\n",
    "\n",
    "def repr_solution(xs: pd.Series, seeds=None):\n",
    "    xs = xs.dropna().astype(str)\n",
    "    xs = xs[xs != \"None\"]\n",
    "    if xs.empty:\n",
    "        return None\n",
    "\n",
    "    # With seed information: show per-seed if solutions differ\n",
    "    if seeds is not None:\n",
    "        df = pd.DataFrame({\"seed\": pd.Series(seeds), \"x\": xs})\n",
    "        df = df.dropna()\n",
    "        df[\"x\"] = df[\"x\"].astype(str)\n",
    "        if df.empty:\n",
    "            return None\n",
    "        if df[\"x\"].nunique() == 1:\n",
    "            return df[\"x\"].iloc[0]\n",
    "        df = df.sort_values(\"seed\")\n",
    "        return \"\\n\".join([f\"seed={int(r.seed)}: {r.x}\" for r in df.itertuples(index=False)])\n",
    "\n",
    "    # No seed info: keep compact\n",
    "    if xs.nunique() == 1:\n",
    "        return xs.iloc[0]\n",
    "    vc = xs.value_counts()\n",
    "    top = vc.index[0]\n",
    "    n_unique = int(vc.shape[0])\n",
    "    return f\"{top} (+{n_unique-1} other)\"\n",
    "    \n",
    "\n",
    "# -----------------------------\n",
    "# Ground-truth optimum solvers\n",
    "# -----------------------------\n",
    "\n",
    "def solve_true_opt_quadratic(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    Q  = np.asarray(gt[\"Q\"], float)\n",
    "    xs = np.asarray(gt[\"x_star\"], float).reshape(-1)\n",
    "    n  = int(xs.shape[0])\n",
    "    assert Q.shape == (n, n)\n",
    "\n",
    "    m = gp.Model(\"gt_quadratic\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the (ground-truth) solve reproducible across runs\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    x = m.addVars(n, vtype=GRB.INTEGER, name=\"x\")\n",
    "    for i in range(n):\n",
    "        x[i].LB = int(xmin[i])\n",
    "        x[i].UB = int(xmax[i])\n",
    "    m.addConstr(gp.quicksum(x[i] for i in range(n)) == int(sum_eq), name=\"sum_eq\")\n",
    "\n",
    "    expr = gp.quicksum(float(Q[i, j]) * (x[i] - float(xs[i])) * (x[j] - float(xs[j])) for i in range(n) for j in range(n))\n",
    "    m.setObjective(expr, GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT quadratic solve failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(x[i].X)) for i in range(n)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status)}\n",
    "\n",
    "\n",
    "def solve_true_opt_assignment(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    C = np.asarray(gt[\"C\"], float)\n",
    "    n_rows, n_cols = C.shape\n",
    "    k = int(sum_eq)\n",
    "    if k > n_cols:\n",
    "        raise RuntimeError(f\"sum_eq={k} exceeds number of columns={n_cols} (infeasible)\")\n",
    "\n",
    "    m = gp.Model(\"gt_assignment\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the (ground-truth) solve reproducible across runs\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    z = m.addVars(n_rows, vtype=GRB.BINARY, name=\"z\")  # select row i\n",
    "    y = m.addVars(n_rows, n_cols, vtype=GRB.BINARY, name=\"y\")  # assign row i to col j\n",
    "\n",
    "    m.addConstr(gp.quicksum(z[i] for i in range(n_rows)) == k, name=\"k_rows\")\n",
    "    for i in range(n_rows):\n",
    "        m.addConstr(gp.quicksum(y[i, j] for j in range(n_cols)) == z[i], name=f\"assign_row_{i}\")\n",
    "    for j in range(n_cols):\n",
    "        m.addConstr(gp.quicksum(y[i, j] for i in range(n_rows)) <= 1, name=f\"assign_col_{j}\")\n",
    "\n",
    "    m.setObjective(gp.quicksum(float(C[i, j]) * y[i, j] for i in range(n_rows) for j in range(n_cols)), GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT assignment solve failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(z[i].X)) for i in range(n_rows)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status)}\n",
    "\n",
    "\n",
    "def solve_true_opt_mdvsp(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    N  = int(gt[\"N\"])\n",
    "    SS = int(gt[\"SS\"])\n",
    "    TT = int(gt[\"TT\"])\n",
    "    src  = np.asarray(gt[\"src\"], np.int64)\n",
    "    dst  = np.asarray(gt[\"dst\"], np.int64)\n",
    "    cost = np.asarray(gt[\"cost\"], float)\n",
    "    cap0 = np.asarray(gt[\"cap0\"], float)\n",
    "    idxSS = np.asarray(gt[\"idxSS\"], np.int64)\n",
    "    idxTT = np.asarray(gt[\"idxTT\"], np.int64)\n",
    "\n",
    "    E = int(src.shape[0])\n",
    "    k = int(idxSS.shape[0])\n",
    "    assert idxTT.shape[0] == k, \"Expect idxSS and idxTT to be same length\"\n",
    "\n",
    "    out_edges = [[] for _ in range(N)]\n",
    "    in_edges  = [[] for _ in range(N)]\n",
    "    for e in range(E):\n",
    "        out_edges[int(src[e])].append(e)\n",
    "        in_edges[int(dst[e])].append(e)\n",
    "\n",
    "    idxSS = idxSS.astype(int)\n",
    "    idxTT = idxTT.astype(int)\n",
    "    var_arc_to_i = {int(idxSS[i]): i for i in range(k)}\n",
    "    var_arc_to_i.update({int(idxTT[i]): i for i in range(k)})\n",
    "    var_arcs = sorted(var_arc_to_i.keys())\n",
    "\n",
    "    m = gp.Model(\"gt_mdvsp_true_opt\")\n",
    "    m.Params.OutputFlag = 1 if verbose else 0\n",
    "    # Make the (ground-truth) solve reproducible across runs\n",
    "    try:\n",
    "        m.Params.Threads = 1\n",
    "    except Exception:\n",
    "        pass\n",
    "    if seed is not None:\n",
    "        try:\n",
    "            m.Params.Seed = int(seed)\n",
    "        except Exception:\n",
    "            pass\n",
    "    if time_limit is not None:\n",
    "        m.Params.TimeLimit = float(time_limit)\n",
    "\n",
    "    m.Params.NonConvex = 2\n",
    "\n",
    "    x = m.addVars(k, vtype=GRB.INTEGER, name=\"x\")\n",
    "    for i in range(k):\n",
    "        x[i].LB = int(xmin[i])\n",
    "        x[i].UB = int(xmax[i])\n",
    "    m.addConstr(gp.quicksum(x[i] for i in range(k)) == int(sum_eq), name=\"sum_eq\")\n",
    "\n",
    "    f = m.addVars(E, vtype=GRB.CONTINUOUS, lb=0.0, name=\"f\")\n",
    "\n",
    "    for e in range(E):\n",
    "        if e in var_arc_to_i:\n",
    "            i = var_arc_to_i[e]\n",
    "            m.addConstr(f[e] <= x[i], name=f\"cap_var_{e}\")\n",
    "        else:\n",
    "            m.addConstr(f[e] <= float(cap0[e]), name=f\"cap_fix_{e}\")\n",
    "\n",
    "    F = m.addVar(vtype=GRB.CONTINUOUS, lb=0.0, name=\"F\")\n",
    "\n",
    "    # flow conservation: out - in = F at SS, = -F at TT, else 0\n",
    "    for v in range(N):\n",
    "        out_sum = gp.quicksum(f[e] for e in out_edges[v])\n",
    "        in_sum  = gp.quicksum(f[e] for e in in_edges[v])\n",
    "        rhs = F if v == SS else (-F if v == TT else 0.0)\n",
    "        m.addConstr(out_sum - in_sum == rhs, name=f\"flow_{v}\")\n",
    "\n",
    "    y = m.addVars(N, vtype=GRB.CONTINUOUS, lb=0.0, ub=1.0, name=\"y\")\n",
    "    m.addConstr(y[SS] == 1.0, name=\"ySS\")\n",
    "    m.addConstr(y[TT] == 0.0, name=\"yTT\")\n",
    "\n",
    "    z = m.addVars(E, vtype=GRB.CONTINUOUS, lb=0.0, ub=1.0, name=\"z\")\n",
    "    for e in range(E):\n",
    "        m.addConstr(z[e] >= y[int(src[e])] - y[int(dst[e])], name=f\"dual_{e}\")\n",
    "\n",
    "    dual_obj = gp.QuadExpr()\n",
    "    for e in range(E):\n",
    "        if e in var_arc_to_i:\n",
    "            i = var_arc_to_i[e]\n",
    "            # bilinear term: x_i * z_e\n",
    "            dual_obj += x[i] * z[e]\n",
    "        else:\n",
    "            dual_obj += float(cap0[e]) * z[e]\n",
    "    m.addConstr(F == dual_obj, name=\"strong_duality\")\n",
    "\n",
    "    m.setObjective(gp.quicksum(float(cost[e]) * f[e] for e in range(E)), GRB.MINIMIZE)\n",
    "    m.optimize()\n",
    "\n",
    "    if m.Status not in (GRB.OPTIMAL, GRB.TIME_LIMIT, GRB.SUBOPTIMAL, GRB.INTERRUPTED):\n",
    "        raise RuntimeError(f\"GT mdvsp true-opt failed, status={m.Status}\")\n",
    "\n",
    "    xsol = np.array([int(round(x[i].X)) for i in range(k)], dtype=int)\n",
    "    return xsol, float(m.ObjVal), {\"status\": int(m.Status), \"F\": float(F.X)}\n",
    "    \n",
    "\n",
    "\n",
    "def solve_true_opt(gt, xmin, xmax, sum_eq, *, time_limit=None, verbose=False, seed=None):\n",
    "    \"\"\"Top-level dispatcher for ground-truth optimum.\"\"\"\n",
    "    if not isinstance(gt, dict):\n",
    "        raise RuntimeError(\"No ground-truth structure found (gt must be a dict).\")\n",
    "    t = gt.get(\"type\", None)\n",
    "    if t == \"quadratic\":\n",
    "        return solve_true_opt_quadratic(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if t == \"assignment\":\n",
    "        return solve_true_opt_assignment(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    if t == \"mdvsp\":\n",
    "        return solve_true_opt_mdvsp(gt, xmin, xmax, sum_eq, time_limit=time_limit, verbose=verbose, seed=seed)\n",
    "    raise RuntimeError(f\"Unknown gt type: {t}\")\n",
    "\n",
    "\n",
    "# -----------------------------\n",
    "# Benchmark runner\n",
    "# -----------------------------\n",
    "def run_benchmark(\n",
    "    *,\n",
    "    dataset_type: str,\n",
    "    dataset_params: dict,\n",
    "    runs: list[tuple[str, str, dict]],\n",
    "    train_base: dict,\n",
    "    lr_map: dict,\n",
    "    x0: np.ndarray,\n",
    "    xmin: np.ndarray,\n",
    "    xmax: np.ndarray,\n",
    "    delta: int,\n",
    "    sum_eq: int,\n",
    "    n_seeds: int = 1,\n",
    "    vary_dataset_seed: bool = False,\n",
    "    vary_model_init_seed: bool = True,\n",
    "    strict_ip_check: bool = False,\n",
    "    ip_check_tol: float = 1e-4,\n",
    "    silence_local_search: bool = False,\n",
    "    allow_plots_multi_seed: bool = True,\n",
    "    time_limit=None,\n",
    "):\n",
    "    seeds = list(range(int(n_seeds)))\n",
    "\n",
    "    learn_rows, opt_rows, spec_rows, fail_rows = [], [], [], []\n",
    "    gt_rows = []\n",
    "\n",
    "    gt_cache_by_seed = {}\n",
    "\n",
    "    for seed in seeds:\n",
    "        print(f\"\\n\\n===================== SEED {seed} =====================\")\n",
    "        # Per-seed starting point + constraint (so the run seed controls x0 / sum_eq too)\n",
    "        x0_seed = x0\n",
    "        if isinstance(x0, (list, tuple, np.ndarray)) and np.asarray(x0).ndim == 2:\n",
    "            x0_seed = np.asarray(x0)[int(seed)]\n",
    "        x0_seed = np.asarray(x0_seed, int).ravel()\n",
    "\n",
    "        sum_eq_seed = sum_eq\n",
    "        if isinstance(sum_eq, (list, tuple, np.ndarray)) and np.asarray(sum_eq).ndim > 0:\n",
    "            sum_eq_seed = int(np.asarray(sum_eq)[int(seed)])\n",
    "        sum_eq_seed = int(sum_eq_seed)\n",
    "\n",
    "\n",
    "        for name, model_type, model_params_base in runs:\n",
    "            # ---- per-run params ----\n",
    "            dp = dict(dataset_params)\n",
    "            if vary_dataset_seed and \"seed\" in dp:\n",
    "                dp[\"seed\"] = int(seed)\n",
    "\n",
    "            mp = dict(model_params_base)\n",
    "            if vary_model_init_seed and \"seed\" in mp:\n",
    "                mp[\"seed\"] = int(seed)\n",
    "\n",
    "            tp = dict(train_base)\n",
    "            tp[\"seed\"] = int(seed)\n",
    "            tp[\"lr\"] = float(lr_map[model_type])\n",
    "\n",
    "            # avoid plot spam unless explicitly allowed\n",
    "            if (n_seeds > 1) and (tp.get(\"plot_every\", 0) not in (0, None)) and (not allow_plots_multi_seed):\n",
    "                tp[\"plot_every\"] = 0\n",
    "\n",
    "            # ---- TRAIN ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                out = generate_and_train_simple(dataset_type, dp, model_type, mp, tp)\n",
    "            except Exception as e:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"TRAIN\", error=repr(e)))\n",
    "                learn_rows.append(dict(seed=seed, model=name, train_time=np.nan, best_epoch=np.nan,\n",
    "                                       best_val=np.nan, test=np.nan, train_err=repr(e)))\n",
    "                continue\n",
    "            train_time = time.perf_counter() - t0\n",
    "\n",
    "            if isinstance(out, tuple) and len(out) == 4:\n",
    "                model, data, hist, spec = out\n",
    "            elif isinstance(out, tuple) and len(out) == 3:\n",
    "                model, data, hist = out\n",
    "                n_params = sum(p.numel() for p in model.parameters())\n",
    "                spec = dict(n_params=int(n_params), extra=\"\", train_params=tp)\n",
    "            else:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"TRAIN\", error=f\"Unexpected return: {type(out)}\"))\n",
    "                continue\n",
    "\n",
    "            # ---- model spec (seed 0 only) ----\n",
    "            if seed == seeds[0]:\n",
    "                spec_rows.append(dict(\n",
    "                    model=name,\n",
    "                    n_params=int(spec.get(\"n_params\", np.nan)),\n",
    "                    details=str(spec.get(\"extra\", \"\")),\n",
    "                    lr=float(spec.get(\"train_params\", {}).get(\"lr\", tp[\"lr\"])),\n",
    "                    batch_size=int(spec.get(\"train_params\", {}).get(\"batch_size\", tp[\"batch_size\"])),\n",
    "                    epochs=int(spec.get(\"train_params\", {}).get(\"epochs\", tp[\"epochs\"])),\n",
    "                ))\n",
    "\n",
    "            device = data[\"device\"]\n",
    "            scaler = data[\"scaler\"]\n",
    "            obj = make_obj(model, scaler, device, chunk=int(tp.get(\"plot_chunk\", 4096)))\n",
    "            gt = data.get(\"true\", None)\n",
    "\n",
    "            # ---- compute GT optimum (once per seed) ----\n",
    "            if seed not in gt_cache_by_seed:\n",
    "                t_gt0 = time.perf_counter()\n",
    "                try:\n",
    "                    x_gt, y_gt, gt_info = solve_true_opt(gt, xmin, xmax, sum_eq_seed, time_limit=time_limit, verbose=False, seed=seed)\n",
    "                    gt_time = time.perf_counter() - t_gt0\n",
    "                    # evaluate with the existing helper to be extra sure\n",
    "                    y_gt_check = float(eval_true_obj(gt, x_gt))\n",
    "                    if np.isfinite(y_gt_check) and abs(y_gt_check - float(y_gt)) / (abs(float(y_gt_check)) + 1e-12) > 1e-6:\n",
    "                        fail_rows.append(dict(seed=seed, model=name, stage=\"GT_OPT_CHECK\",\n",
    "                                              error=f\"gt solver mismatch: solver={y_gt:.6g} eval_true_obj={y_gt_check:.6g}\"))\n",
    "                    gt_cache_by_seed[seed] = dict(x=str(np.asarray(x_gt, int).tolist()),\n",
    "                                                  true_y=float(y_gt_check if np.isfinite(y_gt_check) else y_gt),\n",
    "                                                  runtime=float(gt_time),\n",
    "                                                  err=None)\n",
    "                except Exception as e:\n",
    "                    gt_time = time.perf_counter() - t_gt0\n",
    "                    gt_cache_by_seed[seed] = dict(x=None, true_y=np.nan, runtime=float(gt_time), err=repr(e))\n",
    "                    fail_rows.append(dict(seed=seed, model=name, stage=\"GT_OPT\", error=repr(e)))\n",
    "\n",
    "            # ---- learning metrics: best val (normalized) + test loss (normalized) ----\n",
    "            test_norm, err_te = safe_norm_mse(model, scaler, device, data['raw']['Xte'], data['raw']['yte'], chunk=int(tp.get('plot_chunk', 4096)))\n",
    "            if err_te:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"EVAL_TEST\", error=err_te))\n",
    "\n",
    "            if obj._err_count > 0:\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"OBJ\",\n",
    "                                      error=f\"obj() had {obj._err_count} forward failures; first={obj._err_first}\"))\n",
    "\n",
    "            val_curve = np.asarray(hist.get(\"val_mse_norm\", []), dtype=float)\n",
    "            best_ep = int(np.argmin(val_curve) + 1) if val_curve.size else np.nan\n",
    "            best_val = float(np.min(val_curve)) if val_curve.size else np.nan\n",
    "\n",
    "            learn_rows.append(dict(\n",
    "                seed=seed, model=name,\n",
    "                train_time=float(train_time),\n",
    "                best_epoch=best_ep,\n",
    "                best_val=float(best_val),\n",
    "                test=float(test_norm) if np.isfinite(test_norm) else np.nan,\n",
    "                train_err=(err_te),\n",
    "            ))\n",
    "\n",
    "            # ---- LOCAL SEARCH ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                obj_ls = obj\n",
    "                (ls_out, _ls_log) = suppress_stdout(\n",
    "                    local_search_l1_int, obj_ls, x0_seed, xmin, xmax,\n",
    "                    delta=delta, sum_eq=sum_eq_seed, print_every=0,\n",
    "                    silence=silence_local_search\n",
    "                )\n",
    "                x_best, y_best, ls_hist = ls_out\n",
    "                ls_time = time.perf_counter() - t0\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"LS\",\n",
    "                    x=str(np.asarray(x_best, int).tolist()),\n",
    "                    y=float(y_best),\n",
    "                    true_y=float(eval_true_obj(gt, x_best)),\n",
    "                    runtime=float(ls_time),\n",
    "                    t_best=float(t_to_best(ls_hist, y_best)),\n",
    "                    iters=int(len(ls_hist) - 1),\n",
    "                    err=None,\n",
    "                ))\n",
    "            except Exception as e:\n",
    "                ls_time = time.perf_counter() - t0\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"LS\",\n",
    "                    x=None, y=np.nan, true_y=np.nan,\n",
    "                    runtime=float(ls_time),\n",
    "                    t_best=np.nan, iters=np.nan,\n",
    "                    err=repr(e),\n",
    "                ))\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"LS\", error=repr(e)))\n",
    "\n",
    "            # ---- IP SOLVE (learned objective) ----\n",
    "            t0 = time.perf_counter()\n",
    "            try:\n",
    "                x_ip, y_ip, info = solve_ip(\n",
    "                    model_type, model, scaler, xmin, xmax, sum_eq_seed,\n",
    "                    time_limit=time_limit, verbose=True, seed=seed\n",
    "                )\n",
    "                ip_time = time.perf_counter() - t0\n",
    "\n",
    "                rel_err = check_ip_matches_obj(name, obj, x_ip, y_ip, strict=strict_ip_check, tol=ip_check_tol)\n",
    "                if rel_err > ip_check_tol:\n",
    "                    fail_rows.append(dict(seed=seed, model=name, stage=\"IP_CHECK\",\n",
    "                                          error=f\"rel_err={rel_err:.3e} (tol={ip_check_tol})\"))\n",
    "\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"IP\",\n",
    "                    x=str(np.asarray(x_ip, int).tolist()),\n",
    "                    y=float(y_ip),\n",
    "                    true_y=float(eval_true_obj(gt, x_ip)),\n",
    "                    runtime=float(ip_time),\n",
    "                    status=(info.get(\"status\") if isinstance(info, dict) else None),\n",
    "                    gap=(info.get(\"gap\") if isinstance(info, dict) else None),\n",
    "                    ip_rel_err=float(rel_err),\n",
    "                    err=None,\n",
    "                ))\n",
    "            except Exception as e:\n",
    "                ip_time = time.perf_counter() - t0\n",
    "                opt_rows.append(dict(\n",
    "                    seed=seed, model=name, method=\"IP\",\n",
    "                    x=None, y=np.nan, true_y=np.nan,\n",
    "                    runtime=float(ip_time),\n",
    "                    status=None, gap=None, ip_rel_err=np.nan,\n",
    "                    err=repr(e),\n",
    "                ))\n",
    "                fail_rows.append(dict(seed=seed, model=name, stage=\"IP\", error=repr(e)))\n",
    "\n",
    "    # ---- build DFs ----\n",
    "    spec_df  = pd.DataFrame(spec_rows).drop_duplicates(\"model\").sort_values(\"model\").reset_index(drop=True)\n",
    "    learn_df = pd.DataFrame(learn_rows).sort_values([\"model\", \"seed\"]).reset_index(drop=True)\n",
    "    opt_df   = pd.DataFrame(opt_rows).sort_values([\"model\", \"seed\", \"method\"]).reset_index(drop=True)\n",
    "\n",
    "    fail_df = pd.DataFrame(fail_rows)\n",
    "    if fail_df.empty:\n",
    "        fail_df = pd.DataFrame(columns=[\"seed\", \"model\", \"stage\", \"error\"])\n",
    "    else:\n",
    "        for c in [\"stage\", \"model\", \"seed\"]:\n",
    "            if c not in fail_df.columns:\n",
    "                fail_df[c] = np.nan\n",
    "        fail_df = fail_df.sort_values([\"stage\", \"model\", \"seed\"]).reset_index(drop=True)\n",
    "\n",
    "    gt_df = pd.DataFrame([dict(seed=s, **d) for s, d in sorted(gt_cache_by_seed.items())]).sort_values(\"seed\")\n",
    "\n",
    "    # ---- per-seed gaps ----\n",
    "    gap_seed_df = pd.DataFrame(columns=[\"seed\", \"model\", \"ls_vs_ip_pct\", \"ip_true_vs_gt_pct\"])\n",
    "    if not opt_df.empty:\n",
    "        pivot_y = opt_df.pivot_table(index=[\"seed\", \"model\"], columns=\"method\", values=\"y\", aggfunc=\"first\")\n",
    "        pivot_true = opt_df.pivot_table(index=[\"seed\", \"model\"], columns=\"method\", values=\"true_y\", aggfunc=\"first\")\n",
    "\n",
    "        if (\"LS\" in pivot_y.columns) and (\"IP\" in pivot_y.columns):\n",
    "            ls_vs_ip = 100.0 * (pivot_y[\"LS\"] - pivot_y[\"IP\"]) / (np.abs(pivot_y[\"IP\"]) + 1e-12)\n",
    "        else:\n",
    "            ls_vs_ip = pd.Series(index=pivot_y.index, dtype=float)\n",
    "\n",
    "        # IP true vs GT optimum true\n",
    "        gt_true = gt_df.set_index(\"seed\")[\"true_y\"] if not gt_df.empty else pd.Series(dtype=float)\n",
    "        ip_true_vs_gt = []\n",
    "        for (seed, model), row in pivot_true.iterrows():\n",
    "            ipt = float(row.get(\"IP\", np.nan))\n",
    "            gtt = float(gt_true.get(seed, np.nan))\n",
    "            if np.isfinite(ipt) and np.isfinite(gtt):\n",
    "                ip_true_vs_gt.append(((seed, model), 100.0 * (ipt - gtt) / (abs(gtt) + 1e-12)))\n",
    "            else:\n",
    "                ip_true_vs_gt.append(((seed, model), np.nan))\n",
    "        ip_true_vs_gt = pd.Series({k: v for k, v in ip_true_vs_gt})\n",
    "\n",
    "        gap_seed_df = pd.DataFrame({\n",
    "            \"seed\": [k[0] for k in ip_true_vs_gt.index],\n",
    "            \"model\": [k[1] for k in ip_true_vs_gt.index],\n",
    "            \"ls_vs_ip_pct\": [float(ls_vs_ip.get(k, np.nan)) for k in ip_true_vs_gt.index],\n",
    "            \"ip_true_vs_gt_pct\": [float(ip_true_vs_gt.get(k, np.nan)) for k in ip_true_vs_gt.index],\n",
    "        })\n",
    "\n",
    "    # ---- LEARNING SUMMARY (mean ± SE over seeds) ----\n",
    "    learn_sum_rows = []\n",
    "    for model in sorted(learn_df[\"model\"].unique()):\n",
    "        sub = learn_df[learn_df[\"model\"] == model]\n",
    "        m, se = mean_se(sub[\"train_time\"]); train_time_s = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(sub[\"best_val\"]);   best_val_s   = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(sub[\"test\"]);       test_s       = fmt_mean_se(m, se)\n",
    "        learn_sum_rows.append(dict(model=model, train_time=train_time_s, best_val=best_val_s, test=test_s))\n",
    "    learn_summary_df = pd.DataFrame(learn_sum_rows).sort_values(\"model\").reset_index(drop=True)\n",
    "\n",
    "    # ---- OPTIMIZATION SUMMARY ----\n",
    "    opt_sum_rows = []\n",
    "    gt_x_repr = (repr_solution(gt_df.get(\"x\", pd.Series(dtype=str)), gt_df.get(\"seed\", None))\n",
    "                if not gt_df.empty else None)\n",
    "    m, se = mean_se(gt_df.get(\"true_y\", pd.Series(dtype=float))); gt_true_s = fmt_mean_se(m, se)\n",
    "    m, se = mean_se(gt_df.get(\"runtime\", pd.Series(dtype=float))); gt_time_s = fmt_mean_se(m, se)\n",
    "\n",
    "    for model in sorted(opt_df[\"model\"].unique()):\n",
    "        sub = opt_df[opt_df[\"model\"] == model]\n",
    "        row = {\"model\": model}\n",
    "\n",
    "        ls = sub[sub[\"method\"] == \"LS\"]\n",
    "        ip = sub[sub[\"method\"] == \"IP\"]\n",
    "\n",
    "        row[\"LS_x\"] = repr_solution(ls[\"x\"], ls.get(\"seed\", None))\n",
    "        m, se = mean_se(ls[\"y\"]);       row[\"LS_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ls[\"true_y\"]);  row[\"LS_true_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ls[\"runtime\"]); row[\"LS_time\"] = fmt_mean_se(m, se)\n",
    "\n",
    "        row[\"IP_x\"] = repr_solution(ip[\"x\"], ip.get(\"seed\", None))\n",
    "        m, se = mean_se(ip[\"y\"]);       row[\"IP_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ip[\"true_y\"]);  row[\"IP_true_y\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(ip[\"runtime\"]); row[\"IP_time\"] = fmt_mean_se(m, se)\n",
    "\n",
    "        row[\"GT_x\"] = gt_x_repr\n",
    "        row[\"GT_true_y\"] = gt_true_s\n",
    "        row[\"GT_time\"] = gt_time_s\n",
    "\n",
    "        gsub = gap_seed_df[gap_seed_df[\"model\"] == model] if not gap_seed_df.empty else pd.DataFrame()\n",
    "        m, se = mean_se(gsub.get(\"ls_vs_ip_pct\", pd.Series(dtype=float))); row[\"LS_vs_IP_%\"] = fmt_mean_se(m, se)\n",
    "        m, se = mean_se(gsub.get(\"ip_true_vs_gt_pct\", pd.Series(dtype=float))); row[\"IP_true_vs_GT_%\"] = fmt_mean_se(m, se)\n",
    "\n",
    "        opt_sum_rows.append(row)\n",
    "\n",
    "    opt_summary_df = pd.DataFrame(opt_sum_rows).sort_values(\"model\").reset_index(drop=True)\n",
    "\n",
    "    # ---- Print tables ----\n",
    "    print(\"\\n=== MODEL SPECS (from seed 0 run) ===\")\n",
    "    if not spec_df.empty:\n",
    "        print(spec_df[[\"model\", \"n_params\", \"details\", \"lr\", \"batch_size\", \"epochs\"]].to_string(index=False))\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== LEARNING SUMMARY (mean ± SE over seeds) ===\")\n",
    "    if not learn_summary_df.empty:\n",
    "        print(learn_summary_df.to_string(index=False))\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== OPTIMIZATION SUMMARY (mean ± SE over seeds) ===\")\n",
    "    if not opt_summary_df.empty:\n",
    "        cols = [\n",
    "            \"model\",\n",
    "            \"LS_x\", \"LS_y\", \"LS_true_y\", \"LS_time\",\n",
    "            \"IP_x\", \"IP_y\", \"IP_true_y\", \"IP_time\",\n",
    "            \"GT_x\", \"GT_true_y\", \"GT_time\",\n",
    "            \"LS_vs_IP_%\", \"IP_true_vs_GT_%\",\n",
    "        ]\n",
    "        print(opt_summary_df[cols].to_string(index=False))\n",
    "    else:\n",
    "        print(\"None\")\n",
    "\n",
    "    print(\"\\n=== FAILURES / WARNINGS (if any) ===\")\n",
    "    if fail_df.shape[0] == 0:\n",
    "        print(\"None\")\n",
    "    else:\n",
    "        print(fail_df.to_string(index=False))\n",
    "\n",
    "    return spec_df, learn_df, opt_df, fail_df, learn_summary_df, opt_summary_df, gt_df\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e88e7a1f",
   "metadata": {},
   "source": [
    "## Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8edf3f97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "===================== SEED 0 =====================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/27/6hvffxjj6mnf6nw61tsrkldm0000gp/T/ipykernel_9226/1051077697.py:159: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  return torch.load(str(artifact_path), map_location=\"cpu\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 1800\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  1800\n",
      "\n",
      "Optimize a model with 129 rows, 10260 columns and 21537 nonzeros\n",
      "Model fingerprint: 0x2dd75dd0\n",
      "Variable types: 10240 continuous, 20 integer (0 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-08, 1e+00]\n",
      "  Objective range  [1e-01, 4e+08]\n",
      "  Bounds range     [1e+00, 1e+06]\n",
      "  RHS range        [9e-03, 3e+01]\n",
      "Found heuristic solution: objective 4.701742e+11\n",
      "Presolve removed 0 rows and 2627 columns\n",
      "Presolve time: 0.01s\n",
      "Presolved: 129 rows, 7633 columns, 16278 nonzeros\n",
      "Variable types: 7613 continuous, 20 integer (0 binary)\n",
      "\n",
      "Root relaxation: objective 4.219607e+03, 311 iterations, 0.01 seconds (0.03 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0 4219.60678    0   20 4.7017e+11 4219.60678   100%     -    0s\n",
      "     0     0 4224.93256    0   20 4.7017e+11 4224.93256   100%     -    0s\n",
      "H    0     0                    4682.8533151 4224.93256  9.78%     -    0s\n",
      "     0     2 4224.93256    0   20 4682.85332 4224.93256  9.78%     -    0s\n",
      "*  393   447              29    4444.3900834 4247.68560  4.43%   4.9    0s\n",
      "*  515   454              28    4444.0231211 4275.28715  3.80%   5.1    0s\n",
      "*  520   454              30    4436.3316264 4275.28715  3.63%   5.1    0s\n",
      "*  823   634              24    4418.5679328 4280.23644  3.13%   5.4    0s\n",
      "* 1124   741              21    4415.8099733 4288.32835  2.89%   5.6    0s\n",
      "* 1488   758              22    4408.7149614 4320.42840  2.00%   5.6    0s\n",
      "\n",
      "Explored 6175 nodes (37761 simplex iterations) in 0.97 seconds (2.02 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 8: 4408.71 4415.81 4418.57 ... 4.70174e+11\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 4.408714961420e+03, best bound 4.408714961420e+03, gap 0.0000%\n",
      "[CHECK DFN] obj(x_ip)=4408.72  ip_y=4408.71  rel_err=8.593e-07\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 1800\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  1800\n",
      "\n",
      "Optimize a model with 278 rows, 10992 columns and 22004 nonzeros\n",
      "Model fingerprint: 0xf77aa9bf\n",
      "Variable types: 10972 continuous, 20 integer (0 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [3e-02, 1e+00]\n",
      "  Objective range  [2e-01, 4e+08]\n",
      "  Bounds range     [1e+00, 1e+06]\n",
      "  RHS range        [9e-03, 3e+01]\n",
      "Found heuristic solution: objective 7.302132e+09\n",
      "Presolve removed 0 rows and 3975 columns\n",
      "Presolve time: 0.01s\n",
      "Presolved: 278 rows, 7017 columns, 14053 nonzeros\n",
      "Variable types: 6997 continuous, 20 integer (0 binary)\n",
      "\n",
      "Root relaxation: objective -1.040709e+04, 95 iterations, 0.00 seconds (0.00 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0 -10407.090    0    9 7.3021e+09 -10407.090   100%     -    0s\n",
      "H    0     0                    -9989.505102 -10407.090  4.18%     -    0s\n",
      "     0     0 -10397.708    0    9 -9989.5051 -10397.708  4.09%     -    0s\n",
      "H    0     0                    -10066.15804 -10397.708  3.29%     -    0s\n",
      "     0     2 -10397.708    0    9 -10066.158 -10397.708  3.29%     -    0s\n",
      "*   34     7               8    -10072.62095 -10231.426  1.58%  27.4    0s\n",
      "*   39     7               8    -10167.34833 -10224.075  0.56%  27.6    0s\n",
      "\n",
      "Explored 81 nodes (2133 simplex iterations) in 0.07 seconds (0.13 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 5: -10167.3 -10072.6 -10066.2 ... 7.30213e+09\n",
      "No other solutions better than -10167.3\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective -1.016734833008e+04, best bound -1.016734833008e+04, gap 0.0000%\n",
      "[CHECK DFN_AfixI] obj(x_ip)=-10167.4  ip_y=-10167.3  rel_err=2.623e-06\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 1800\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  1800\n",
      "\n",
      "Optimize a model with 1282 rows, 789 columns and 21397 nonzeros\n",
      "Model fingerprint: 0x54e79dc1\n",
      "Variable types: 513 continuous, 276 integer (256 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [2e-06, 6e+01]\n",
      "  Objective range  [7e+04, 7e+04]\n",
      "  Bounds range     [1e+00, 6e+01]\n",
      "  RHS range        [2e-02, 6e+01]\n",
      "Presolve removed 257 rows and 1 columns\n",
      "Presolve time: 0.04s\n",
      "Presolved: 1025 rows, 788 columns, 21005 nonzeros\n",
      "Variable types: 512 continuous, 276 integer (256 binary)\n",
      "\n",
      "Root relaxation: objective -1.473207e+06, 356 iterations, 0.01 seconds (0.02 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0 -1473206.9    0  110          - -1473206.9      -     -    0s\n",
      "     0     0 -577600.45    0  135          - -577600.45      -     -    0s\n",
      "     0     0 -497232.06    0  136          - -497232.06      -     -    0s\n",
      "     0     0 -495930.52    0  139          - -495930.52      -     -    0s\n",
      "     0     0 -495930.49    0  139          - -495930.49      -     -    0s\n",
      "     0     0 -442971.55    0  147          - -442971.55      -     -    0s\n",
      "     0     0 -442043.21    0  148          - -442043.21      -     -    0s\n",
      "     0     0 -441349.19    0  148          - -441349.19      -     -    0s\n",
      "     0     0 -441113.38    0  152          - -441113.38      -     -    0s\n",
      "     0     0 -433920.30    0  153          - -433920.30      -     -    0s\n",
      "     0     0 -432455.66    0  152          - -432455.66      -     -    0s\n",
      "     0     0 -431648.54    0  153          - -431648.54      -     -    0s\n",
      "     0     0 -431374.38    0  153          - -431374.38      -     -    0s\n",
      "     0     0 -428775.04    0  153          - -428775.04      -     -    1s\n",
      "     0     0 -428465.41    0  150          - -428465.41      -     -    1s\n",
      "     0     2 -428465.41    0  150          - -428465.41      -     -    2s\n",
      "  4857  4577 134950.870  158  136          - -360741.00      -  44.0    5s\n",
      "H 4865  4352                    444330.33575 -360741.00   181%  43.9    7s\n",
      "  4882  4367 -345590.11   17  149 444330.336 -360741.00   181%  45.1   10s\n",
      "H 7671  5456                    89716.813744 -326813.58   464%  55.2   14s\n",
      "* 9722  5097             151    54936.380807 -311057.65   666%  51.5   14s\n",
      "  9977  4852     cutoff   61      54936.3808 -311057.65   666%  50.9   15s\n",
      " 23211 12827 27115.5013   65   67 54936.3808 -216329.58   494%  48.4   20s\n",
      " 36136 23679 3679.19150   50   93 54936.3808 -194507.79   454%  48.2   25s\n",
      " 38793 24764 -78109.150   34  150 54936.3808 -191564.58   449%  48.1   43s\n",
      " 38800 24769 -16447.052   54  159 54936.3808 -191564.58   449%  48.0   45s\n",
      "H38800 23530                    54784.562336 -191564.58   450%  48.0   46s\n",
      " 38805 23533 -30881.664   65  158 54784.5623 -191564.58   450%  48.0   50s\n",
      " 38810 23536 -41941.471   46  162 54784.5623 -191564.58   450%  48.0   55s\n",
      " 38815 23540 46278.8746   84  161 54784.5623 -191564.58   450%  48.0   60s\n",
      "H38819 22364                    26131.795071 -191564.58   833%  48.0   65s\n",
      " 38824 22368 21641.6221  158  163 26131.7951 -191564.58   833%  48.0   71s\n",
      "H38825 21249                    26122.066612 -191564.58   833%  48.0   73s\n",
      " 38827 21251 -51460.974   39  165 26122.0666 -191564.58   833%  48.0   76s\n",
      " 38830 21253 -56378.745   45  164 26122.0666 -191564.58   833%  48.0   80s\n",
      " 38833 21255 -139847.41   28  166 26122.0666 -191564.58   833%  48.0   86s\n",
      " 38835 21256 -13177.701   52  164 26122.0666 -191564.58   833%  48.0   90s\n",
      " 38838 21258 11352.5771   40  163 26122.0666 -191564.58   833%  48.0   96s\n",
      " 38841 21260 -125086.34   28  165 26122.0666 -191564.58   833%  48.0  101s\n",
      " 38844 21262 26122.0666   52  166 26122.0666 -191564.58   833%  48.0  106s\n",
      " 38847 21264 -20729.294   41  161 26122.0666 -191564.58   833%  48.0  111s\n",
      " 38850 21266 25228.6129   50  170 26122.0666 -191564.58   833%  48.0  115s\n",
      " 38852 21267 -179270.60   27  167 26122.0666 -191564.58   833%  48.0  120s\n",
      " 38855 21269 -20642.395   51  162 26122.0666 -191564.58   833%  48.0  125s\n",
      " 38858 21271 -3363.5509   42  164 26122.0666 -191564.58   833%  48.0  130s\n",
      " 38861 21273 -68354.567   36  162 26122.0666 -191564.58   833%  48.0  135s\n",
      " 38864 21275 22706.9555   59  164 26122.0666 -191564.58   833%  48.0  140s\n",
      " 38867 21277 26122.0666   40  162 26122.0666 -191564.58   833%  48.0  145s\n",
      " 38871 21280 -142004.85   28  164 26122.0666 -191564.58   833%  48.0  150s\n",
      " 38875 21283 -57590.721   42  165 26122.0666 -191564.58   833%  47.9  155s\n",
      " 38880 21286 -124911.19   36  166 26122.0666 -191564.58   833%  47.9  160s\n",
      " 38883 21288 -184651.85   24  161 26122.0666 -191564.58   833%  47.9  165s\n",
      " 38887 21291 22878.0633   61  161 26122.0666 -191564.58   833%  47.9  171s\n",
      " 38889 21292 -125726.86   39  163 26122.0666 -191564.58   833%  47.9  175s\n",
      " 38892 21294 -79853.190   39  163 26122.0666 -191564.58   833%  47.9  180s\n",
      " 38895 21296 -76696.523   35  162 26122.0666 -191564.58   833%  47.9  185s\n",
      " 38900 21299 -16447.052   54  166 26122.0666 -191564.58   833%  47.9  190s\n",
      " 38904 21302 -20720.881   33  170 26122.0666 -191564.58   833%  47.9  195s\n",
      " 38909 21305 -147556.05   28  170 26122.0666 -191564.58   833%  47.9  200s\n",
      " 38913 21308 19405.9020   56  161 26122.0666 -191564.58   833%  47.9  205s\n",
      " 38918 21311 -45691.042   58  168 26122.0666 -191564.58   833%  47.9  210s\n",
      " 38922 21314 10496.0408   70  170 26122.0666 -191564.58   833%  47.9  215s\n",
      " 38926 21317 26122.0666   66  170 26122.0666 -191564.58   833%  47.9  220s\n",
      " 38929 21319 -135095.95   30  166 26122.0666 -191564.58   833%  47.9  225s\n",
      " 38933 21321 -139847.41   28  164 26122.0666 -191564.58   833%  47.9  230s\n",
      " 38937 21324 5664.91561   53  164 26122.0666 -191564.58   833%  47.9  235s\n",
      " 38941 21327 -125086.34   28  167 26122.0666 -191564.58   833%  47.9  240s\n",
      " 38946 21330 -57984.174   46  167 26122.0666 -191564.58   833%  47.9  246s\n",
      " 38950 21333 25228.6129   50  172 26122.0666 -191564.58   833%  47.9  251s\n",
      " 38953 21335 26122.0666   72  164 26122.0666 -191564.58   833%  47.9  255s\n",
      " 38957 21337 -28739.008   56  163 26122.0666 -191564.58   833%  47.8  260s\n",
      " 38963 21346 -191564.58   30  154 26122.0666 -191564.58   833%  48.2  265s\n",
      " 39862 21910 -191564.58   36  143 26122.0666 -191564.58   833%  49.8  270s\n",
      " 44445 23938 -191564.58   37  145 26122.0666 -191564.58   833%  51.4  275s\n",
      " 50164 26749 -17115.809   61   99 26122.0666 -191564.58   833%  54.0  280s\n",
      " 56140 29316 -112720.89   37  122 26122.0666 -191564.58   833%  54.7  285s\n",
      " 62593 31802     cutoff   54      26122.0666 -191564.58   833%  55.3  290s\n",
      " 72212 35586 -156056.36   39  137 26122.0666 -191564.58   833%  55.9  295s\n",
      " 78329 38383 -124407.92   40  120 26122.0666 -191564.58   833%  56.6  300s\n",
      "H80881 37144                    26028.487984 -191564.58   836%  56.7  301s\n",
      "H83109 37324                    25535.504643 -191564.58   850%  56.8  303s\n",
      "H85470 37565                    25472.246367 -191564.58   852%  57.2  306s\n",
      "H86137 36426                    25453.265797 -191564.58   853%  57.3  306s\n",
      "H86241 35505                    25205.125719 -191564.58   860%  57.3  306s\n",
      "H88473 36872                    25142.179206 -191564.58   862%  57.3  309s\n",
      "H88574 36871                    25102.142220 -191564.58   863%  57.3  309s\n",
      "H89439 36845                    24761.166771 -191564.58   874%  57.4  309s\n",
      " 89695 38262 -158390.31   41  130 24761.1668 -191564.58   874%  57.5  310s\n",
      "H91727 39265                    24626.428736 -191564.58   878%  57.4  311s\n",
      "H92954 40074                    24619.475611 -191564.58   878%  57.4  312s\n",
      "H93446 40074                    24619.456320 -191564.58   878%  57.4  312s\n",
      "H96249 42173                    24484.718285 -191564.58   882%  57.3  314s\n",
      " 97263 42575 -91271.926   47  121 24484.7183 -191564.58   882%  57.3  315s\n",
      "H97645 42575                    24472.112333 -191564.58   883%  57.3  315s\n",
      "H99740 44650                    24469.431683 -191564.58   883%  57.3  317s\n",
      "H99800 44650                    24465.526001 -191564.58   883%  57.3  317s\n",
      " 105057 48778 -156623.04   42  135 24465.5260 -190587.21   879%  57.5  320s\n",
      " 113843 55079 -137183.27   41  134 24465.5260 -186339.37   862%  58.3  325s\n",
      " 121665 60276 -138414.17   43  124 24465.5260 -183491.24   850%  58.6  330s\n",
      " 131555 67279 -32484.491   62   98 24465.5260 -179530.67   834%  58.6  335s\n",
      " 138715 71464 -113085.97   45  124 24465.5260 -176797.20   823%  58.6  340s\n",
      "H138927 71448                    24382.980125 -176735.62   825%  58.6  340s\n",
      " 145652 76238 -91500.545   44  126 24382.9801 -174535.53   816%  58.7  346s\n",
      " 153312 81550 -46951.178   46  129 24382.9801 -172141.77   806%  58.7  350s\n",
      " 162164 87460 -122059.12   46  126 24382.9801 -170163.83   798%  59.0  356s\n",
      " 169358 92303 -25842.243   58  102 24382.9801 -168505.77   791%  59.1  360s\n",
      " 176905 97369 -6978.4114   66   92 24382.9801 -166470.34   783%  59.0  366s\n",
      " 181951 100877 -37429.715   54  115 24382.9801 -165264.75   778%  59.1  370s\n",
      " 191240 107225 -33416.785   53  113 24382.9801 -163609.82   771%  59.3  375s\n",
      " 197255 111352 -129807.26   43  134 24382.9801 -162350.03   766%  59.4  380s\n",
      " 205765 117146 -125254.27   40  132 24382.9801 -160810.41   760%  59.8  385s\n",
      " 213132 121940 -112020.55   46  121 24382.9801 -159443.83   754%  60.0  390s\n",
      " 219098 125440 -110710.10   43  129 24382.9801 -158390.71   750%  60.1  395s\n",
      " 223219 128686     cutoff   57      24382.9801 -157560.81   746%  60.1  400s\n",
      " 231161 133807 -122520.99   44  133 24382.9801 -156305.44   741%  60.3  405s\n",
      " 238813 139005 -62357.479   44  118 24382.9801 -155063.64   736%  60.4  410s\n",
      " 243973 142296 -61486.256   45  117 24382.9801 -154450.83   733%  60.5  415s\n",
      " 252967 148438 14310.6824   59  106 24382.9801 -153378.15   729%  60.5  420s\n",
      " 263232 154997 12713.1739   57  100 24382.9801 -152023.85   723%  60.6  425s\n",
      " 274746 162321 -48004.804   50  117 24382.9801 -150722.12   718%  60.7  430s\n",
      " 284414 168487     cutoff   59      24382.9801 -149789.13   714%  60.9  435s\n",
      " 295885 176352 11808.5320   55  109 24382.9801 -148427.90   709%  60.9  440s\n",
      "H300569 178577                    23258.166314 -147878.18   736%  61.0  442s\n",
      "H302259 179763                    23131.506401 -147713.17   739%  61.0  444s\n",
      " 303445 180480 -65859.775   49  124 23131.5064 -147552.74   738%  60.9  445s\n",
      " 314278 187495 -26525.253   60   96 23131.5064 -146418.33   733%  61.1  450s\n",
      " 324875 193868 -84769.193   52  112 23131.5064 -145399.44   729%  61.2  455s\n",
      " 334891 199935 17457.4087   57  109 23131.5064 -144476.50   725%  61.2  460s\n",
      " 345010 206026 8170.41900   60  103 23131.5064 -143587.25   721%  61.3  465s\n",
      " 356873 213119 -113383.07   45  126 23131.5064 -142418.13   716%  61.3  470s\n",
      "H365254 218228                    23057.652225 -141683.25   714%  61.5  475s\n",
      "H365330 218184                    22973.647679 -141683.25   717%  61.5  475s\n",
      "H365914 218169                    22937.626090 -141669.75   718%  61.5  475s\n",
      " 366262 218207 -68354.732   49  100 22937.6261 -141624.50   717%  61.5  480s\n",
      " 366439 218378 -43407.304   54   94 22937.6261 -141624.50   717%  61.5  485s\n",
      "H367006 218803                    22799.691444 -141624.50   721%  61.5  488s\n",
      " 369342 221030     cutoff   64      22799.6914 -141370.70   720%  61.5  490s\n",
      "H369751 220826                    22348.978387 -141370.70   733%  61.5  490s\n",
      "H371477 221719                    22258.086385 -141248.86   735%  61.5  491s\n",
      "H375261 223369                    22093.565887 -140862.95   738%  61.5  493s\n",
      " 376505 225594 -24931.107   55  106 22093.5659 -140701.72   737%  61.5  496s\n",
      "H377561 225586                    22073.685151 -140694.47   737%  61.5  496s\n",
      "H378757 225581                    22064.512753 -140646.21   737%  61.5  496s\n",
      "H381237 227291                    21973.051636 -140308.12   739%  61.5  497s\n",
      "H381927 227273                    21942.364541 -140278.92   739%  61.5  497s\n",
      "H384053 229128                    21888.598134 -140109.52   740%  61.6  499s\n",
      "H384135 228891                    21421.315692 -140107.47   754%  61.6  499s\n",
      " 385067 229916 -43650.955   47  120 21421.3157 -140023.26   754%  61.6  500s\n",
      "H386772 230497                    21292.963693 -139899.25   757%  61.6  501s\n",
      "H387890 231275                    21278.924014 -139825.22   757%  61.7  502s\n",
      "H389625 232236                    21245.404231 -139711.86   758%  61.7  502s\n",
      "H389946 232235                    21242.729154 -139711.86   758%  61.7  502s\n",
      "H392611 233986                    21210.083423 -139440.20   757%  61.8  504s\n",
      "H393043 233914                    21058.736384 -139429.61   762%  61.8  504s\n",
      " 393430 234864 -69693.465   50  116 21058.7364 -139369.64   762%  61.8  505s\n",
      " 402258 239794 20450.8797   61  106 21058.7364 -138592.67   758%  61.9  510s\n",
      " 412738 246812 -107041.24   42  128 21058.7364 -137751.85   754%  61.9  515s\n",
      " 421843 252524 -93343.780   52  111 21058.7364 -137108.06   751%  62.1  520s\n",
      " 433663 259018 19391.5790   61   91 21058.7364 -136310.47   747%  62.1  525s\n",
      " 441947 264077 -11036.899   60  103 21058.7364 -135787.41   745%  62.1  530s\n",
      " 450746 269308 16863.5457   41  105 21058.7364 -135194.78   742%  62.1  536s\n",
      " 456606 272923 -52498.017   46  121 21058.7364 -134807.24   740%  62.2  540s\n",
      " 463371 277282 -14872.139   44  112 21058.7364 -134398.68   738%  62.3  545s\n",
      "H466761 278808                    21037.711808 -134125.09   738%  62.3  547s\n",
      "H469038 280051                    21035.278878 -134009.81   737%  62.3  549s\n",
      " 469542 281139 -29612.089   45  111 21035.2789 -133925.52   737%  62.3  550s\n",
      "H471766 281914                    21029.045279 -133800.71   736%  62.3  551s\n",
      " 475027 284664 18497.9476   64   99 21029.0453 -133577.89   735%  62.4  556s\n",
      " 480225 287299 -81266.487   46  122 21029.0453 -133224.53   734%  62.5  560s\n",
      " 487930 292080 -26637.681   56  106 21029.0453 -132705.98   731%  62.5  565s\n",
      " 496825 297530 -91743.600   45  123 21029.0453 -132230.74   729%  62.6  570s\n",
      " 508369 303322 -103099.21   46  128 21029.0453 -131585.40   726%  62.7  575s\n",
      " 517681 309582 -14414.777   58  111 21029.0453 -131085.11   723%  62.7  580s\n",
      " 527797 315611 -46721.084   50  112 21029.0453 -130476.95   720%  62.7  585s\n",
      " 537874 321668 19937.7192   59  107 21029.0453 -129934.06   718%  62.8  590s\n",
      " 547399 327180 -93086.245   45  128 21029.0453 -129381.97   715%  62.9  595s\n",
      "H550810 328419                    21000.372813 -129225.46   715%  62.9  596s\n",
      "H550963 328407                    20985.838363 -129221.89   716%  62.9  596s\n",
      " 555802 331700 -19440.845   61  100 20985.8384 -128931.63   714%  62.9  600s\n",
      " 566741 338384 7779.84025   60  104 20985.8384 -128390.80   712%  63.0  605s\n",
      " 576678 344218 -79570.809   51  116 20985.8384 -127833.04   709%  63.0  610s\n",
      " 586573 350250 12925.6793   65   89 20985.8384 -127329.89   707%  63.1  615s\n",
      " 596810 356006 -79769.501   45  126 20985.8384 -126905.81   705%  63.1  620s\n",
      " 608213 362010 -75315.613   48  120 20985.8384 -126421.96   702%  63.1  625s\n",
      " 617462 368002 -78840.563   46  120 20985.8384 -126025.81   701%  63.2  630s\n",
      " 627750 373585 -47584.434   44  127 20985.8384 -125556.35   698%  63.2  635s\n",
      "H634720 377038                    20852.340468 -125303.25   701%  63.3  638s\n",
      "H634999 377025                    20844.146637 -125303.25   701%  63.3  638s\n",
      " 636643 378446 -3249.6964   59  100 20844.1466 -125220.84   701%  63.3  640s\n",
      " 648538 385590 -82209.196   48  125 20844.1466 -124691.08   698%  63.3  645s\n",
      " 656950 390201 -59601.815   51  121 20844.1466 -124297.49   696%  63.3  650s\n",
      " 667711 396468 -30890.470   53  110 20844.1466 -123832.97   694%  63.3  655s\n",
      " 679067 402581 -59051.690   50  122 20844.1466 -123412.51   692%  63.2  660s\n",
      " 688610 408223 -74973.043   54  111 20844.1466 -122963.22   690%  63.3  665s\n",
      " 698894 414034 -52435.190   48  130 20844.1466 -122535.67   688%  63.3  670s\n",
      " 709101 420017 -96307.289   44  135 20844.1466 -122190.41   686%  63.4  675s\n",
      " 719810 426244 -69362.481   46  122 20844.1466 -121786.13   684%  63.5  680s\n",
      " 730494 432352 15880.1897   51  116 20844.1466 -121421.54   683%  63.5  685s\n",
      " 740119 438117 -105803.78   46  130 20844.1466 -121030.48   681%  63.6  690s\n",
      " 750475 443533 -93266.191   46  125 20844.1466 -120667.32   679%  63.6  695s\n",
      " 760433 449550 16271.4607   57  110 20844.1466 -120362.31   677%  63.6  700s\n",
      " 770559 455369 2917.01759   55  111 20844.1466 -119968.67   676%  63.7  705s\n",
      " 780060 460966 -77503.787   45  126 20844.1466 -119652.32   674%  63.7  710s\n",
      "H790484 466330                    20840.096000 -119284.98   672%  63.7  715s\n",
      " 790502 474495 -2947.7944   58  104 20840.0960 -119284.98   672%  63.7  722s\n",
      "H804556 474730                    20831.432358 -118850.79   671%  63.7  723s\n",
      " 807480 477017 17168.1884   53  115 20831.4324 -118713.68   670%  63.8  725s\n",
      " 817921 482518 16819.0551   70   84 20831.4324 -118394.56   668%  63.8  730s\n",
      " 828957 488689 -78340.929   47  119 20831.4324 -118063.16   667%  63.9  735s\n",
      " 839190 494372 3804.11580   56  112 20831.4324 -117718.40   665%  63.8  740s\n",
      " 849642 500266 1648.67678   49  111 20831.4324 -117369.95   663%  63.9  745s\n",
      " 859737 506259     cutoff   56      20831.4324 -117046.65   662%  63.9  750s\n",
      "H862780 507488                    20816.333797 -116943.07   662%  63.9  751s\n",
      "H863165 507465                    20799.856972 -116938.95   662%  63.9  751s\n",
      "H865365 508724                    20766.610443 -116870.53   663%  63.9  753s\n",
      " 868819 511152 -72890.321   53  109 20766.6104 -116773.26   662%  63.9  755s\n",
      " 879121 516873 13358.9380   55  104 20766.6104 -116444.18   661%  63.9  760s\n",
      " 889389 522162 -40792.613   48  114 20766.6104 -116141.56   659%  64.0  765s\n",
      " 899820 528571 -26203.202   51  109 20766.6104 -115798.21   658%  64.0  770s\n",
      " 909923 534339 -52554.299   53  110 20766.6104 -115462.37   656%  64.0  775s\n",
      " 919686 540011 -72204.125   50  124 20766.6104 -115147.24   654%  64.0  780s\n",
      " 929779 545054 15883.2632   61   92 20766.6104 -114878.61   653%  64.0  785s\n",
      " 940843 551416 15115.6885   65   96 20766.6104 -114596.55   652%  64.1  790s\n",
      " 950725 556591 -57426.793   49  125 20766.6104 -114316.19   650%  64.1  795s\n",
      " 960630 562131 -72205.387   48  124 20766.6104 -114049.46   649%  64.1  800s\n",
      " 969848 567201 11795.2378   55  108 20766.6104 -113790.51   648%  64.1  805s\n",
      " 980546 573439 -64564.543   48  122 20766.6104 -113493.08   647%  64.2  810s\n",
      " 991325 579338 -56733.776   57  106 20766.6104 -113225.08   645%  64.2  815s\n",
      " 1000368 584224 -9460.0803   50  106 20766.6104 -113021.63   644%  64.2  820s\n",
      " 1010746 590155 -79395.487   52  111 20766.6104 -112763.76   643%  64.3  825s\n",
      " 1021100 595807 10724.1918   55  111 20766.6104 -112485.14   642%  64.3  830s\n",
      " 1031356 601181     cutoff   59      20766.6104 -112220.21   640%  64.3  835s\n",
      " 1041223 606782 -7450.4371   57  102 20766.6104 -111961.46   639%  64.3  840s\n",
      " 1049757 611400 -19356.789   51  120 20766.6104 -111739.71   638%  64.3  845s\n",
      " 1059406 616866 -58331.192   50  120 20766.6104 -111488.06   637%  64.3  850s\n",
      " 1069793 622812 -51390.723   58  107 20766.6104 -111248.16   636%  64.3  855s\n",
      " 1080248 628662 -79695.782   47  114 20766.6104 -110977.92   634%  64.4  860s\n",
      " 1090512 634169 -64247.888   48  120 20766.6104 -110756.55   633%  64.4  865s\n",
      " 1100663 639693 -78486.475   46  125 20766.6104 -110541.70   632%  64.4  870s\n",
      "H1108334 643598                    20748.223891 -110349.04   632%  64.4  874s\n",
      " 1108910 644960 -38456.452   50  124 20748.2239 -110326.37   632%  64.4  875s\n",
      " 1118765 649707 13900.1355   54  105 20748.2239 -110084.08   631%  64.5  880s\n",
      " 1127642 654495 15525.1457   66  101 20748.2239 -109866.15   630%  64.5  885s\n",
      " 1137654 660010 -70250.830   45  123 20748.2239 -109661.61   629%  64.5  890s\n",
      " 1146867 664969 -66611.358   51  116 20748.2239 -109448.76   628%  64.6  895s\n",
      " 1156975 670036 5274.26252   57   93 20748.2239 -109256.28   627%  64.6  900s\n",
      " 1168203 676062 -36637.915   47  117 20748.2239 -109001.02   625%  64.6  905s\n",
      " 1175410 679850     cutoff   56      20748.2239 -108830.94   625%  64.6  910s\n",
      " 1185915 685403 17193.6960   56  105 20748.2239 -108615.58   623%  64.6  915s\n",
      " 1195981 690505 -3245.0945   54  100 20748.2239 -108392.42   622%  64.6  920s\n",
      " 1205668 696137 -67205.689   47  116 20748.2239 -108149.26   621%  64.6  925s\n",
      " 1215606 701573 20311.4485   62   95 20748.2239 -107940.88   620%  64.6  930s\n",
      " 1224600 706239 -52849.098   50  122 20748.2239 -107774.59   619%  64.6  935s\n",
      " 1226141 713767 -74487.625   50  116 20748.2239 -107741.19   619%  64.6  946s\n",
      " 1243804 716447 -10942.264   53  115 20748.2239 -107414.04   618%  64.7  950s\n",
      " 1251314 719780 -78386.406   54  104 20748.2239 -107250.00   617%  64.7  955s\n",
      " 1257576 723738 -63662.670   45  122 20748.2239 -107157.16   616%  64.7  960s\n",
      " 1268563 729392 -78020.373   49  111 20748.2239 -106944.57   615%  64.8  965s\n",
      " 1276988 733954 -22655.518   51  123 20748.2239 -106771.62   615%  64.8  970s\n",
      " 1286691 739382 -90197.040   49  118 20748.2239 -106569.00   614%  64.8  975s\n",
      "H1286807 739303                    20701.896832 -106569.00   615%  64.8  975s\n",
      " 1293694 743244 19655.4074   75   79 20701.8968 -106439.58   614%  64.8  980s\n",
      " 1303828 748491 2169.80075   52  117 20701.8968 -106251.66   613%  64.8  985s\n",
      " 1313918 753973 -54117.817   49  126 20701.8968 -106034.64   612%  64.8  990s\n",
      " 1322755 758166 -77554.924   47  122 20701.8968 -105857.74   611%  64.8  995s\n",
      " 1332362 763892 -21467.588   50  112 20701.8968 -105689.67   611%  64.8 1000s\n",
      " 1342206 768874 -52477.951   49  120 20701.8968 -105485.27   610%  64.8 1005s\n",
      " 1350926 773244 -68852.088   46  131 20701.8968 -105327.26   609%  64.9 1010s\n",
      " 1359555 778112     cutoff   53      20701.8968 -105170.98   608%  64.9 1015s\n",
      " 1369250 783132 7476.75975   55  113 20701.8968 -104994.43   607%  64.9 1020s\n",
      " 1377703 787666 5105.11790   54  109 20701.8968 -104819.18   606%  64.9 1025s\n",
      " 1388010 793246 -54098.192   49  112 20701.8968 -104635.76   605%  64.9 1030s\n",
      " 1396189 797782     cutoff   73      20701.8968 -104466.42   605%  64.9 1035s\n",
      " 1403801 801669 17382.3741   63  101 20701.8968 -104316.35   604%  64.9 1040s\n",
      " 1412722 806454 -57123.440   45  123 20701.8968 -104152.58   603%  64.9 1045s\n",
      " 1421403 810593 -10682.910   49  112 20701.8968 -104000.02   602%  64.9 1050s\n",
      " 1431681 815633     cutoff   44      20701.8968 -103829.31   602%  64.9 1055s\n",
      " 1440514 820476 -6610.4233   54  114 20701.8968 -103673.39   601%  64.9 1060s\n",
      " 1449846 825047 14842.1574   51  112 20701.8968 -103494.59   600%  64.9 1065s\n",
      " 1459917 829682 11150.5522   59  105 20701.8968 -103336.14   599%  64.9 1070s\n",
      " 1468338 834107 -70965.437   46  128 20701.8968 -103190.08   598%  65.0 1075s\n",
      " 1478583 839203 11490.2283   62   79 20701.8968 -103018.80   598%  65.0 1080s\n",
      " 1485565 842867 -33637.508   57  113 20701.8968 -102920.68   597%  65.0 1085s\n",
      " 1494272 847484 -78939.807   47  117 20701.8968 -102768.33   596%  65.0 1090s\n",
      " 1502300 852042 -43452.701   53  108 20701.8968 -102633.10   596%  65.0 1095s\n",
      " 1511024 856750 -58790.294   50  121 20701.8968 -102484.59   595%  65.0 1100s\n",
      " 1521358 862444 13105.7675   55  115 20701.8968 -102318.56   594%  65.0 1105s\n",
      " 1531707 867463 -43548.678   52  118 20701.8968 -102162.49   593%  65.1 1110s\n",
      " 1540161 872468 5219.53896   52   93 20701.8968 -102035.47   593%  65.1 1115s\n",
      " 1550362 877557 16773.0217   54  108 20701.8968 -101880.98   592%  65.1 1120s\n",
      " 1551962 877588     cutoff   61      20701.8968 -101859.16   592%  65.1 1125s\n",
      " 1552263 877900 -53978.270   49  131 20701.8968 -101858.69   592%  65.1 1130s\n",
      " 1554723 880126 -32612.764   45  126 20701.8968 -101839.94   592%  65.1 1135s\n",
      " 1563806 884673 -72345.514   46  125 20701.8968 -101666.00   591%  65.1 1140s\n",
      " 1571768 889062 -53671.727   52  117 20701.8968 -101541.37   590%  65.1 1145s\n",
      " 1582433 894403     cutoff   61      20701.8968 -101383.87   590%  65.1 1150s\n",
      " 1591050 898631 -48562.335   49  125 20701.8968 -101244.90   589%  65.1 1155s\n",
      " 1599406 903068 -605.17379   56  104 20701.8968 -101098.33   588%  65.2 1160s\n",
      " 1609665 908270     cutoff   60      20701.8968 -100942.07   588%  65.2 1165s\n",
      " 1620065 913797 -38882.564   61   99 20701.8968 -100789.48   587%  65.2 1170s\n",
      " 1628769 918506 -53638.267   51  121 20701.8968 -100665.53   586%  65.2 1175s\n",
      " 1639565 924492 -71483.882   46  123 20701.8968 -100491.54   585%  65.2 1180s\n",
      " 1647791 928357 -74708.365   48  120 20701.8968 -100366.79   585%  65.2 1185s\n",
      " 1657073 933457 -67080.029   47  123 20701.8968 -100248.92   584%  65.2 1190s\n",
      " 1667758 938645 -51343.749   49  120 20701.8968 -100102.15   584%  65.2 1195s\n",
      " 1677565 943902     cutoff   65      20701.8968 -99955.334   583%  65.2 1200s\n",
      " 1687440 948949 -63157.060   47  123 20701.8968 -99816.516   582%  65.2 1205s\n",
      " 1697580 954278 -16417.749   52  113 20701.8968 -99662.340   581%  65.2 1210s\n",
      " 1706321 958378 -76904.331   40  119 20701.8968 -99542.552   581%  65.2 1215s\n",
      " 1716005 963367 -72126.013   47  118 20701.8968 -99419.284   580%  65.3 1220s\n",
      " 1725273 968225 -45635.794   51  117 20701.8968 -99294.292   580%  65.3 1225s\n",
      " 1735598 973389 8652.22145   63  102 20701.8968 -99157.780   579%  65.3 1230s\n",
      " 1743495 977829 -13023.206   60  104 20701.8968 -99040.291   578%  65.3 1235s\n",
      " 1753689 982633 6886.86328   52  106 20701.8968 -98909.569   578%  65.3 1240s\n",
      " 1763042 987578     cutoff   66      20701.8968 -98777.762   577%  65.3 1245s\n",
      " 1772978 992923 18884.9986   53  112 20701.8968 -98638.083   576%  65.3 1250s\n",
      " 1783316 998029 6434.03169   57  114 20701.8968 -98491.236   576%  65.3 1255s\n",
      " 1792075 1002010 19170.5022   60   99 20701.8968 -98366.616   575%  65.3 1260s\n",
      " 1801338 1006679 1076.78012   50  113 20701.8968 -98254.156   575%  65.4 1265s\n",
      " 1810186 1011286 -59297.559   51  114 20701.8968 -98125.569   574%  65.4 1270s\n",
      " 1820297 1015842 -895.46476   51  119 20701.8968 -97984.649   573%  65.4 1275s\n",
      " 1829899 1021375 3569.78852   54  109 20701.8968 -97852.164   573%  65.4 1280s\n",
      " 1839960 1026592 -46955.655   47  120 20701.8968 -97705.401   572%  65.4 1285s\n",
      " 1850139 1031411 17583.2876   60   97 20701.8968 -97562.942   571%  65.4 1290s\n",
      " 1859851 1036261 -21127.534   54  114 20701.8968 -97432.889   571%  65.4 1295s\n",
      " 1868266 1041006 -36769.791   64   94 20701.8968 -97319.795   570%  65.4 1300s\n",
      " 1877553 1045694 -2720.5451   52  109 20701.8968 -97184.221   569%  65.4 1305s\n",
      " 1887630 1050383 -48764.366   50  120 20701.8968 -97075.721   569%  65.4 1310s\n",
      " 1895074 1054800 -3751.4581   53  109 20701.8968 -96968.139   568%  65.4 1315s\n",
      " 1906219 1060002     cutoff   60      20701.8968 -96820.694   568%  65.4 1320s\n",
      " 1915719 1064647 -20889.141   52  107 20701.8968 -96701.495   567%  65.4 1325s\n",
      " 1924675 1069261 -52713.127   47  127 20701.8968 -96591.209   567%  65.4 1330s\n",
      " 1934697 1074164 6708.06128   64   92 20701.8968 -96463.095   566%  65.5 1335s\n",
      " 1944568 1079119 -38161.084   57  106 20701.8968 -96336.442   565%  65.5 1340s\n",
      " 1954602 1084153 -52465.102   49  119 20701.8968 -96210.599   565%  65.5 1345s\n",
      " 1964839 1089234 -70177.905   54  110 20701.8968 -96095.403   564%  65.5 1350s\n",
      " 1974754 1094084 -29339.151   55  113 20701.8968 -95969.711   564%  65.5 1355s\n",
      " 1983222 1098488 -18066.369   65   98 20701.8968 -95871.210   563%  65.5 1360s\n",
      " 1993355 1103347 3319.89560   59  106 20701.8968 -95732.496   562%  65.5 1365s\n",
      " 2000989 1107108 -33153.965   46  113 20701.8968 -95631.304   562%  65.5 1370s\n",
      " 2009054 1111306     cutoff   70      20701.8968 -95527.423   561%  65.5 1375s\n",
      " 2018419 1116070     cutoff   56      20701.8968 -95409.468   561%  65.5 1380s\n",
      " 2026947 1120341 -75998.632   48  123 20701.8968 -95297.771   560%  65.5 1385s\n",
      " 2035275 1124445 -49344.083   48  119 20701.8968 -95195.375   560%  65.5 1390s\n",
      " 2045600 1129569 6200.66280   48  117 20701.8968 -95059.692   559%  65.6 1395s\n",
      " 2054371 1134070 -65738.307   49  119 20701.8968 -94944.976   559%  65.6 1400s\n",
      " 2062690 1138125 -65480.113   45  129 20701.8968 -94827.804   558%  65.6 1405s\n",
      " 2069230 1141372 15863.0512   60  100 20701.8968 -94740.118   558%  65.6 1410s\n",
      " 2078651 1145987 -57523.075   47  123 20701.8968 -94619.049   557%  65.6 1415s\n",
      " 2086587 1149846 14423.9181   64  104 20701.8968 -94524.449   557%  65.7 1420s\n",
      " 2094053 1153299 -68054.133   54  106 20701.8968 -94431.065   556%  65.7 1425s\n",
      " 2103164 1157498 12240.9841   59  103 20701.8968 -94334.797   556%  65.7 1430s\n",
      " 2111607 1161486     cutoff   56      20701.8968 -94233.020   555%  65.7 1435s\n",
      " 2117959 1164605 -70936.596   46  112 20701.8968 -94150.269   555%  65.7 1440s\n",
      " 2126097 1168208 -34192.860   52  107 20701.8968 -94059.481   554%  65.8 1445s\n",
      " 2134864 1172231 -48138.910   45  125 20701.8968 -93966.187   554%  65.8 1450s\n",
      " 2143431 1176498 -71507.716   50  114 20701.8968 -93866.447   553%  65.8 1455s\n",
      " 2151887 1180308 -12581.306   51  116 20701.8968 -93774.380   553%  65.8 1460s\n",
      " 2157535 1183204 19668.5819   54  104 20701.8968 -93710.217   553%  65.8 1465s\n",
      " 2166326 1187081 -64214.676   54  115 20701.8968 -93595.600   552%  65.9 1470s\n",
      " 2173539 1190632 -23919.954   55  107 20701.8968 -93511.987   552%  65.9 1475s\n",
      " 2181901 1194546 -61752.586   46  123 20701.8968 -93427.806   551%  65.9 1480s\n",
      " 2190139 1198482 3458.71424   47  106 20701.8968 -93318.002   551%  65.9 1485s\n",
      " 2197493 1202148 -59421.058   50  105 20701.8968 -93227.429   550%  65.9 1490s\n",
      " 2206027 1205995     cutoff   73      20701.8968 -93128.320   550%  66.0 1495s\n",
      " 2213463 1209920 -327.70429   51  104 20701.8968 -93044.223   549%  66.0 1500s\n",
      " 2221377 1213402 -67161.099   50  120 20701.8968 -92957.845   549%  66.0 1505s\n",
      " 2229730 1217506 -49273.708   49  122 20701.8968 -92853.008   549%  66.0 1510s\n",
      " 2236624 1220533 -54890.233   45  126 20701.8968 -92775.151   548%  66.0 1515s\n",
      " 2245321 1224662 -47411.048   56   94 20701.8968 -92680.022   548%  66.1 1520s\n",
      " 2252501 1228338 -26402.986   49  116 20701.8968 -92613.640   547%  66.1 1525s\n",
      " 2260902 1232247 1295.65112   58  105 20701.8968 -92514.912   547%  66.1 1530s\n",
      " 2268112 1235365 -45494.550   47  127 20701.8968 -92435.151   547%  66.2 1535s\n",
      " 2275271 1238949 -58531.476   46  122 20701.8968 -92359.263   546%  66.2 1540s\n",
      " 2282665 1242449 -59418.076   48  112 20701.8968 -92274.221   546%  66.2 1545s\n",
      " 2291238 1246740     cutoff   59      20701.8968 -92175.984   545%  66.2 1550s\n",
      " 2297060 1248666 -54907.383   45  122 20701.8968 -92126.152   545%  66.2 1555s\n",
      "H2299090 1241650                    18498.265007 -92103.786   598%  66.2 1557s\n",
      " 2301536 1243055 -61935.925   46  121 18498.2650 -92079.337   598%  66.3 1560s\n",
      " 2308337 1246431 -61840.586   47  124 18498.2650 -92003.406   597%  66.3 1565s\n",
      " 2315804 1249828 -53977.292   50  119 18498.2650 -91921.213   597%  66.3 1570s\n",
      " 2322875 1252743 -17775.338   47  115 18498.2650 -91842.939   596%  66.3 1575s\n",
      " 2330992 1256614     cutoff   53      18498.2650 -91748.468   596%  66.4 1580s\n",
      " 2338453 1259354 -35937.710   50  118 18498.2650 -91653.850   595%  66.4 1585s\n",
      " 2344736 1262523 -57738.614   46  129 18498.2650 -91582.669   595%  66.4 1590s\n",
      " 2351812 1265842 -31803.917   53  112 18498.2650 -91508.444   595%  66.4 1595s\n",
      " 2360804 1269531 -54079.269   48  125 18498.2650 -91406.932   594%  66.5 1600s\n",
      " 2367504 1272448 1374.63410   58  101 18498.2650 -91344.750   594%  66.5 1605s\n",
      " 2374665 1275154 -43286.295   50  116 18498.2650 -91264.960   593%  66.5 1610s\n",
      "H2374676 1274365                    18295.229063 -91264.960   599%  66.5 1610s\n",
      "H2374683 1274094                    18158.383452 -91264.960   603%  66.5 1612s\n",
      " 2378053 1276090 -14400.747   51  113 18158.3835 -91232.925   602%  66.6 1615s\n",
      " 2385436 1279066 -56750.371   44  120 18158.3835 -91149.738   602%  66.6 1620s\n",
      " 2392495 1282642 10322.4775   52  105 18158.3835 -91069.916   602%  66.6 1625s\n",
      " 2400008 1285820 -28459.016   52  110 18158.3835 -90988.589   601%  66.6 1630s\n",
      " 2406189 1288585 -38218.958   50  118 18158.3835 -90923.652   601%  66.7 1635s\n",
      " 2414001 1291939 -52542.029   47  128 18158.3835 -90836.486   600%  66.7 1640s\n",
      " 2420646 1294666 -64809.572   49  117 18158.3835 -90760.009   600%  66.7 1645s\n",
      " 2427684 1298036 -48541.704   54  110 18158.3835 -90681.952   599%  66.7 1650s\n",
      " 2434963 1301286 -62321.332   48  121 18158.3835 -90598.090   599%  66.8 1655s\n",
      " 2441904 1303949 -58413.534   48  121 18158.3835 -90518.113   598%  66.8 1660s\n",
      " 2448514 1305987     cutoff   60      18158.3835 -90450.003   598%  66.8 1667s\n",
      " 2448553 1306013     cutoff   60      18158.3835 -90449.366   598%  66.8 1670s\n",
      " 2448621 1306058 -30930.458   53  103 18158.3835 -90446.619   598%  66.8 1676s\n",
      "H2448650 1306096                    18154.161749 -90446.619   598%  66.8 1677s\n",
      "H2448651 1305996                    18132.258158 -90445.780   599%  66.8 1677s\n",
      "H2448731 1306081                    18098.494924 -90445.780   600%  66.8 1678s\n",
      "H2449087 1305955                    18069.638236 -90442.894   601%  66.8 1678s\n",
      " 2451016 1307366 9471.57267   57  106 18069.6382 -90423.319   600%  66.8 1681s\n",
      " 2457015 1309923 -33425.992   52  116 18069.6382 -90376.873   600%  66.8 1685s\n",
      " 2464177 1312608 -41314.614   46  128 18069.6382 -90299.189   600%  66.9 1690s\n",
      " 2470643 1315516 -46827.275   51  113 18069.6382 -90246.232   599%  66.9 1695s\n",
      " 2477435 1318724 -8423.3291   55  104 18069.6382 -90162.407   599%  66.9 1700s\n",
      " 2482973 1320874 -16544.061   74   85 18069.6382 -90097.815   599%  66.9 1705s\n",
      " 2490657 1323949 -23270.509   54  110 18069.6382 -90029.153   598%  67.0 1710s\n",
      " 2497979 1326922 -7170.1075   49  121 18069.6382 -89955.199   598%  67.0 1715s\n",
      " 2504923 1329959 -11380.264   48  117 18069.6382 -89878.583   597%  67.0 1720s\n",
      " 2511333 1332975 -42643.474   54  105 18069.6382 -89798.141   597%  67.0 1725s\n",
      " 2519130 1336229 -12464.366   50  113 18069.6382 -89712.206   596%  67.1 1730s\n",
      " 2526207 1339053 -5523.4295   62   95 18069.6382 -89642.096   596%  67.1 1735s\n",
      " 2533577 1342361 1539.29611   58   98 18069.6382 -89560.534   596%  67.1 1740s\n",
      " 2539401 1344598     cutoff   60      18069.6382 -89502.907   595%  67.1 1745s\n",
      " 2545949 1347489 -52156.157   47  123 18069.6382 -89440.789   595%  67.2 1750s\n",
      " 2553033 1350637 -1464.4950   62   95 18069.6382 -89366.981   595%  67.2 1755s\n",
      " 2560000 1353718 -9751.0414   54  104 18069.6382 -89285.816   594%  67.2 1760s\n",
      " 2566887 1356640     cutoff   53      18069.6382 -89222.886   594%  67.2 1765s\n",
      " 2573026 1358849  715.56332   55  113 18069.6382 -89156.950   593%  67.2 1770s\n",
      " 2579552 1361419 -52827.642   47  125 18069.6382 -89092.411   593%  67.3 1775s\n",
      " 2586799 1364349 15874.3606   53  105 18069.6382 -89022.453   593%  67.3 1780s\n",
      " 2593705 1367257 -7989.3706   61   94 18069.6382 -88956.602   592%  67.3 1785s\n",
      " 2599416 1369598 2654.24963   52  115 18069.6382 -88904.188   592%  67.3 1790s\n",
      " 2606353 1372883 -38968.452   47  124 18069.6382 -88835.414   592%  67.4 1796s\n",
      " 2612724 1374749     cutoff   62      18069.6382 -88790.054   591%  67.4 1800s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 65\n",
      "  Implied bound: 2\n",
      "  MIR: 2518\n",
      "  Mixing: 4\n",
      "  Flow cover: 724\n",
      "  RLT: 1479\n",
      "  Relax-and-lift: 44\n",
      "\n",
      "Explored 2613122 nodes (176076820 simplex iterations) in 1800.02 seconds (4896.38 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 10: 18069.6 18098.5 18132.3 ... 20766.6\n",
      "\n",
      "Time limit reached\n",
      "Best objective 1.806963823565e+04, best bound -8.878685237619e+04, gap 591.3593%\n",
      "[CHECK MLP] obj(x_ip)=18069.6  ip_y=18069.6  rel_err=1.597e-06\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 1800\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  1800\n",
      "\n",
      "Optimize a model with 1001 rows, 21 columns and 21020 nonzeros\n",
      "Model fingerprint: 0x5e766f1b\n",
      "Variable types: 1 continuous, 20 integer (0 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [7e-07, 1e+00]\n",
      "  Objective range  [7e+04, 7e+04]\n",
      "  Bounds range     [5e+01, 5e+01]\n",
      "  RHS range        [1e+00, 3e+01]\n",
      "Found heuristic solution: objective 223920.23254\n",
      "Presolve time: 0.00s\n",
      "Presolved: 1001 rows, 21 columns, 21020 nonzeros\n",
      "Variable types: 1 continuous, 20 integer (0 binary)\n",
      "\n",
      "Root relaxation: objective -2.612233e+04, 14 iterations, 0.00 seconds (0.01 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0 -26122.331    0    7 223920.233 -26122.331   112%     -    0s\n",
      "H    0     0                    -26005.88445 -26122.331  0.45%     -    0s\n",
      "H    0     0                    -26005.88445 -26122.331  0.45%     -    0s\n",
      "     0     0 -26092.520    0    7 -26005.884 -26092.520  0.33%     -    0s\n",
      "     0     2 -26092.520    0    7 -26005.884 -26092.520  0.33%     -    0s\n",
      "\n",
      "Explored 152 nodes (215 simplex iterations) in 0.11 seconds (0.14 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 2: -26005.9 223920 \n",
      "No other solutions better than -26005.9\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective -2.600588445079e+04, best bound -2.600588445079e+04, gap 0.0000%\n",
      "[CHECK MaxAffine] obj(x_ip)=-26005.9  ip_y=-26005.9  rel_err=8.382e-07\n",
      "Set parameter OutputFlag to value 1\n",
      "Set parameter TimeLimit to value 1800\n",
      "Set parameter FuncNonlinear to value 1\n",
      "Set parameter FeasibilityTol to value 1e-09\n",
      "Set parameter OptimalityTol to value 1e-09\n",
      "Set parameter IntFeasTol to value 1e-09\n",
      "Set parameter NumericFocus to value 3\n",
      "Gurobi Optimizer version 12.0.3 build v12.0.3rc0 (mac64[arm] - Darwin 25.2.0 25C56)\n",
      "\n",
      "CPU model: Apple M4 Max\n",
      "Thread count: 14 physical cores, 14 logical processors, using up to 14 threads\n",
      "\n",
      "Non-default parameters:\n",
      "TimeLimit  1800\n",
      "FeasibilityTol  1e-09\n",
      "IntFeasTol  1e-09\n",
      "OptimalityTol  1e-09\n",
      "NumericFocus  3\n",
      "\n",
      "Optimize a model with 3004 rows, 3025 columns and 27026 nonzeros\n",
      "Model fingerprint: 0xc2ddd565\n",
      "Model has 1001 function constraints treated as nonlinear\n",
      "  1000 EXP, 1 LOG\n",
      "Variable types: 3005 continuous, 20 integer (0 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-05, 7e+04]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-12, 9e+05]\n",
      "  RHS range        [3e+01, 3e+05]\n",
      "Warning: x < 1e-10 in domain of log(x).\n",
      "         Setting lower bound to 1e-10.\n",
      "Presolve removed 2002 rows and 1002 columns\n",
      "Presolve time: 0.01s\n",
      "Presolved: 6007 rows, 2024 columns, 33031 nonzeros\n",
      "Presolved model has 1001 nonlinear constraint(s)\n",
      "\n",
      "Solving non-convex MINLP\n",
      "\n",
      "Variable types: 2004 continuous, 20 integer (0 binary)\n",
      "\n",
      "Root relaxation: objective -1.204721e+05, 1094 iterations, 0.02 seconds (0.08 work units)\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0 -120472.10    0   18          - -120472.10      -     -    0s\n",
      "     0     2 -120453.80    0   18          - -120453.80      -     -    0s\n",
      "  2917  2517 -33424.538   36    8          - -120381.54      -   8.2    5s\n",
      "  9184  4216 -31310.518   36   19          - -78283.507      -   5.9   10s\n",
      " 22219  8122 -61517.712   27   17          - -74400.583      -   3.1   15s\n",
      " 34186 12584 -35109.754   70   17          - -70928.829      -   2.6   20s\n",
      " 46584 17978 -40686.452   59   19          - -69944.560      -   2.3   25s\n",
      " 56286 21364 35191.6038  128   12          - -69638.921      -   2.2   30s\n",
      " 66689 25642 infeasible   92               - -69430.760      -   2.1   35s\n",
      " 78340 29642 infeasible   88               - -69346.318      -   1.9   40s\n",
      " 89169 33341 -64431.562   58   13          - -69305.664      -   1.8   45s\n",
      " 107816 39143 -20040.118  100   14          - -69240.858      -   1.6   50s\n",
      " 127478 45412 -26648.203  161   11          - -69219.481      -   1.5   55s\n",
      " 142783 50477 infeasible  108               - -69161.761      -   1.4   60s\n",
      " 164382 56820 -30196.844   87   12          - -69120.589      -   1.3   65s\n",
      " 182579 62357 -27427.658  157   10          - -69095.251      -   1.2   70s\n",
      " 200691 67703 -26468.257   86   13          - -69061.547      -   1.2   75s\n",
      " 219994 73837 -16599.517  153   15          - -69033.441      -   1.1   80s\n",
      " 234482 78741 -46149.670   59   14          - -68986.191      -   1.1   85s\n",
      " 249253 84163 infeasible   57               - -68945.665      -   1.1   90s\n",
      " 266731 91192 -16910.935  201   16          - -68904.740      -   1.1   95s\n",
      " 283060 97781 -16998.198  166   13          - -68854.687      -   1.0  100s\n",
      " 299385 104203 -32424.115   76   13          - -68817.245      -   1.0  105s\n",
      " 315289 110484 infeasible  113               - -68762.390      -   1.0  110s\n",
      " 330694 116583 -26188.629  137   12          - -68691.922      -   1.0  115s\n",
      " 345518 122560 -65218.953   54   13          - -68643.160      -   1.0  120s\n",
      " 359583 128077 -28736.951  100   15          - -68548.132      -   1.0  125s\n",
      " 374483 134336 -25491.209   78   12          - -68470.620      -   1.0  130s\n",
      " 390858 140635 -38402.812   73   12          - -68278.797      -   1.0  135s\n",
      " 403278 146214 infeasible  150               - -68087.080      -   0.9  140s\n",
      " 416918 152780 -21622.543  186   14          - -67696.444      -   1.0  145s\n",
      " 428814 158642 -952.94471   76   14          - -67411.400      -   1.0  150s\n",
      " 441214 164524 infeasible  126               - -67228.674      -   1.0  155s\n",
      " 453675 169986 1336.11913   62   22          - -67132.113      -   1.0  160s\n",
      " 467461 177613 infeasible  127               - -66987.140      -   1.0  165s\n",
      " 482118 183830 -26684.075   95   15          - -66888.185      -   1.0  170s\n",
      " 498421 190801 infeasible   92               - -66803.696      -   1.0  175s\n",
      " 513943 199812 infeasible  100               - -66774.829      -   1.0  180s\n",
      " 534864 213500 1264.85985   73   16          - -66633.980      -   1.0  185s\n",
      " 551662 224909 -21752.894  145   10          - -66556.250      -   1.0  190s\n",
      " 570421 238330 -32217.543   60   17          - -66516.239      -   0.9  195s\n",
      " 591145 254423 infeasible   69               - -66441.529      -   0.9  200s\n",
      " 614881 272026 -27259.697  174   12          - -66417.856      -   0.9  205s\n",
      " 637667 288148 infeasible   76               - -66359.627      -   0.9  210s\n",
      " 655098 301010 -20368.219  178   14          - -66326.262      -   0.9  215s\n",
      " 674290 315644 infeasible   68               - -66277.791      -   0.9  220s\n",
      " 696705 334382 -7499.6140  914    8          - -66224.667      -   0.9  225s\n",
      " 717373 351297 -7499.6140 1872    8          - -66176.207      -   0.8  230s\n",
      " 744310 374190 -7499.6140 3068    8          - -66125.365      -   0.8  235s\n",
      " 769370 396397 -7499.6140 4123    8          - -66108.668      -   0.8  240s\n",
      " 796585 419741 -7499.6140 5253    8          - -66085.080      -   0.8  245s\n",
      " 824501 444927 -7499.6140 6415    8          - -66047.274      -   0.8  250s\n",
      " 850240 468152 -7499.6140 7494    8          - -66025.098      -   0.8  255s\n",
      " 878642 493043 -7499.6140 8641    8          - -65994.813      -   0.7  260s\n",
      " 907192 521253 -7499.6140 9775    8          - -65987.590      -   0.7  265s\n",
      " 941546 554389 -7499.6140 11103    8          - -65964.980      -   0.7  270s\n",
      " 980340 591268 -7499.6140 12558    8          - -65962.835      -   0.7  275s\n",
      " 1008496 618772 -7499.6140 13627    8          - -65957.949      -   0.7  280s\n",
      " 1037064 645326 -7499.6140 14740    8          - -65943.758      -   0.6  285s\n",
      " 1063716 669631 -7499.6140 15813    8          - -65930.701      -   0.6  290s\n",
      " 1088835 690864 -7499.6140 16888    8          - -65903.633      -   0.6  295s\n",
      " 1111446 709444 -7499.6140 17894    8          - -65871.249      -   0.6  300s\n",
      " 1133265 727397 -7499.6140 18886    8          - -65840.987      -   0.6  305s\n",
      " 1156402 746628 -7499.6140 19931    8          - -65821.268      -   0.6  310s\n",
      " 1172758 757632 infeasible  101               - -65818.138      -   0.6  315s\n",
      " 1191455 772199 1133.55406   69   27          - -65756.327      -   0.6  320s\n",
      " 1212465 787645 -18779.556  101   16          - -65730.209      -   0.6  325s\n",
      " 1232512 802931 -27290.887   57   19          - -65669.228      -   0.6  330s\n",
      " 1256966 823015 -22439.544   57   16          - -65630.939      -   0.6  335s\n",
      " 1281033 843039 -5555.6534  120   20          - -65586.474      -   0.6  340s\n",
      " 1305660 863019 infeasible   57               - -65572.457      -   0.6  345s\n",
      " 1330317 882830 infeasible  146               - -65516.059      -   0.6  350s\n",
      " 1355091 903609 16170.4073   93   15          - -65505.103      -   0.6  355s\n",
      " 1381833 926473 -28553.595  208   11          - -65464.992      -   0.6  360s\n",
      " 1406078 945573 -22527.047   44   16          - -65420.250      -   0.6  365s\n",
      " 1435003 969306 3883.69351   42   10          - -65404.817      -   0.6  370s\n",
      " 1459482 987833 infeasible   64               - -65350.250      -   0.6  375s\n",
      " 1480487 1004285 20432.0831   69   16          - -65324.843      -   0.6  380s\n",
      " 1502947 1020756 -22457.294  152   17          - -65276.243      -   0.6  385s\n",
      " 1525732 1036629 -26090.293   82   13          - -65202.570      -   0.6  390s\n",
      " 1547406 1052788 -30012.163   57   12          - -65172.977      -   0.6  395s\n",
      " 1569786 1069526 -23809.225  190   12          - -65147.102      -   0.6  400s\n",
      " 1591702 1085897 3059.11012   81   17          - -65116.028      -   0.6  405s\n",
      " 1616230 1105111 -24355.679  130   16          - -65074.165      -   0.6  410s\n",
      " 1640830 1124671 -20376.258  234   13          - -65053.665      -   0.6  415s\n",
      " 1668997 1147505 -12273.918   49   18          - -64995.357      -   0.6  420s\n",
      " 1691783 1165551 -19038.876  152   10          - -64975.924      -   0.6  425s\n",
      " 1716572 1184633 -23168.450  105   11          - -64907.043      -   0.6  430s\n",
      " 1739378 1200207 8288.95608   93    8          - -64881.722      -   0.6  435s\n",
      " 1760643 1214180 -20771.570  167   15          - -64857.993      -   0.6  440s\n",
      " 1778855 1225877 -28693.743   78   15          - -64791.323      -   0.6  445s\n",
      " 1800862 1240196 -19241.919  166   11          - -64740.887      -   0.6  450s\n",
      " 1823438 1255164 -29648.371   51   20          - -64691.238      -   0.6  455s\n",
      " 1841511 1265885 -33323.630   84   14          - -64672.799      -   0.6  460s\n",
      " 1859169 1276255 -20530.163  215   12          - -64631.761      -   0.6  465s\n",
      " 1877620 1286329 infeasible  148               - -64566.622      -   0.6  470s\n",
      " 1895623 1296289 -23223.777  106   14          - -64549.399      -   0.6  475s\n",
      " 1913446 1306406 -20320.080  175    9          - -64501.550      -   0.6  480s\n",
      " 1929505 1314594 infeasible   44               - -64468.176      -   0.6  485s\n",
      " 1946746 1324777 -24138.414  144   10          - -64409.280      -   0.6  490s\n",
      " 1967923 1337040 -18372.498  191    5          - -64380.031      -   0.6  495s\n",
      " 1983839 1347369 infeasible  184               - -64352.012      -   0.6  500s\n",
      " 2002675 1357626 infeasible  174               - -64322.396      -   0.6  505s\n",
      " 2021200 1367592 infeasible   95               - -64271.598      -   0.6  510s\n",
      " 2037496 1376650 infeasible  134               - -64194.911      -   0.6  515s\n",
      " 2053670 1384771 infeasible   99               - -64134.758      -   0.6  520s\n",
      " 2069786 1392923 infeasible  185               - -64048.238      -   0.6  525s\n",
      " 2085172 1400864 infeasible   78               - -64012.182      -   0.6  530s\n",
      " 2101992 1409532 2751.72631   59   26          - -63981.753      -   0.6  535s\n",
      " 2120031 1418878 -20630.807   90   16          - -63925.593      -   0.6  540s\n",
      " 2139419 1429842 -20001.857  194   13          - -63898.846      -   0.6  545s\n",
      " 2160190 1440895 -17774.997  204   14          - -63862.743      -   0.6  550s\n",
      " 2179529 1451242 infeasible  178               - -63833.980      -   0.6  555s\n",
      " 2198798 1460729 -25091.144   60   14          - -63776.770      -   0.6  560s\n",
      " 2214840 1468498 -33928.732   73   11          - -63750.933      -   0.6  565s\n",
      " 2227946 1475496 -27124.592  156    9          - -63712.807      -   0.6  570s\n",
      " 2244467 1484007 -22122.990  126   12          - -63690.722      -   0.6  575s\n",
      " 2260108 1492161 -5036.6328  113   11          - -63656.931      -   0.6  580s\n",
      " 2276518 1500074 infeasible   80               - -63618.415      -   0.6  585s\n",
      " 2291101 1507333 -13996.297   47   19          - -63566.214      -   0.6  590s\n",
      " 2305772 1514801 -15452.349  157   13          - -63538.230      -   0.6  595s\n",
      " 2319955 1521669 -26405.051   98   19          - -63511.863      -   0.6  600s\n",
      " 2335497 1528307 infeasible   71               - -63475.076      -   0.6  605s\n",
      " 2350494 1534253 -17178.341  158   15          - -63430.640      -   0.6  610s\n",
      " 2363786 1540787 -2403.7968   53   18          - -63418.450      -   0.6  615s\n",
      " 2374923 1545734 infeasible  109               - -63359.907      -   0.6  620s\n",
      " 2387351 1552042 infeasible   41               - -63330.838      -   0.6  625s\n",
      " 2400907 1558599 -18149.743   70   17          - -63296.669      -   0.6  630s\n",
      " 2414725 1566031 -19530.541  140   14          - -63255.541      -   0.7  635s\n",
      " 2429919 1573263 -21699.206  173   11          - -63228.345      -   0.7  640s\n",
      " 2448562 1582150 -27802.324   75   15          - -63219.807      -   0.7  645s\n",
      " 2466422 1590429 11151.4779   52   15          - -63162.789      -   0.7  650s\n",
      " 2483021 1598450 -32880.465   54   17          - -63130.586      -   0.7  655s\n",
      " 2497102 1605300 infeasible  118               - -63112.239      -   0.7  660s\n",
      " 2511731 1612946 5471.74668  101   14          - -63086.866      -   0.7  665s\n",
      " 2527649 1620297 -21740.507  106   11          - -63048.939      -   0.7  670s\n",
      " 2543183 1626278 infeasible  122               - -63014.247      -   0.7  675s\n",
      " 2557617 1632686 infeasible  128               - -62993.033      -   0.7  680s\n",
      " 2572495 1639065 17971.0678  105   14          - -62959.090      -   0.7  685s\n",
      " 2589783 1647617 -34551.100   72   12          - -62939.000      -   0.7  690s\n",
      " 2605778 1654943 infeasible  122               - -62913.282      -   0.7  695s\n",
      " 2621382 1662833 infeasible  136               - -62881.022      -   0.7  700s\n",
      " 2640074 1673534 infeasible  259               - -62854.769      -   0.7  705s\n",
      " 2657376 1683805 infeasible   98               - -62836.796      -   0.7  710s\n",
      " 2677713 1696075 -2961.3995   55   21          - -62817.702      -   0.7  715s\n",
      " 2700348 1710195 -28539.205   90   18          - -62796.636      -   0.7  720s\n",
      " 2721119 1723299 -32624.052   87   13          - -62780.309      -   0.7  725s\n",
      " 2737531 1734430 infeasible  113               - -62770.051      -   0.7  730s\n",
      " 2754430 1746563 -21266.172  884    8          - -62749.709      -   0.7  735s\n",
      " 2775145 1761686 -21266.172 1864    8          - -62735.495      -   0.7  740s\n",
      " 2793348 1775089 -21266.172 2660    8          - -62701.214      -   0.7  745s\n",
      " 2815543 1791390 -21266.172 3621    8          - -62690.084      -   0.7  750s\n",
      " 2835386 1805786 -21266.172 4524    8          - -62671.397      -   0.7  755s\n",
      " 2857476 1821492 -21266.172 5508    8          - -62660.917      -   0.6  760s\n",
      " 2872832 1832665 -21266.172 6201    8          - -62648.264      -   0.6  765s\n",
      " 2889736 1844882 -21266.172 6974    8          - -62632.347      -   0.6  770s\n",
      " 2904245 1855510 -21266.172 7648    8          - -62607.831      -   0.6  775s\n",
      " 2928550 1873506 -21266.172 8770    8          - -62596.420      -   0.6  780s\n",
      " 2950889 1891247 -21266.172 9755    8          - -62582.242      -   0.6  785s\n",
      " 2975459 1909262 -21266.172 10799    8          - -62577.135      -   0.6  790s\n",
      " 2999129 1928395 -21266.172 11753    8          - -62551.074      -   0.6  795s\n",
      " 3022185 1946230 -21266.172 12724    8          - -62529.736      -   0.6  800s\n",
      " 3046281 1964711 -21266.172 13791    8          - -62517.596      -   0.6  805s\n",
      " 3072459 1984593 -21266.172 14876    8          - -62503.431      -   0.6  810s\n",
      " 3098389 2003067 -21266.172 16038    8          - -62485.259      -   0.6  815s\n",
      " 3121751 2018975 -21266.172 17128    8          - -62458.427      -   0.6  820s\n",
      " 3143769 2033705 -21266.172 18191    8          - -62455.991      -   0.6  825s\n",
      " 3166045 2048204 -21266.172 19292    8          - -62427.128      -   0.6  830s\n",
      " 3189268 2063884 -21266.172 20294    8          - -62406.597      -   0.6  835s\n",
      " 3213445 2079803 -15977.492  120   14          - -62388.981      -   0.6  840s\n",
      " 3232111 2090574 -30158.926   95   15          - -62377.935      -   0.6  845s\n",
      " 3249849 2100894 infeasible  191               - -62346.918      -   0.6  850s\n",
      " 3270050 2112980 -30788.828  109   16          - -62317.301      -   0.6  855s\n",
      " 3288127 2124141 -32812.063   88   13          - -62296.554      -   0.6  860s\n",
      " 3307632 2135764 -3030.1098   63   32          - -62260.067      -   0.6  865s\n",
      " 3326786 2147680 -11702.154   67   20          - -62246.937      -   0.6  870s\n",
      " 3345166 2158672 infeasible  186               - -62225.939      -   0.6  875s\n",
      " 3362631 2169285 infeasible  118               - -62218.681      -   0.6  880s\n",
      " 3381238 2180235 -58328.099   46   17          - -62169.747      -   0.6  885s\n",
      " 3398330 2189818 3532.02098  137   10          - -62151.585      -   0.6  890s\n",
      " 3415635 2199247 -29398.991   65   17          - -62121.340      -   0.6  895s\n",
      " 3430213 2206179 -17822.923  134   12          - -62107.611      -   0.6  900s\n",
      " 3447390 2214389 -20324.035   87   15          - -62086.320      -   0.6  905s\n",
      " 3465502 2224009 infeasible  183               - -62068.761      -   0.6  910s\n",
      " 3486370 2235301 infeasible  155               - -62056.762      -   0.6  915s\n",
      " 3504161 2244851 infeasible   76               - -62028.105      -   0.6  920s\n",
      " 3522592 2254720 infeasible  102               - -62019.467      -   0.6  925s\n",
      " 3539821 2265495 -22239.420  423    9          - -61991.511      -   0.6  930s\n",
      " 3559957 2279430 -22239.420 1413    9          - -61979.612      -   0.6  935s\n",
      " 3578977 2292627 -22239.420 2354    9          - -61971.397      -   0.6  940s\n",
      " 3600537 2306446 -22239.420 3340    9          - -61958.855      -   0.6  945s\n",
      " 3619189 2318926 -22239.420 4191    9          - -61946.608      -   0.6  950s\n",
      " 3639949 2332777 -22239.420 5140    9          - -61931.300      -   0.6  955s\n",
      " 3657363 2344274 -22239.420 5958    9          - -61919.258      -   0.6  960s\n",
      " 3676579 2357462 -22239.420 6849    9          - -61899.868      -   0.6  965s\n",
      " 3693795 2368134 -22239.420 7674    9          - -61882.316      -   0.6  970s\n",
      " 3712670 2379527 -22239.420 8583    9          - -61870.489      -   0.6  975s\n",
      " 3729799 2390331 -22239.420 9448    9          - -61852.231      -   0.6  980s\n",
      " 3747631 2401296 -22239.420 10340    9          - -61835.472      -   0.6  985s\n",
      " 3764192 2411735 -22239.420 11157    9          - -61807.429      -   0.6  990s\n",
      " 3775455 2418266 -22239.420 11685    9          - -61801.580      -   0.6  995s\n",
      " 3788474 2425601 -22239.420 12333    9          - -61790.507      -   0.6 1000s\n",
      " 3802162 2434922 -22239.420 13037    9          - -61774.104      -   0.6 1005s\n",
      " 3813597 2441789 -22239.420 13588    9          - -61765.465      -   0.6 1010s\n",
      " 3825035 2448557 -22239.420 14112    9          - -61764.260      -   0.6 1015s\n",
      " 3839606 2458445 -22239.420 14832    9          - -61743.439      -   0.6 1020s\n",
      " 3850873 2466009 -22239.420 15367    9          - -61739.903      -   0.6 1025s\n",
      " 3863057 2473534 -22239.420 15976    9          - -61730.885      -   0.6 1030s\n",
      " 3874391 2480717 -22239.420 16562    9          - -61728.189      -   0.6 1035s\n",
      " 3885110 2487166 -22239.420 17066    9          - -61716.880      -   0.6 1040s\n",
      " 3894133 2492716 -22239.420 17528    9          - -61712.588      -   0.6 1045s\n",
      " 3903731 2498313 -22239.420 17989    9          - -61706.950      -   0.6 1050s\n",
      " 3914331 2504738 -22239.420 18549    9          - -61706.950      -   0.6 1055s\n",
      " 3925871 2511790 -22239.420 19146    9          - -61697.284      -   0.6 1060s\n",
      " 3938923 2519401 -22239.420 19778    9          - -61697.284      -   0.6 1065s\n",
      " 3949691 2525850 -22239.420 20285    9          - -61665.628      -   0.6 1070s\n",
      " 3960987 2532685 -22239.420 20281    9          - -61665.628      -   0.6 1075s\n",
      " 3973389 2539888 -18614.414  176   14          - -61651.075      -   0.6 1080s\n",
      " 3982957 2544638 -17865.401  175   11          - -61644.814      -   0.6 1085s\n",
      " 3999248 2553815 9176.93132  210   12          - -61625.493      -   0.6 1090s\n",
      " 4016298 2563555 -19684.983   99   16          - -61615.128      -   0.6 1095s\n",
      " 4026668 2569748 -14249.178   75   12          - -61602.619      -   0.6 1100s\n",
      " 4042677 2580285 -32535.456   68   14          - -61590.343      -   0.6 1105s\n",
      " 4058211 2590279 infeasible  164               - -61579.866      -   0.6 1110s\n",
      " 4074379 2600627 -12685.987  120    9          - -61559.314      -   0.6 1115s\n",
      " 4088889 2610241 -18757.036  218   15          - -61550.114      -   0.6 1120s\n",
      " 4100782 2618933 infeasible   94               - -61543.726      -   0.6 1125s\n",
      " 4115113 2629301 10670.2635  108   11          - -61538.850      -   0.6 1130s\n",
      " 4131331 2640535 -13213.102  160   12          - -61529.749      -   0.6 1135s\n",
      " 4143356 2649164 -19355.015  123   14          - -61519.291      -   0.6 1140s\n",
      " 4158904 2660601 -24094.318   98   16          - -61511.290      -   0.6 1145s\n",
      " 4173344 2669971 -18573.407   68   18          - -61494.482      -   0.6 1150s\n",
      " 4189419 2681970 8204.42186   49   14          - -61492.600      -   0.6 1155s\n",
      " 4206537 2693985 -9190.9823  163   10          - -61483.741      -   0.6 1160s\n",
      " 4223270 2705647 -29262.758   66   15          - -61474.685      -   0.6 1165s\n",
      " 4234086 2713607 -22863.610   64   18          - -61474.685      -   0.6 1170s\n",
      " 4246401 2722390 3901.43439   64   11          - -61464.951      -   0.6 1175s\n",
      " 4257241 2729957 5828.26584  104   10          - -61457.662      -   0.6 1180s\n",
      " 4267610 2736840 -27420.156  146   11          - -61452.305      -   0.6 1185s\n",
      " 4278970 2744534 9094.65716   75  100          - -61440.770      -   0.6 1190s\n",
      " 4297644 2757135 34016.9230   90   17          - -61429.234      -   0.6 1195s\n",
      " 4314876 2769315 -32227.129   98    9          - -61425.355      -   0.6 1200s\n",
      " 4332227 2782471 -35277.872   78   16          - -61413.584      -   0.6 1205s\n",
      " 4344109 2791215 -22115.678  156   10          - -61411.158      -   0.6 1210s\n",
      " 4359339 2802723 -33216.682   66   14          - -61407.486      -   0.6 1215s\n",
      " 4371601 2812000 -24119.845   67   17          - -61400.253      -   0.6 1220s\n",
      " 4383107 2819913 -33457.674   69   14          - -61397.084      -   0.6 1225s\n",
      " 4396844 2830862 -23691.396  138    9          - -61395.433      -   0.6 1230s\n",
      " 4413459 2843448 infeasible  160               - -61389.248      -   0.6 1235s\n",
      " 4427365 2854312 -24657.959  178   11          - -61383.349      -   0.6 1240s\n",
      " 4444771 2867624 -10436.760  187   13          - -61374.994      -   0.6 1245s\n",
      " 4457509 2876965 -60017.844   63   18          - -61370.242      -   0.6 1250s\n",
      " 4469871 2886137 12234.7678  142   13          - -61364.706      -   0.6 1255s\n",
      " 4483877 2896059 -27668.448   98   11          - -61361.528      -   0.6 1260s\n",
      " 4492723 2901852 infeasible  152               - -61356.855      -   0.6 1265s\n",
      " 4504410 2910081 infeasible  144               - -61346.393      -   0.6 1270s\n",
      " 4520875 2920209 infeasible   67               - -61343.019      -   0.6 1275s\n",
      " 4537989 2930456 -27301.325  142   13          - -61329.896      -   0.6 1280s\n",
      " 4556056 2940253 -16346.702  369   10          - -61320.184      -   0.6 1285s\n",
      " 4573493 2949429 -16346.702 1236   10          - -61310.739      -   0.6 1290s\n",
      " 4589011 2957999 -16346.702 1955   10          - -61297.634      -   0.6 1295s\n",
      " 4610678 2969183 -16346.702 2949   10          - -61292.644      -   0.6 1300s\n",
      " 4628382 2978845 -16346.702 3797   10          - -61276.155      -   0.6 1305s\n",
      " 4646104 2988391 -16346.702 4684   10          - -61261.892      -   0.6 1310s\n",
      " 4663412 2998216 -16346.702 5556   10          - -61247.782      -   0.6 1315s\n",
      " 4681081 3008289 -16346.702 6433   10          - -61240.339      -   0.6 1320s\n",
      " 4699084 3018146 -16346.702 7259   10          - -61227.198      -   0.6 1325s\n",
      " 4719916 3030567 -16346.702 8152   10          - -61209.081      -   0.6 1330s\n",
      " 4737067 3040283 -16346.702 8967   10          - -61199.142      -   0.6 1335s\n",
      " 4755702 3051000 -16346.702 9895   10          - -61175.569      -   0.6 1340s\n",
      " 4772091 3060121 -16346.702 10714   10          - -61164.689      -   0.6 1345s\n",
      " 4787980 3068888 -16346.702 11536   10          - -61159.908      -   0.6 1350s\n",
      " 4804218 3077159 -16346.702 12438   10          - -61133.565      -   0.6 1355s\n",
      " 4819466 3084565 -16346.702 13262   10          - -61122.678      -   0.6 1360s\n",
      " 4828879 3089255 -16346.702 13789   10          - -61117.978      -   0.6 1365s\n",
      " 4840130 3094565 -16346.702 14402   10          - -61111.409      -   0.6 1370s\n",
      " 4853662 3101018 -16346.702 15093   10          - -61102.796      -   0.6 1375s\n",
      " 4863882 3107065 -16346.702 15633   10          - -61097.831      -   0.6 1380s\n",
      " 4873332 3111699 -16346.702 16099   10          - -61097.831      -   0.6 1385s\n",
      " 4882015 3117779 -16346.702 16496   10          - -61075.690      -   0.6 1390s\n",
      " 4893518 3124952 -16346.702 17070   10          - -61065.071      -   0.6 1395s\n",
      " 4905276 3132849 -16346.702 17620   10          - -61059.196      -   0.6 1400s\n",
      " 4918512 3141903 -16346.702 18256   10          - -61051.659      -   0.6 1405s\n",
      " 4930811 3150726 -16346.702 18829   10          - -61043.449      -   0.6 1410s\n",
      " 4940762 3157050 -16346.702 19266   10          - -61043.449      -   0.6 1415s\n",
      " 4952061 3165912 -16346.702 19751   10          - -61034.035      -   0.6 1420s\n",
      " 4962924 3173723 -16346.702 20268   10          - -61030.390      -   0.6 1425s\n",
      " 4976844 3183620 -16346.702 20284   10          - -61022.904      -   0.6 1430s\n",
      " 4990937 3193067 -21790.580  145    8          - -61017.560      -   0.6 1435s\n",
      " 5003310 3200980 -7912.7736  140   13          - -61014.646      -   0.6 1440s\n",
      " 5016621 3210602 infeasible  126               - -61001.353      -   0.6 1445s\n",
      " 5028818 3218141 infeasible  148               - -60989.804      -   0.6 1450s\n",
      "*5035093 2352559             181    -16154.46457 -60989.804   278%   0.6 1467s\n",
      " 5038428 2355234 -60899.423   67   17 -16154.465 -60989.804   278%   0.6 1470s\n",
      " 5043281 2357484 infeasible  119      -16154.465 -60976.839   277%   0.6 1475s\n",
      " 5052981 2364918 -26678.993   38   14 -16154.465 -60969.376   277%   0.6 1480s\n",
      " 5057961 2368067 -29040.744   84   11 -16154.465 -60956.265   277%   0.6 1485s\n",
      " 5062803 2371332 -57007.745   51   11 -16154.465 -60947.590   277%   0.6 1490s\n",
      "*5063152 22273              84    -39277.85390 -60947.590  55.2%   0.6 1493s\n",
      " 5066004 23177 -52299.402   78   13 -39277.854 -60918.154  55.1%   0.6 1495s\n",
      " 5078462 27930 -59374.261   61   13 -39277.854 -60765.993  54.7%   0.6 1500s\n",
      " 5089689 31639 infeasible   67      -39277.854 -60641.667  54.4%   0.6 1505s\n",
      " 5101153 35784 -59774.240   67   17 -39277.854 -60531.270  54.1%   0.6 1510s\n",
      " 5112242 39483 -59719.210   42   18 -39277.854 -60431.479  53.9%   0.6 1515s\n",
      " 5123776 43346 -59950.448   59   16 -39277.854 -60329.509  53.6%   0.6 1520s\n",
      " 5135142 46816 -47773.700   55   16 -39277.854 -60252.728  53.4%   0.6 1525s\n",
      " 5146808 50669 -60174.855   60   19 -39277.854 -60175.789  53.2%   0.6 1530s\n",
      " 5157371 54006     cutoff   49      -39277.854 -60091.641  53.0%   0.6 1535s\n",
      " 5169038 57600 -59963.524   71   12 -39277.854 -60018.612  52.8%   0.6 1540s\n",
      " 5181518 61461 -59641.047   52   13 -39277.854 -59933.553  52.6%   0.6 1545s\n",
      " 5192896 65188 -59869.318   60   16 -39277.854 -59869.827  52.4%   0.6 1550s\n",
      " 5204482 68228 -56742.796   40   15 -39277.854 -59799.989  52.2%   0.6 1555s\n",
      " 5216119 71578 -58414.450   73   16 -39277.854 -59733.448  52.1%   0.6 1560s\n",
      " 5227693 74869 -58491.517   56   18 -39277.854 -59664.203  51.9%   0.6 1565s\n",
      " 5239164 77840 -59552.238   49   17 -39277.854 -59592.493  51.7%   0.6 1570s\n",
      " 5250726 81012 -56391.626   61   18 -39277.854 -59524.004  51.5%   0.6 1575s\n",
      " 5262326 84409 -56221.152   81   11 -39277.854 -59468.531  51.4%   0.6 1580s\n",
      " 5273695 88033 -49848.883   74   16 -39277.854 -59417.578  51.3%   0.6 1585s\n",
      " 5285546 91315     cutoff   58      -39277.854 -59368.634  51.2%   0.6 1590s\n",
      " 5295950 93904 -59314.734   63   17 -39277.854 -59314.734  51.0%   0.6 1595s\n",
      " 5307663 96738 -58373.825   59   16 -39277.854 -59259.694  50.9%   0.6 1600s\n",
      " 5319216 99795 -59207.810   59   15 -39277.854 -59210.140  50.7%   0.6 1605s\n",
      " 5331002 102599 -57822.357   59   16 -39277.854 -59159.000  50.6%   0.7 1610s\n",
      " 5341399 105119 -59066.535   60   16 -39277.854 -59118.276  50.5%   0.7 1615s\n",
      " 5353120 107787 -59065.236   60   17 -39277.854 -59065.236  50.4%   0.7 1620s\n",
      " 5364101 109867 -59012.947   53   18 -39277.854 -59016.372  50.3%   0.7 1625s\n",
      " 5374859 112049     cutoff   51      -39277.854 -58965.739  50.1%   0.7 1630s\n",
      " 5385145 114076 -58912.813   49   16 -39277.854 -58924.872  50.0%   0.7 1635s\n",
      " 5396967 116234 -58461.584   57   16 -39277.854 -58873.947  49.9%   0.7 1640s\n",
      " 5407897 118076 -55740.341   84   14 -39277.854 -58828.493  49.8%   0.7 1645s\n",
      " 5418972 119793 -58774.694   56   19 -39277.854 -58780.928  49.7%   0.7 1650s\n",
      " 5429998 121674 -58615.199   46   17 -39277.854 -58734.550  49.5%   0.7 1655s\n",
      " 5441742 123873 infeasible   81      -39277.854 -58689.471  49.4%   0.7 1660s\n",
      " 5453686 125733     cutoff   60      -39277.854 -58640.469  49.3%   0.7 1665s\n",
      " 5465222 127592     cutoff   50      -39277.854 -58593.478  49.2%   0.7 1670s\n",
      " 5477042 129465     cutoff   59      -39277.854 -58547.295  49.1%   0.7 1675s\n",
      " 5488594 130816 -56929.444   42   16 -39277.854 -58500.317  48.9%   0.7 1680s\n",
      " 5499530 132490 -58060.148   49   18 -39277.854 -58455.679  48.8%   0.7 1685s\n",
      " 5511659 134513 infeasible   60      -39277.854 -58412.648  48.7%   0.7 1690s\n",
      " 5522519 135867 -58254.438   47   15 -39277.854 -58371.851  48.6%   0.7 1695s\n",
      " 5534415 137437     cutoff   53      -39277.854 -58330.525  48.5%   0.7 1700s\n",
      " 5545512 138904 -58156.752   58   13 -39277.854 -58290.274  48.4%   0.7 1705s\n",
      " 5557341 140443 -58207.734   44   15 -39277.854 -58250.056  48.3%   0.7 1710s\n",
      " 5567330 141577 -56615.432   67   18 -39277.854 -58213.589  48.2%   0.7 1715s\n",
      " 5576408 142431 -58160.346   49   18 -39277.854 -58180.379  48.1%   0.7 1720s\n",
      " 5585506 143883 infeasible   69      -39277.854 -58146.450  48.0%   0.7 1725s\n",
      " 5595610 145472 -56673.676   61   20 -39277.854 -58112.161  48.0%   0.7 1730s\n",
      " 5605621 146633 -57634.664   64   19 -39277.854 -58078.007  47.9%   0.7 1735s\n",
      " 5615853 147506 -57993.385   61   18 -39277.854 -58039.232  47.8%   0.7 1740s\n",
      " 5627616 149346     cutoff   44      -39277.854 -58005.066  47.7%   0.7 1745s\n",
      " 5638055 150753 -51313.111   69   16 -39277.854 -57969.521  47.6%   0.7 1750s\n",
      " 5649605 151921 -56147.182   52   19 -39277.854 -57932.433  47.5%   0.7 1755s\n",
      " 5659685 152933 infeasible   83      -39277.854 -57897.457  47.4%   0.7 1760s\n",
      " 5669915 154133     cutoff   56      -39277.854 -57865.782  47.3%   0.7 1765s\n",
      " 5681023 155648 -56318.288   50   15 -39277.854 -57831.332  47.2%   0.7 1770s\n",
      " 5691824 155896 infeasible   62      -39277.854 -57794.168  47.1%   0.7 1775s\n",
      " 5701522 156397     cutoff   67      -39277.854 -57757.439  47.0%   0.7 1780s\n",
      " 5711920 157068 -54292.114   71   16 -39277.854 -57724.669  47.0%   0.7 1785s\n",
      " 5720803 157724     cutoff   53      -39277.854 -57691.949  46.9%   0.7 1790s\n",
      " 5731335 157939 -55971.421   68   19 -39277.854 -57652.656  46.8%   0.7 1795s\n",
      " 5742221 158829 infeasible   62      -39277.854 -57614.421  46.7%   0.7 1800s\n",
      "\n",
      "Explored 5743098 nodes (4166576 simplex iterations) in 1800.01 seconds (2105.06 work units)\n",
      "Thread count was 14 (of 14 available processors)\n",
      "\n",
      "Solution count 2: -39277.9 -16154.5 \n",
      "\n",
      "Time limit reached\n",
      "Best objective -3.927785389648e+04, best bound -5.761237119382e+04, gap 46.6790%\n",
      "[CHECK LSET] obj(x_ip)=-26509.7  ip_y=-39277.9  rel_err=4.816e-01\n",
      "\n",
      "=== MODEL SPECS (from seed 0 run) ===\n",
      "    model  n_params                                                  details    lr  batch_size  epochs\n",
      "      DFN     17707  layers=[32, 64, 32] p_list=[1, 1] alpha=0.005 beta=-2.0 0.100           8     500\n",
      "DFN_AfixI     21524 layers=[10, 256, 11] p_list=[1, 1] alpha=0.005 beta=-2.0 0.100           8     500\n",
      "     LSET     21000                                     n_pieces=1000 T=0.05 0.001           8     500\n",
      "      MLP     19329                                        hidden=[128, 128] 0.001           8     500\n",
      "MaxAffine     21000                                            n_pieces=1000 0.001           8     500\n",
      "\n",
      "=== LEARNING SUMMARY (mean ± SE over seeds) ===\n",
      "    model     train_time      best_val          test\n",
      "      DFN  0.0868274 ± 0 0.0310469 ± 0 0.0333343 ± 0\n",
      "DFN_AfixI 0.00422013 ± 0   0.18933 ± 0  0.205894 ± 0\n",
      "     LSET 0.00572162 ± 0  0.182747 ± 0  0.169422 ± 0\n",
      "      MLP 0.00734021 ± 0 0.0460312 ± 0 0.0494116 ± 0\n",
      "MaxAffine 0.00755729 ± 0  0.167872 ± 0  0.175368 ± 0\n",
      "\n",
      "=== OPTIMIZATION SUMMARY (mean ± SE over seeds) ===\n",
      "    model                                                                                      LS_x         LS_y   LS_true_y     LS_time                                                                                      IP_x         IP_y   IP_true_y      IP_time                                                                                 GT_x   GT_true_y       GT_time      LS_vs_IP_% IP_true_vs_GT_%\n",
      "      DFN         [27, -32, -31, 18, -22, -12, -18, 5, 9, 13, -5, 6, 30, 29, 26, -1, -2, 2, 4, -12]  5655.91 ± 0    9826 ± 0 66.5071 ± 0          [24, -33, -39, 16, -15, -7, -15, 9, 14, 8, -9, 7, 32, 33, 28, 2, -6, 7, -4, -18]  4408.71 ± 0   14343 ± 0  1.03493 ± 0 [33, -40, -23, 12, -34, -10, -17, 11, 24, 19, -19, 7, 30, 23, 28, 1, 7, -5, 10, -23] 227.976 ± 0 0.0481046 ± 0     28.2892 ± 0     6191.46 ± 0\n",
      "DFN_AfixI         [50, -50, -50, 2, -50, -1, -50, 0, 50, 29, -50, 29, 50, 50, 50, 1, 0, 1, 23, -50] -10167.4 ± 0 78512.2 ± 0 92.4111 ± 0         [50, -50, -50, 2, -50, -1, -50, 0, 50, 29, -50, 29, 50, 50, 50, 1, 0, 1, 23, -50] -10167.3 ± 0 78512.2 ± 0 0.138021 ± 0 [33, -40, -23, 12, -34, -10, -17, 11, 24, 19, -19, 7, 30, 23, 28, 1, 7, -5, 10, -23] 227.976 ± 0 0.0481046 ± 0 -0.00026231 ± 0     34338.8 ± 0\n",
      "     LSET   [50, -50, -50, 50, -50, -50, -38, -3, 31, 50, -50, 4, -5, 50, 50, -31, 26, 50, 50, -50] -30749.8 ± 0  165832 ± 0 2.75452 ± 0  [49, -50, -50, 40, -50, -50, -29, -1, 26, 50, -50, 19, -8, 50, 50, -36, 28, 46, 49, -49] -39277.9 ± 0  158811 ± 0   1800.8 ± 0 [33, -40, -23, 12, -34, -10, -17, 11, 24, 19, -19, 7, 30, 23, 28, 1, 7, -5, 10, -23] 227.976 ± 0 0.0481046 ± 0     21.7122 ± 0     69561.1 ± 0\n",
      "      MLP     [31, -50, -29, 42, -50, -13, -14, 2, 50, 5, -25, 31, 27, 28, 18, -38, 0, -9, 50, -22]  27390.8 ± 0 55650.3 ± 0 1.58796 ± 0     [35, -50, -40, 12, -50, -7, -13, 29, 50, -6, -50, 49, 29, 50, 20, -28, 5, 8, 41, -50]  18069.6 ± 0 77826.6 ± 0  1801.07 ± 0 [33, -40, -23, 12, -34, -10, -17, 11, 24, 19, -19, 7, 30, 23, 28, 1, 7, -5, 10, -23] 227.976 ± 0 0.0481046 ± 0     51.5847 ± 0       34038 ± 0\n",
      "MaxAffine [50, -50, -2, -10, -50, -50, -50, 50, 33, 23, -50, -25, 22, 50, 50, 21, -10, 50, 14, -32]   -12701 ± 0  142857 ± 0  2.1288 ± 0 [50, -50, -46, 38, -50, -50, -50, 50, 10, 50, -50, -50, 50, 50, 50, 37, 24, 50, -45, -34] -26005.9 ± 0  192313 ± 0 0.195166 ± 0 [33, -40, -23, 12, -34, -10, -17, 11, 24, 19, -19, 7, 30, 23, 28, 1, 7, -5, 10, -23] 227.976 ± 0 0.0481046 ± 0     51.1612 ± 0     84256.4 ± 0\n",
      "\n",
      "=== FAILURES / WARNINGS (if any) ===\n",
      " seed model    stage                          error\n",
      "    0  LSET IP_CHECK rel_err=4.816e-01 (tol=0.0001)\n"
     ]
    }
   ],
   "source": [
    "# ===================== Quadratic DATASET =====================\n",
    "N_SEEDS = 1\n",
    "\n",
    "VARY_DATASET_SEED = True\n",
    "VARY_MODEL_INIT_SEED = True\n",
    "STRICT_IP_CHECK = False\n",
    "IP_CHECK_TOL = 1e-4\n",
    "\n",
    "SILENCE_LOCAL_SEARCH = True\n",
    "ALLOW_PLOTS_MULTI_SEED = True\n",
    "\n",
    "dataset_type = \"quadratic\"\n",
    "dataset_params = dict(\n",
    "    K=2048, dim=20, eigen_min=1, eigen_max=15.0,\n",
    "    x_min=-50, x_max=50, noise_std=0.1, seed=0\n",
    ")\n",
    "in_dim = int(dataset_params[\"dim\"])\n",
    "\n",
    "train_base = dict(\n",
    "    epochs=500,\n",
    "    batch_size=8,\n",
    "    val_frac=0.15,\n",
    "    test_frac=0.15,\n",
    "    seed=0,\n",
    "    device=\"cpu\",\n",
    "    eps=1e-8,\n",
    "    weight_decay=0.0,\n",
    "    plot_every=10,\n",
    "    plot_points=128,\n",
    "    plot_chunk=128,\n",
    ")\n",
    "\n",
    "# ---- DFN (learnable A) ----\n",
    "dfn_params = dict(\n",
    "    input_dim=in_dim, layer_sizes=[32, 64, 32], p_list=[1, 1],\n",
    "    seed=0, alpha=5e-3, beta=-2.0\n",
    ")\n",
    "\n",
    "# ---- DFN (fixed A = I) ----\n",
    "dfn_Afix_params = dict(\n",
    "    input_dim=in_dim,\n",
    "    layer_sizes=[10, 256, 11],\n",
    "    p_list=[1, 1],\n",
    "    seed=0,\n",
    "    alpha=5e-3,\n",
    "    beta=-2.0,\n",
    "    A_fixed=np.eye(in_dim, dtype=np.float32),\n",
    ")\n",
    "\n",
    "# ---- other models ----\n",
    "mlp_params  = dict(in_dim=in_dim, hidden_dims=[128, 128], out_dim=1)\n",
    "maff_params = dict(in_dim=in_dim, n_pieces=1000)\n",
    "lset_params = dict(in_dim=in_dim, n_pieces=1000, T=0.05)\n",
    "\n",
    "lr_map = dict(DFN=1e-1, MLP=1e-3, MaxAffine=1e-3, LSET=1e-3)\n",
    "time_limit = 1800\n",
    "\n",
    "# x0 + sum_eq are generated per seed so the run seed controls them too\n",
    "x0_list = []\n",
    "sum_eq_list = []\n",
    "for seed in range(int(N_SEEDS)):\n",
    "    rng = np.random.default_rng(int(seed))\n",
    "    x0_s = rng.integers(-30, 30, size=in_dim, dtype=int)\n",
    "    x0_list.append(x0_s)\n",
    "    sum_eq_list.append(int(x0_s.sum()))\n",
    "x0 = x0_list[0] if int(N_SEEDS) == 1 else np.stack(x0_list, axis=0)\n",
    "sum_eq = sum_eq_list[0] if int(N_SEEDS) == 1 else sum_eq_list\n",
    "delta = 2\n",
    "\n",
    "\n",
    "runs = [\n",
    "    (\"DFN\",        \"DFN\", dfn_params),\n",
    "    (\"DFN_AfixI\",  \"DFN\", dfn_Afix_params),\n",
    "    (\"MLP\",        \"MLP\", mlp_params),\n",
    "    (\"MaxAffine\",  \"MaxAffine\", maff_params),\n",
    "    (\"LSET\",       \"LSET\", lset_params),\n",
    "]\n",
    "\n",
    "_ = run_benchmark(\n",
    "    dataset_type=dataset_type,\n",
    "    dataset_params=dataset_params,\n",
    "    runs=runs,\n",
    "    train_base=train_base,\n",
    "    lr_map=lr_map,\n",
    "    x0=x0, xmin=xmin, xmax=xmax,\n",
    "    delta=delta, sum_eq=sum_eq,\n",
    "    n_seeds=N_SEEDS,\n",
    "    vary_dataset_seed=VARY_DATASET_SEED,\n",
    "    vary_model_init_seed=VARY_MODEL_INIT_SEED,\n",
    "    strict_ip_check=STRICT_IP_CHECK,\n",
    "    ip_check_tol=IP_CHECK_TOL,\n",
    "    silence_local_search=SILENCE_LOCAL_SEARCH,\n",
    "    allow_plots_multi_seed=ALLOW_PLOTS_MULTI_SEED,\n",
    "    time_limit=time_limit,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af7e440-4bd3-410e-ace5-6f336ab98fa6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dfn]",
   "language": "python",
   "name": "conda-env-dfn-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}