{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "96398989-280e-41b3-bc11-99b4b51e9439",
   "metadata": {},
   "source": [
    "# General Helpers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "767d34c9-66bd-4986-a324-4daeec80aee1",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bef302bd-4ca0-4269-9ba1-ac6d05ef5819",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "import sys, json, time, math, random, tempfile\n",
    "from pathlib import Path\n",
    "from typing import Optional, Tuple, Dict, List\n",
    "import numpy as np\n",
    "import cppimport\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from IPython.display import Image, display\n",
    "from dataclasses import dataclass, field\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cd5a10c-7768-4318-8244-beaf6c4c84d1",
   "metadata": {},
   "source": [
    "### Python Implementation of LEMON \n",
    "The following defines an importable Python module `lemon_mcf` which runs LEMON solvers (written in C++)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b4658685-aab5-4c45-aa6d-36e862bbdc9a",
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "Couldn't find a file matching the module name: lemon_mcf  (opt_in = False)",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mImportError\u001b[39m                               Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m      1\u001b[39m repo = Path().resolve().parent\n\u001b[32m      2\u001b[39m sys.path.insert(\u001b[32m0\u001b[39m, \u001b[38;5;28mstr\u001b[39m(repo))\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m lemon_mcf = \u001b[43mcppimport\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimp\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mlemon_mcf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/dfn/lib/python3.11/site-packages/cppimport/__init__.py:49\u001b[39m, in \u001b[36mimp\u001b[39m\u001b[34m(fullname, opt_in)\u001b[39m\n\u001b[32m     46\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mcppimport\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mfind\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m find_module_cpppath\n\u001b[32m     48\u001b[39m \u001b[38;5;66;03m# Search through sys.path to find a file that matches the module\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m49\u001b[39m filepath = \u001b[43mfind_module_cpppath\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfullname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt_in\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     50\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m imp_from_filepath(filepath, fullname)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/dfn/lib/python3.11/site-packages/cppimport/find.py:14\u001b[39m, in \u001b[36mfind_module_cpppath\u001b[39m\u001b[34m(modulename, opt_in)\u001b[39m\n\u001b[32m     12\u001b[39m filepath = _find_module_cpppath(modulename, opt_in)\n\u001b[32m     13\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m filepath \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[32m     15\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33mCouldn\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt find a file matching the module name: \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m     16\u001b[39m         + \u001b[38;5;28mstr\u001b[39m(modulename)\n\u001b[32m     17\u001b[39m         + \u001b[33m\"\u001b[39m\u001b[33m  (opt_in = \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m     18\u001b[39m         + \u001b[38;5;28mstr\u001b[39m(opt_in)\n\u001b[32m     19\u001b[39m         + \u001b[33m\"\u001b[39m\u001b[33m)\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m     20\u001b[39m     )\n\u001b[32m     21\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m filepath\n",
      "\u001b[31mImportError\u001b[39m: Couldn't find a file matching the module name: lemon_mcf  (opt_in = False)"
     ]
    }
   ],
   "source": [
    "repo = Path().resolve().parent\n",
    "sys.path.insert(0, str(repo))\n",
    "lemon_mcf = cppimport.imp(\"lemon_mcf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "800144cd-efc4-47e8-95d6-e75cad74419d",
   "metadata": {},
   "source": [
    "The module has the function `lemon_mcf.solve_mcf` which solves the minimum cost flow problem. The function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    out = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply, tol=1e-9)\n",
    "    \n",
    "**Inputs (NumPy arrays)**\n",
    "- `n` : `int` — number of nodes (indexed `0..n-1`)\n",
    "- `src`, `dst` : `np.ndarray` shape `(m,)`, dtype `int64` — directed edges `src[i] -> dst[i]`\n",
    "- `cost`, `cap` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge cost and capacities\n",
    "- `supply` : `np.ndarray` shape `(n,)`, dtype `float64` — node supplies/demands (`>0` supply, `<0` demand)\n",
    "- `tol` : `float` — tolerance for capacity-status flags\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `out[\"status\"]` : `int`\n",
    "- `out[\"flow\"]` : `np.ndarray` shape `(m,)`, dtype `float64`\n",
    "- `out[\"potential\"]` : `np.ndarray` shape `(n,)`, dtype `float64` (node potentials; defined up to an additive constant)\n",
    "- `out[\"reduced_cost\"]` : `np.ndarray` shape `(m,)`, dtype `float64`, computed as `cost[i] + potential[src[i]] - potential[dst[i]]`\n",
    "- (optional) capacity-status flag: boolean arrays indicating whether each arc is at its capacity\n",
    "- `out[\"total_cost\"]` : `float` (objective value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ff805be-adf0-44c9-a12c-d4d2fb077251",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- Example of Usage ------ #\n",
    "\n",
    "n = 3\n",
    "src    = np.array([0, 0, 1], dtype=np.int64)\n",
    "dst    = np.array([1, 2, 2], dtype=np.int64)\n",
    "cost   = np.array([2.0, 1.0, 3.0], dtype=np.float64)\n",
    "cap    = np.array([5.0, 2.0, 4.0], dtype=np.float64)\n",
    "supply = np.array([3.0, -1.0, -2.0], dtype=np.float64)\n",
    "\n",
    "out_min_cost_flow = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply)\n",
    "print(out_min_cost_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5edb874-0d6f-4c0a-a1c0-8707a04bd0c7",
   "metadata": {},
   "source": [
    "It also has the function `lemon_mcf.max_flow` which solves the maximum s-t flow problem. The function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    out = lemon_mcf.max_flow(n, src, dst, cap, s, t)\n",
    "\n",
    "**Inputs (NumPy arrays)**\n",
    "- `n` : `int` — number of nodes (indexed `0..n-1`)\n",
    "- `src`, `dst` : `np.ndarray` shape `(m,)`, dtype `int64` — directed edges `src[i] -> dst[i]`\n",
    "- `cap` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge capacities (nonnegative)\n",
    "- `s` : `int` — source node index\n",
    "- `t` : `int` — sink node index\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `out[\"value\"]` : `float` — maximum flow value from `s` to `t`\n",
    "- `out[\"flow\"]` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge flow values (in the same order as the input edges)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84d7091d-b4dd-46db-84b2-e580dd826951",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# ----- Example of Usage ------ #\n",
    "\n",
    "n = 4\n",
    "src = np.array([0,0,1,2], dtype=np.int64)\n",
    "dst = np.array([1,2,3,3], dtype=np.int64)\n",
    "cap = np.array([3.0,2.0,2.0,4.0], dtype=np.float64)\n",
    "out_max_flow = lemon_mcf.max_flow(n, src, dst, cap, 0, 3)\n",
    "print(out_max_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2d023ed-4551-4a12-895b-5499538ccfe3",
   "metadata": {},
   "source": [
    "# Generating Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c19ed1a-276c-4f7b-930b-15ee5dd1a087",
   "metadata": {},
   "source": [
    "## Synthetic Quadratic "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e668b358-356c-4b0f-8f52-3fa79709f650",
   "metadata": {},
   "source": [
    "#### Quadratic Instance Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d51f4f1d-d52d-4167-8546-975f7056157a",
   "metadata": {},
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    data = generate_synthetic_convex_quadratic_dataset(d, m, M, X, sigma2, K, seed=None)\n",
    "\n",
    "**Inputs**\n",
    "- `d` : `int` — dimension (`d > 0`)\n",
    "- `m`, `M` : `float` — eigenvalue bounds (`0 < m <= M`)\n",
    "- `X` : `int` — integer box bound (samples are uniform over `[-X, X]^d`, coordinate-wise)\n",
    "- `sigma2` : `float` — noise variance (`sigma2 >= 0`)\n",
    "- `K` : `int` — number of samples (`K > 0`)\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "\n",
    "**Process**\n",
    "1. Sample eigenvalues `λ_1,...,λ_d ~ Uniform([m, M])` i.i.d.\n",
    "2. Sample an (approximately) Haar-uniform orthonormal matrix `U ∈ R^{d×d}` via QR of a Gaussian matrix.\n",
    "3. Form `Q = U diag(λ) U^T` (positive definite since `m > 0`).\n",
    "4. Sample an integer optimizer `x*` uniformly from the integer box `[-X, X]^d`.\n",
    "5. Sample `K` integer covariates `x^1,...,x^K` uniformly from `[-X, X]^d`.\n",
    "6. Sample i.i.d. noise `ε^k ~ N(0, sigma2)` for `k=1,...,K`.\n",
    "7. Set targets `y^k = (x^k - x*)^T Q (x^k - x*) + ε^k`.\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `data[\"U\"]` : `np.ndarray` shape `(d, d)` — sampled orthonormal matrix\n",
    "- `data[\"lambdas\"]` : `np.ndarray` shape `(d,)` — sampled eigenvalues\n",
    "- `data[\"Q\"]` : `np.ndarray` shape `(d, d)` — sampled quadratic matrix `Q = U diag(lambdas) U^T`\n",
    "- `data[\"x_star\"]` : `np.ndarray` shape `(d,)`, dtype `int` — sampled optimizer `x*`\n",
    "- `data[\"X_samples\"]` : `np.ndarray` shape `(K, d)`, dtype `int` — sampled inputs `x^k`\n",
    "- `data[\"y\"]` : `np.ndarray` shape `(K,)`, dtype `float` — targets `y^k`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edd30065-8973-4a0f-8315-99a96ac42355",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def _haar_orthonormal_matrix(d: int, rng: np.random.Generator) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Sample an (approximately) Haar-uniform random orthonormal matrix U in R^{dxd}\n",
    "    using QR decomposition of a standard Gaussian matrix, with a sign correction.\n",
    "    \"\"\"\n",
    "    A = rng.standard_normal((d, d))\n",
    "    U, R = np.linalg.qr(A)\n",
    "    s = np.sign(np.diag(R))\n",
    "    s[s == 0] = 1.0\n",
    "    U = U * s  # multiply columns by sign\n",
    "    return U\n",
    "\n",
    "def generate_synthetic_convex_quadratic_dataset(\n",
    "    d: int,\n",
    "    m: float,\n",
    "    M: float,\n",
    "    X: int,\n",
    "    sigma2: float,\n",
    "    K: int,\n",
    "    seed: int | None = None\n",
    "):\n",
    "    if d <= 0:\n",
    "        raise ValueError(\"d must be positive.\")\n",
    "    if not (M >= m):\n",
    "        raise ValueError(\"Require M >= m.\")\n",
    "    if m <= 0:\n",
    "        raise ValueError(\"Require m > 0 so that Q is positive definite.\")\n",
    "    if K <= 0:\n",
    "        raise ValueError(\"K must be positive.\")\n",
    "    if sigma2 < 0:\n",
    "        raise ValueError(\"sigma2 must be nonnegative.\")\n",
    "    if not isinstance(X, (int, np.integer)):\n",
    "        raise ValueError(\"X must be an integer.\")\n",
    "    X = int(X)\n",
    "    if X < 0:\n",
    "        raise ValueError(\"X must be a nonnegative integer.\")\n",
    "\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # (i) sample eigenvalues\n",
    "    lambdas = rng.uniform(m, M, size=d)\n",
    "\n",
    "    # (ii) sample orthonormal matrix U\n",
    "    U = _haar_orthonormal_matrix(d, rng)\n",
    "\n",
    "    # Build Q = U diag(lambdas) U^T\n",
    "    Q = U @ np.diag(lambdas) @ U.T\n",
    "\n",
    "    # Sample integer vectors uniformly from [-X, X]^d\n",
    "    def sample_int_box(num: int) -> np.ndarray:\n",
    "        return rng.integers(-X, X + 1, size=(num, d), dtype=int)\n",
    "\n",
    "    # (iii) sample x*\n",
    "    x_star = sample_int_box(1).reshape(-1)  # shape (d,)\n",
    "\n",
    "    # (iv) sample K vectors x^1,...,x^K\n",
    "    X_samples = sample_int_box(K)  # shape (K, d)\n",
    "\n",
    "    # (v) sample Gaussian noises\n",
    "    sigma = float(np.sqrt(sigma2))\n",
    "    eps = rng.normal(loc=0.0, scale=sigma, size=K)\n",
    "\n",
    "    # (vi) compute y^k\n",
    "    diffs = X_samples - x_star  # shape (K, d)\n",
    "    quad = np.einsum(\"bi,ij,bj->b\", diffs, Q, diffs)  # shape (K,)\n",
    "    y = quad + eps\n",
    "\n",
    "    return {\n",
    "        \"U\": U,\n",
    "        \"lambdas\": lambdas,\n",
    "        \"Q\": Q,\n",
    "        \"x_star\": x_star,\n",
    "        \"X_samples\": X_samples,\n",
    "        \"y\": y,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "666e79e4-8daa-483c-995d-6a3b92367b88",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# --- Example of Usage --- #\n",
    "data = generate_synthetic_convex_quadratic_dataset(\n",
    "    d=5, m=0.5, M=3.0, X=10, sigma2=0.25, K=100, seed=0\n",
    ")\n",
    "\n",
    "print(\"x_star:\", data[\"x_star\"])\n",
    "print(\"X_samples shape:\", data[\"X_samples\"].shape)\n",
    "print(\"y shape:\", data[\"y\"].shape)\n",
    "print(\"min eigenvalue(Q) ~\", np.min(np.linalg.eigvalsh(data[\"Q\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f403d314-b3f6-4933-b466-a6b3363839c7",
   "metadata": {},
   "source": [
    "#### Full Dataset Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "302dbd89-36d9-45e5-a1a4-d34846a4a6fc",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4c49e91f-881b-4020-8878-c22cf217c00c",
   "metadata": {},
   "source": [
    "## Synthetic Min Cost Flow"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f773417-3500-4b5a-b931-59eacf5b61c2",
   "metadata": {},
   "source": [
    "#### NETGEN Instance Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6c6656d-e5ac-45fd-975d-398e4fbe472b",
   "metadata": {},
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    u, v, cap, cost, b = netgen_instance_arrays(nodes, arcs, sources, sinks, cost_bounds, cap_bounds, total_supply=1000, seed=None)\n",
    "\n",
    "**Inputs**\n",
    "- `nodes` : `int` — number of nodes (`nodes > 0`)\n",
    "- `arcs` : `int` — number of directed arcs/edges to generate (`arcs > 0`)\n",
    "- `sources` : `int` — number of supply (source) nodes (`sources >= 0`)\n",
    "- `sinks` : `int` — number of demand (sink) nodes (`sinks >= 0`)\n",
    "- `cost_bounds` : `tuple[int, int]` — arc cost bounds `(min_cost, max_cost)` with `min_cost <= max_cost`\n",
    "- `cap_bounds` : `tuple[int, int]` — arc capacity bounds `(min_cap, max_cap)` with `min_cap <= max_cap`\n",
    "- `total_supply` : `int` — total supply injected into the network (`total_supply >= 0`)\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "\n",
    "**Output (arrays)**\n",
    "Returns a 5-tuple:\n",
    "\n",
    "- `u` : `np.ndarray` shape `(m,)`, dtype `int` — tail (source) node index for each arc (0-based)\n",
    "- `v` : `np.ndarray` shape `(m,)`, dtype `int` — head (destination) node index for each arc (0-based)  \n",
    "  (arc `i` is `u[i] -> v[i]`)\n",
    "- `cap` : `np.ndarray` shape `(m,)`, dtype `int` — capacity for each arc\n",
    "- `cost` : `np.ndarray` shape `(m,)`, dtype `int` — per-unit cost for each arc\n",
    "- `b` : `np.ndarray` shape `(nodes,)`, dtype `int` — flow balance for each node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fb8ebaa-0e13-4c0b-bbf6-00c0b60c4547",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def _parse_dimacs_mcf(path: Path) -> tuple[int, Dict[int, int], List[tuple[int, int, int, int]]]:\n",
    "    n = None\n",
    "    supply: Dict[int, int] = {}\n",
    "    arcs: List[tuple[int, int, int, int]] = []  # (u, v, cap, cost)\n",
    "\n",
    "    for raw in path.read_text(encoding=\"utf-8\").splitlines():\n",
    "        line = raw.strip()\n",
    "        if not line or line.startswith(\"c\"):\n",
    "            continue\n",
    "        parts = line.split()\n",
    "        if parts[0] == \"p\":\n",
    "            n = int(parts[2])\n",
    "        elif parts[0] == \"n\":\n",
    "            supply[int(parts[1])] = int(parts[2])\n",
    "        elif parts[0] == \"a\":\n",
    "            u, v, ignored, cap, cost = map(int, parts[1:])\n",
    "            if ignored != 0:\n",
    "                raise ValueError(f\"Unexpected nonzero ignored field in DIMACS arc line: value={ignored}\")\n",
    "            arcs.append((u, v, cap, cost))\n",
    "\n",
    "    if n is None:\n",
    "        raise ValueError(\"Missing 'p min N M' line.\")\n",
    "    return n, supply, arcs\n",
    "\n",
    "\n",
    "def netgen_instance_arrays(\n",
    "    *,\n",
    "    nodes: int,\n",
    "    arcs: int,\n",
    "    sources: int,\n",
    "    sinks: int,\n",
    "    cost_bounds: Tuple[int, int],\n",
    "    cap_bounds: Tuple[int, int],\n",
    "    total_supply: int = 1000,\n",
    "    seed: Optional[int] = None,\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      u, v, cap, cost, b\n",
    "    where:\n",
    "      - u,v,cap,cost are arrays of length m\n",
    "      - b is an array of length n with supply(+)/demand(-), sum(b)=0\n",
    "      - nodes are 0-based indices\n",
    "    \"\"\"\n",
    "    if nodes <= 0 or arcs <= 0:\n",
    "        raise ValueError(\"nodes and arcs must be positive.\")\n",
    "    if sources < 0 or sinks < 0 or sources + sinks > nodes:\n",
    "        raise ValueError(\"Need sources+sinks <= nodes.\")\n",
    "    minc, maxc = cost_bounds\n",
    "    mincap, maxcap = cap_bounds\n",
    "    if minc > maxc or mincap > maxcap:\n",
    "        raise ValueError(\"Bad bounds: min must be <= max.\")\n",
    "    if seed is None:\n",
    "        seed = random.randint(1, 2_000_000_000)\n",
    "\n",
    "    try:\n",
    "        import pynetgen  # type: ignore\n",
    "    except ImportError as e:\n",
    "        raise RuntimeError(\n",
    "            \"pynetgen is not installed.\"\n",
    "        ) from e\n",
    "\n",
    "    tmpdir = Path(tempfile.mkdtemp(prefix=\"netgen_\"))\n",
    "    dimacs_path = tmpdir / \"instance.dimacs\"\n",
    "\n",
    "    pynetgen.netgen_generate(\n",
    "        seed=seed,\n",
    "        nodes=nodes,\n",
    "        sources=sources,\n",
    "        sinks=sinks,\n",
    "        density=arcs,\n",
    "        mincost=minc,\n",
    "        maxcost=maxc,\n",
    "        supply=total_supply,\n",
    "        tsources=0,\n",
    "        tsinks=0,\n",
    "        hicost=0,\n",
    "        capacitated=100,\n",
    "        mincap=mincap,\n",
    "        maxcap=maxcap,\n",
    "        rng=1,\n",
    "        type=None,\n",
    "        fname=str(dimacs_path),\n",
    "    )\n",
    "\n",
    "    n, supply_1b, arc_list_1b = _parse_dimacs_mcf(dimacs_path)\n",
    "    m = len(arc_list_1b)\n",
    "\n",
    "    u = np.empty(m, dtype=np.int64)\n",
    "    v = np.empty(m, dtype=np.int64)\n",
    "    cap = np.empty(m, dtype=np.int64)\n",
    "    cost = np.empty(m, dtype=np.int64)\n",
    "\n",
    "    for i, (uu, vv, c, w) in enumerate(arc_list_1b):\n",
    "        u[i] = uu - 1\n",
    "        v[i] = vv - 1\n",
    "        cap[i] = c\n",
    "        cost[i] = w\n",
    "\n",
    "    b = np.zeros(n, dtype=np.int64)\n",
    "    for node_1b, val in supply_1b.items():\n",
    "        b[node_1b - 1] = val\n",
    "\n",
    "    if b.sum() != 0:\n",
    "        raise ValueError(f\"Unbalanced instance: sum(b)={b.sum()} (should be 0).\")\n",
    "\n",
    "    return u, v, cap, cost, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51948156-ee64-4380-9cd9-f31f811dd585",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# ------ Example of usage ------ #\n",
    "\n",
    "u, v, cap, cost, b = netgen_instance_arrays(\n",
    "    nodes=200,\n",
    "    arcs=1500,\n",
    "    sources=10,\n",
    "    sinks=10,\n",
    "    cost_bounds=(1, 50),\n",
    "    cap_bounds=(1, 100),\n",
    "    total_supply=5000,\n",
    "    seed=7,\n",
    ")\n",
    "\n",
    "n = len(b)\n",
    "m = len(u)\n",
    "print(f\"n={n}, m={m}, balance sum(b)={int(b.sum())}\")\n",
    "\n",
    "src = np.where(b > 0)[0]\n",
    "snk = np.where(b < 0)[0]\n",
    "print(f\"#sources={len(src)}, #sinks={len(snk)}\")\n",
    "print(\"sources (node, supply):\", [(int(i), int(b[i])) for i in src[:10]])\n",
    "print(\"sinks   (node, demand):\", [(int(i), int(b[i])) for i in snk[:10]])\n",
    "\n",
    "k = 10\n",
    "k = min(k, m)\n",
    "print(\"\\nFirst edges:\")\n",
    "print(f\"{'i':>4}  {'u':>5}  {'v':>5}  {'cap':>6}  {'cost':>6}\")\n",
    "print(\"-\" * 35)\n",
    "for i in range(k):\n",
    "    print(f\"{i:>4}  {int(u[i]):>5}  {int(v[i]):>5}  {int(cap[i]):>6}  {int(cost[i]):>6}\")\n",
    "\n",
    "print(\"\\ncap stats:\", int(cap.min()), float(cap.mean()), int(cap.max()))\n",
    "print(\"cost stats:\", int(cost.min()), float(cost.mean()), int(cost.max()))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "273eeb7d-38f9-4cdb-9c43-1c3e920c792d",
   "metadata": {},
   "source": [
    "#### NETGEN Dataset Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "723aab3f-b920-44d6-a891-dc87f6a89b61",
   "metadata": {},
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    data = generate_netgen_mcf_dataset(n, family, K, tilde_C, sigma2, seed=None, cap_range=(1,1000), cost_range=(1,10000), total_supply=1000, verbose=False)\n",
    "\n",
    "**Inputs**\n",
    "- `n` : `int` — number of nodes in the NETGEN network (`n > 0`)\n",
    "- `family` : `str` — family: `\"sparse\"` or `\"dense\"`\n",
    "  - `\"sparse\"` uses `m = 8n` arcs\n",
    "  - `\"dense\"` uses `m = ceil(n * sqrt(n))` arcs\n",
    "- `K` : `int` — number of balance/label samples to generate (`K > 0`)\n",
    "- `tilde_C` : `int` — max auxiliary cost used when generating balances (`tilde_C > 0`)\n",
    "- `sigma2` : `float` — label noise variance (`sigma2 >= 0`)\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "- `cap_range` : `(int, int)` — integer capacity range for NETGEN arcs (default `(1,1000)`)\n",
    "- `cost_range` : `(int, int)` — integer cost range for NETGEN arcs (default `(1,10000)`)\n",
    "- `total_supply` : `int` — total supply in the feasible balance NETGEN produces (needed to generate network but balances are discarded)\n",
    "- `verbose` : `bool` — print `F_max` and basic diagnostics if `True`\n",
    "\n",
    "**Process**\n",
    "1. Set the number of supply and demand nodes to  \n",
    "   `|S| = |T| = floor(sqrt(n))`.\n",
    "2. Choose the number of arcs `m` based on `family`:\n",
    "   - If `\"sparse\"`: `m = 8n`\n",
    "   - If `\"dense\"`:  `m = ceil(n * sqrt(n))`\n",
    "3. Generate a NETGEN minimum-cost flow instance on `n` nodes with:\n",
    "   - arc capacities `u_e ~ Unif({1,...,1000})`\n",
    "   - arc costs `c_e ~ Unif({1,...,10000})`\n",
    "   NETGEN also returns a feasible balance vector `b0`; discard its magnitudes but use its sign pattern to\n",
    "   identify the node sets:\n",
    "   - `S = {v : b0(v) > 0}` (supplies)\n",
    "   - `T = {v : b0(v) < 0}` (demands)\n",
    "4. Augment the network by adding a super-source `s` and super-sink `t`:\n",
    "   - add arcs `(s,v)` for all `v ∈ S` with capacity `+∞`\n",
    "   - add arcs `(v,t)` for all `v ∈ T` with capacity `+∞`\n",
    "5. Compute `F_max`, the value of a maximum `s–t` flow in the augmented network.\n",
    "6. For each sample `k = 1,...,K`:\n",
    "   1. Sample a target flow value `F^(k) ~ Unif({0,1,...,F_max})`.\n",
    "   2. Sample auxiliary integer arc costs on the augmented network  \n",
    "      `tilde_c_e^(k) ~ Unif({1,...,tilde_C})`.\n",
    "   3. Solve a minimum-cost `s–t` flow of value `F^(k)` on the augmented network under costs `tilde_c^(k)`,\n",
    "      obtaining a flow `f^(k)`.\n",
    "   4. Define the balance vector `b^(k) ∈ Z^n` on the original node set by:\n",
    "      - `b^(k)(v) = f_sv^(k)` for `v ∈ S`\n",
    "      - `b^(k)(v) = -f_vt^(k)` for `v ∈ T`\n",
    "      - `b^(k)(v) = 0` otherwise\n",
    "   5. Solve the original minimum-cost flow problem on `G(V,E)` with capacities `u`, costs `c`,\n",
    "      and balance `b^(k)` to obtain the optimal cost `z^(k)`.\n",
    "   6. Sample noise `ε^(k) ~ N(0, sigma2)` and set the label `y^(k) = z^(k) + ε^(k)`.\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `data[\"n\"]` : `int` — number of nodes\n",
    "- `data[\"m\"]` : `int` — number of arcs in the original NETGEN graph\n",
    "- `data[\"u\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc tail nodes (0-indexed)\n",
    "- `data[\"v\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc head nodes (0-indexed)\n",
    "- `data[\"cap\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc capacities\n",
    "- `data[\"cost\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc costs\n",
    "- `data[\"S_nodes\"]` : `np.ndarray` shape `(|S|,)`, dtype `int` — supply node indices\n",
    "- `data[\"T_nodes\"]` : `np.ndarray` shape `(|T|,)`, dtype `int` — demand node indices\n",
    "- `data[\"X_samples\"]` : `np.ndarray` shape `(K, n)`, dtype `int` — sampled balance vectors `b^(k)`\n",
    "- `data[\"y\"]` : `np.ndarray` shape `(K,)`, dtype `float` — noisy labels `y^(k)`\n",
    "- `data[\"z\"]` : `np.ndarray` shape `(K,)`, dtype `float` — clean optimal costs `z^(k)` (before noise)\n",
    "- `data[\"F_max\"]` : `int` — maximum `s–t` flow value in the augmented network\n",
    "- `data[\"F_targets\"]` : `np.ndarray` shape `(K,)`, dtype `int` — sampled target values `F^(k)`\n",
    "- `data[\"tilde_C\"]` : `int` — auxiliary cost bound used\n",
    "- `data[\"sigma2\"]` : `float` — noise variance used\n",
    "- `data[\"seed\"]` : `int` or `None` — RNG seed used\n",
    "- `data[\"family\"]` : `str` — `\"sparse\"` or `\"dense\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6820d3-b93a-4172-a39e-4c96595e5999",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "\n",
    "def generate_netgen_mcf_dataset(\n",
    "    n: int,\n",
    "    family: str,                 # \"sparse\" or \"dense\"\n",
    "    K: int,\n",
    "    tilde_C: int,\n",
    "    sigma2: float,\n",
    "    seed: int | None = None,\n",
    "    cap_range: tuple[int, int] = (1, 1000),\n",
    "    cost_range: tuple[int, int] = (1, 10000),\n",
    "    total_supply: int = 1000,\n",
    "    verbose: bool = False,\n",
    "):\n",
    "    if n <= 0:\n",
    "        raise ValueError(\"n must be positive.\")\n",
    "    if K <= 0:\n",
    "        raise ValueError(\"K must be positive.\")\n",
    "    if tilde_C <= 0:\n",
    "        raise ValueError(\"tilde_C must be positive.\")\n",
    "    if sigma2 < 0:\n",
    "        raise ValueError(\"sigma2 must be nonnegative.\")\n",
    "    if family not in {\"sparse\", \"dense\"}:\n",
    "        raise ValueError(\"family must be 'sparse' or 'dense'.\")\n",
    "\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # --- NETGEN sizes ---\n",
    "    S_cnt = int(math.isqrt(n))                 # floor(sqrt(n))\n",
    "    T_cnt = int(math.isqrt(n))                 # floor(sqrt(n))\n",
    "    if family == \"sparse\":\n",
    "        m_target = 8 * n\n",
    "    else:\n",
    "        m_target = int(math.ceil(n * math.sqrt(n)))\n",
    "\n",
    "    # --- generate NETGEN instance (graph + a feasible b we will discard) ---\n",
    "    u, v, cap, cost, b0 = netgen_instance_arrays(\n",
    "        nodes=n,\n",
    "        arcs=m_target,\n",
    "        sources=S_cnt,\n",
    "        sinks=T_cnt,\n",
    "        cost_bounds=cost_range,\n",
    "        cap_bounds=cap_range,\n",
    "        total_supply=total_supply,\n",
    "        seed=None if seed is None else int(seed),\n",
    "    )\n",
    "\n",
    "    # Identify supply/demand node sets from NETGEN's b (we discard magnitudes, keep sets)\n",
    "    S_nodes = np.where(b0 > 0)[0]\n",
    "    T_nodes = np.where(b0 < 0)[0]\n",
    "\n",
    "    if len(S_nodes) != S_cnt or len(T_nodes) != T_cnt:\n",
    "        raise RuntimeError(\n",
    "            f\"NETGEN returned |S|={len(S_nodes)}, |T|={len(T_nodes)} but expected \"\n",
    "            f\"{S_cnt} and {T_cnt}. (Check generator settings.)\"\n",
    "        )\n",
    "\n",
    "    # --- build augmented network with super-source s and super-sink t ---\n",
    "    s = n\n",
    "    t = n + 1\n",
    "    N_aug = n + 2\n",
    "\n",
    "    # \"infinite\" capacity for super arcs: big enough not to bind\n",
    "    INF = float(np.sum(cap, dtype=np.int64) + 1)\n",
    "\n",
    "    # original arcs\n",
    "    src_aug = [u.astype(np.int64)]\n",
    "    dst_aug = [v.astype(np.int64)]\n",
    "    cap_aug = [cap.astype(np.float64)]\n",
    "\n",
    "    # s -> v for v in S\n",
    "    src_s = np.full(len(S_nodes), s, dtype=np.int64)\n",
    "    dst_s = S_nodes.astype(np.int64)\n",
    "    cap_s = np.full(len(S_nodes), INF, dtype=np.float64)\n",
    "\n",
    "    # v -> t for v in T\n",
    "    src_t = T_nodes.astype(np.int64)\n",
    "    dst_t = np.full(len(T_nodes), t, dtype=np.int64)\n",
    "    cap_t = np.full(len(T_nodes), INF, dtype=np.float64)\n",
    "\n",
    "    src_aug = np.concatenate(src_aug + [src_s, src_t])\n",
    "    dst_aug = np.concatenate(dst_aug + [dst_s, dst_t])\n",
    "    cap_aug = np.concatenate(cap_aug + [cap_s, cap_t])\n",
    "\n",
    "    # Remember arc indices for extracting b^(k)\n",
    "    m0 = len(u)\n",
    "    idx_s_to_v = {int(vv): m0 + i for i, vv in enumerate(S_nodes)}\n",
    "    idx_v_to_t = {int(vv): m0 + len(S_nodes) + i for i, vv in enumerate(T_nodes)}\n",
    "\n",
    "    # --- compute F_max via max s-t flow on augmented network ---\n",
    "    out_mf = lemon_mcf.max_flow(N_aug, src_aug, dst_aug, cap_aug, s, t)\n",
    "    F_max = int(out_mf[\"value\"])\n",
    "    if verbose:\n",
    "        print(f\"[augmented] F_max = {F_max}\")\n",
    "\n",
    "    # --- generate K balances and labels ---\n",
    "    B = np.zeros((K, n), dtype=np.int64)\n",
    "    z = np.zeros(K, dtype=np.float64)\n",
    "    y = np.zeros(K, dtype=np.float64)\n",
    "    F_targets = np.zeros(K, dtype=np.int64)\n",
    "\n",
    "    for k in range(K):\n",
    "        # sample target value F^(k)\n",
    "        Fk = int(rng.integers(0, F_max + 1))\n",
    "        F_targets[k] = Fk\n",
    "\n",
    "        # sample auxiliary integer costs on ALL augmented arcs\n",
    "        aux_cost = rng.integers(1, tilde_C + 1, size=len(src_aug), dtype=np.int64).astype(np.float64)\n",
    "\n",
    "        # min-cost s-t flow of value Fk via supplies\n",
    "        supply_aug = np.zeros(N_aug, dtype=np.float64)\n",
    "        supply_aug[s] = float(Fk)\n",
    "        supply_aug[t] = -float(Fk)\n",
    "\n",
    "        out_aux = lemon_mcf.solve_mcf(\n",
    "            N_aug,\n",
    "            src_aug,\n",
    "            dst_aug,\n",
    "            aux_cost,\n",
    "            cap_aug,\n",
    "            supply_aug,\n",
    "        )\n",
    "        if out_aux[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"Aux min-cost flow failed at k={k} with status={out_aux['status']}.\")\n",
    "\n",
    "        flow_aug = out_aux[\"flow\"]  # float array per arc\n",
    "\n",
    "        # build balance b^(k) on original nodes\n",
    "        bk = np.zeros(n, dtype=np.int64)\n",
    "        for vv in S_nodes:\n",
    "            bk[int(vv)] = int(round(float(flow_aug[idx_s_to_v[int(vv)]])))\n",
    "        for vv in T_nodes:\n",
    "            bk[int(vv)] = -int(round(float(flow_aug[idx_v_to_t[int(vv)]])))\n",
    "\n",
    "        B[k] = bk\n",
    "\n",
    "        # solve original min-cost flow to get z^(k)\n",
    "        out_orig = lemon_mcf.solve_mcf(\n",
    "            n,\n",
    "            u.astype(np.int64),\n",
    "            v.astype(np.int64),\n",
    "            cost.astype(np.float64),\n",
    "            cap.astype(np.float64),\n",
    "            bk.astype(np.float64),\n",
    "        )\n",
    "        if out_orig[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"Original min-cost flow failed at k={k} with status={out_orig['status']}.\")\n",
    "\n",
    "        zk = float(out_orig[\"total_cost\"])\n",
    "        z[k] = zk\n",
    "\n",
    "        # noisy label\n",
    "        eps = float(rng.normal(0.0, math.sqrt(sigma2)))\n",
    "        y[k] = zk + eps\n",
    "\n",
    "    return {\n",
    "        # graph\n",
    "        \"n\": n,\n",
    "        \"m\": len(u),\n",
    "        \"u\": u.astype(np.int64),\n",
    "        \"v\": v.astype(np.int64),\n",
    "        \"cap\": cap.astype(np.int64),\n",
    "        \"cost\": cost.astype(np.int64),\n",
    "        \"S_nodes\": S_nodes.astype(np.int64),\n",
    "        \"T_nodes\": T_nodes.astype(np.int64),\n",
    "\n",
    "        # dataset (quadratic-like)\n",
    "        \"X_samples\": B,     # balances b^(k)\n",
    "        \"y\": y,             # noisy labels\n",
    "\n",
    "        # extras (often useful)\n",
    "        \"z\": z,             # clean optimal costs\n",
    "        \"F_max\": F_max,\n",
    "        \"F_targets\": F_targets,\n",
    "        \"tilde_C\": tilde_C,\n",
    "        \"sigma2\": float(sigma2),\n",
    "        \"seed\": seed,\n",
    "        \"family\": family,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "891009e5-34d9-4d9e-a62e-ef7321e4ddc1",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# ---- Example of Usage ---- #\n",
    "\n",
    "data = generate_netgen_mcf_dataset(\n",
    "    n=200,\n",
    "    family=\"sparse\",\n",
    "    K=100,\n",
    "    tilde_C=100,\n",
    "    sigma2=0.25,\n",
    "    seed=0,\n",
    "    verbose=True,\n",
    ")\n",
    "\n",
    "print(\"X_samples shape:\", data[\"X_samples\"].shape)  # (K, n)\n",
    "print(\"y shape:\", data[\"y\"].shape)                  # (K,)\n",
    "print(\"F_max:\", data[\"F_max\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da97f236-a558-46cc-9803-8ce740c9c2a0",
   "metadata": {},
   "source": [
    "#### Full Dataset Generator "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06e9b3dc-4978-4afe-a58b-b58f2c5c505a",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e1cdd06e-81d1-4f27-a39b-81d540a21241",
   "metadata": {},
   "source": [
    "# Models Definition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b131da93-331f-46ef-91e4-1ced802f0e27",
   "metadata": {},
   "source": [
    "### DFN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d617ce9-24ab-4015-9fd0-6bc983d61fab",
   "metadata": {},
   "source": [
    "#### Helper Functions and Modules"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07be67a0-5e8e-4e5f-9e2d-5acda3ab83be",
   "metadata": {},
   "source": [
    "STE op:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d7ccba7-9593-421d-a6ef-778b9dd9c437",
   "metadata": {},
   "outputs": [],
   "source": [
    "def round_param(raw: torch.Tensor, step: float = 1.0, nonneg: bool = False, eps: float = 0.0) -> torch.Tensor:\n",
    "    x = F.softplus(raw) + eps if nonneg else raw\n",
    "    y = torch.round(x / step) * step\n",
    "    return x + (y - x).detach()\n",
    "\n",
    "# def round_param(raw: torch.Tensor, step: float = 1.0, nonneg: bool = False, eps: float = 0.0) -> torch.Tensor:\n",
    "#     x = F.softplus(raw) + eps if nonneg else raw\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab29e386-ed1c-4bff-aae5-830ec56acfa2",
   "metadata": {},
   "source": [
    "Min Cost Flow op:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2da3d4c2-7972-4d9e-8167-722a22f7a409",
   "metadata": {},
   "source": [
    "**Signature**  \n",
    "    z = min_cost_flow_value(n_nodes, src, dst, cost, cap, supply, solver=\"lemon\")\n",
    "\n",
    "## Inputs\n",
    "- `n_nodes` : `int` — number of nodes (indexed `0..n_nodes-1`)\n",
    "- `src`, `dst` : `torch.Tensor` shape `(m,)`, integer/long — arc endpoints `src[e] -> dst[e]`\n",
    "- `cost` : `torch.Tensor` shape `(m,)`, float — per-arc cost\n",
    "- `cap` : `torch.Tensor` shape `(m,)`, float — per-arc capacity (upper bound)\n",
    "- `supply` : `torch.Tensor` shape `(n_nodes,)`, float — node flow balance (`>0` supply, `<0` demand)\n",
    "- `solver` : `str` — `\"lemon\"` or `\"gurobi\"`\n",
    "\n",
    "## Output\n",
    "- `z` : `torch.Tensor` scalar — optimal objective value of the min-cost flow\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "980806c5-11ab-4527-b164-a88981fca9ff",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "_TOL = 1e-9\n",
    "\n",
    "def _solve_lemon(n, src, dst, cost, cap, supply):\n",
    "    if lemon_mcf is None:\n",
    "        raise RuntimeError(\"LEMON requested but lemon_mcf is not available.\")\n",
    "    out = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply, tol=_TOL)\n",
    "    if out[\"status\"] != 1:\n",
    "        raise RuntimeError(f\"LEMON failed (status={out['status']})\")\n",
    "    flow = out[\"flow\"]\n",
    "    pot  = out[\"potential\"]\n",
    "    red  = out[\"reduced_cost\"]\n",
    "    at   = out.get(\"at_cap\", out.get(\"at_capacity\", None))\n",
    "    if at is None:\n",
    "        at = np.abs(flow - cap) <= _TOL\n",
    "    return float(out[\"total_cost\"]), flow, pot, red, at\n",
    "\n",
    "\n",
    "def _solve_gurobi(n, src, dst, cost, cap, supply):\n",
    "    try:\n",
    "        import gurobipy as gp\n",
    "        from gurobipy import GRB\n",
    "    except Exception as e:\n",
    "        raise ImportError(\"gurobipy is required (with a valid license).\") from e\n",
    "\n",
    "    m = int(src.size)\n",
    "    out_idx = [[] for _ in range(n)]\n",
    "    in_idx  = [[] for _ in range(n)]\n",
    "    for k in range(m):\n",
    "        out_idx[int(src[k])].append(k)\n",
    "        in_idx[int(dst[k])].append(k)\n",
    "\n",
    "    model = gp.Model()\n",
    "    model.Params.OutputFlag = 0\n",
    "    model.Params.LogToConsole = 0\n",
    "\n",
    "    x = model.addVars(m, lb=0.0, ub=cap.tolist(), obj=cost.tolist(), vtype=GRB.CONTINUOUS, name=\"x\")\n",
    "    model.ModelSense = GRB.MINIMIZE\n",
    "\n",
    "    bal = []\n",
    "    for i in range(n):\n",
    "        bal.append(model.addConstr(\n",
    "            gp.quicksum(x[k] for k in out_idx[i]) - gp.quicksum(x[k] for k in in_idx[i]) == float(supply[i])\n",
    "        ))\n",
    "\n",
    "    model.optimize()\n",
    "    if model.Status != GRB.OPTIMAL:\n",
    "        raise RuntimeError(f\"Gurobi failed (status={model.Status})\")\n",
    "\n",
    "    flow = np.fromiter((x[k].X  for k in range(m)), dtype=np.float64, count=m)\n",
    "    red  = np.fromiter((x[k].RC for k in range(m)), dtype=np.float64, count=m)\n",
    "    pot  = -np.fromiter((bal[i].Pi for i in range(n)), dtype=np.float64, count=n)  # match LEMON convention\n",
    "    at   = np.abs(flow - cap) <= _TOL\n",
    "    return float(model.ObjVal), flow, pot, red, at\n",
    "\n",
    "\n",
    "class _MCFValue(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, n_nodes, src, dst, cost, cap, supply, solver=\"lemon\"):\n",
    "        n = int(n_nodes)\n",
    "        src = src.long(); dst = dst.long()\n",
    "        m = int(src.numel())\n",
    "        if dst.numel() != m or cost.numel() != m or cap.numel() != m or supply.numel() != n:\n",
    "            raise ValueError(\"bad shapes\")\n",
    "\n",
    "        src_np = src.detach().cpu().contiguous().reshape(-1).numpy().astype(np.int64, copy=False)\n",
    "        dst_np = dst.detach().cpu().contiguous().reshape(-1).numpy().astype(np.int64, copy=False)\n",
    "        c_np   = cost.detach().cpu().contiguous().reshape(-1).numpy().astype(np.float64, copy=False)\n",
    "        u_np   = cap.detach().cpu().contiguous().reshape(-1).numpy().astype(np.float64, copy=False)\n",
    "        b_np   = supply.detach().cpu().contiguous().reshape(-1).numpy().astype(np.float64, copy=False)\n",
    "\n",
    "        if abs(float(b_np.sum())) > _TOL:\n",
    "            raise ValueError(f\"require sum(supply)=0 (got {float(b_np.sum()):.3e})\")\n",
    "        if m and ((src_np < 0).any() or (src_np >= n).any() or (dst_np < 0).any() or (dst_np >= n).any()):\n",
    "            raise ValueError(\"src/dst out of range\")\n",
    "        if m and (u_np < 0).any():\n",
    "            raise ValueError(\"cap must be nonnegative\")\n",
    "\n",
    "        s = str(solver).lower()\n",
    "        if s == \"lemon\":\n",
    "            obj, flow, pot, red, at = _solve_lemon(n, src_np, dst_np, c_np, u_np, b_np)\n",
    "        elif s == \"gurobi\":\n",
    "            obj, flow, pot, red, at = _solve_gurobi(n, src_np, dst_np, c_np, u_np, b_np)\n",
    "        else:\n",
    "            raise ValueError(f\"unknown solver: {solver}\")\n",
    "\n",
    "        ctx.flow, ctx.pot, ctx.red, ctx.at = flow, pot, red, at\n",
    "        return cost.new_tensor(obj)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g):\n",
    "        dev, dt = g.device, g.dtype\n",
    "        flow = torch.as_tensor(ctx.flow, device=dev, dtype=dt)\n",
    "        pot  = torch.as_tensor(ctx.pot,  device=dev, dtype=dt)\n",
    "        red  = torch.as_tensor(ctx.red,  device=dev, dtype=dt)\n",
    "        at   = torch.as_tensor(ctx.at,   device=dev, dtype=torch.bool)\n",
    "\n",
    "        grad_cost = flow\n",
    "        grad_cap  = torch.where(at, red, torch.zeros_like(red))\n",
    "        grad_sup  = pot.mean() - pot  # gauge-fixed (-dual) gradient\n",
    "\n",
    "        return None, None, None, grad_cost * g, grad_cap * g, grad_sup * g, None\n",
    "\n",
    "\n",
    "def min_cost_flow_value(n_nodes, src, dst, cost, cap, supply, solver=\"lemon\"):\n",
    "    return _MCFValue.apply(n_nodes, src, dst, cost, cap, supply, solver)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49e1537e-0608-496b-8cfd-7fd4ba19a910",
   "metadata": {},
   "source": [
    "Network Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90ebc084-4dde-4fb5-b606-f1190431ac95",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "@dataclass(frozen=True)\n",
    "class Network:\n",
    "    n: int\n",
    "    src_param: torch.Tensor\n",
    "    dst_param: torch.Tensor\n",
    "    ax_nodes: torch.Tensor\n",
    "    b_nodes: torch.Tensor\n",
    "    fix_node: int\n",
    "    src_fixed: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.long))\n",
    "    dst_fixed: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.long))\n",
    "    cap_fixed: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float32))\n",
    "    cost_fixed: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed2502cb-6f96-4e64-b75a-aa085e94a1dd",
   "metadata": {},
   "source": [
    "Multi-Layered Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39851bc6-86be-45cd-8d14-f121dda84f9c",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def build_layered_network(\n",
    "    layer_sizes: list[int],\n",
    "    n_ax: int | None = None,\n",
    "    n_b: int | None = None,\n",
    "    big_cost: float = 1e3,\n",
    "    big_cap: float = 1e9,\n",
    "    p: float = 1.0,          # keep-prob for each internal (layered) arc\n",
    "    seed: int | None = None, # for reproducible sparsification\n",
    ") -> Network:\n",
    "    if not (0.0 <= float(p) <= 1.0):\n",
    "        raise ValueError(\"p must be in [0,1].\")\n",
    "    gen = None\n",
    "    if seed is not None:\n",
    "        gen = torch.Generator(device=\"cpu\").manual_seed(int(seed))\n",
    "\n",
    "    layer_sizes = list(map(int, layer_sizes))\n",
    "    K = len(layer_sizes)\n",
    "    if K < 2: raise ValueError(\"need at least 2 layers\")\n",
    "    n = sum(layer_sizes)\n",
    "    if n <= 0: raise ValueError(\"sum(layer_sizes) must be > 0\")\n",
    "\n",
    "    layers, off = [], 0\n",
    "    for s in layer_sizes:\n",
    "        if s < 0: raise ValueError(\"layer_sizes must be nonnegative\")\n",
    "        layers.append(torch.arange(off, off + s, dtype=torch.long))\n",
    "        off += s\n",
    "    L1, LK = layers[0], layers[-1]\n",
    "    s1, sK = int(L1.numel()), int(LK.numel())\n",
    "    if s1 == 0 or sK == 0: raise ValueError(\"L1 and LK must be non-empty\")\n",
    "\n",
    "    # boundary split: default half/half inside each of L1 and LK\n",
    "    boundary = s1 + sK\n",
    "    if n_ax is None and n_b is None:\n",
    "        v1, vK = s1 // 2, sK // 2\n",
    "        c1, cK = s1 - v1, sK - vK\n",
    "    else:\n",
    "        n_ax = boundary - int(n_b) if n_ax is None else int(n_ax)\n",
    "        n_b  = boundary - int(n_ax) if n_b  is None else int(n_b)\n",
    "        if n_ax < 0 or n_b < 0 or n_ax + n_b > boundary: raise ValueError(\"bad n_ax/n_b\")\n",
    "        v1 = min(n_ax, s1 // 2); vK = n_ax - v1\n",
    "        c1 = min(n_b,  s1 - v1); cK = n_b  - c1\n",
    "        if vK > sK or vK + cK > sK: raise ValueError(\"n_ax/n_b too large for LK\")\n",
    "\n",
    "    ax_nodes = torch.cat([L1[:v1], LK[:vK]], 0)\n",
    "    b_nodes  = torch.cat([L1[v1:v1+c1], LK[vK:vK+cK]], 0)\n",
    "    fix_node = int(LK[-1].item())\n",
    "\n",
    "    def bip(U, V):\n",
    "        su, sv = int(U.numel()), int(V.numel())\n",
    "        return U.repeat_interleave(sv), V.repeat(su)\n",
    "\n",
    "    # learnable layered arcs Li -> L(i+1), sparsified with keep-prob p\n",
    "    srcL, dstL = [], []\n",
    "    for i in range(K - 1):\n",
    "        s, t = bip(layers[i], layers[i + 1])\n",
    "        if p < 1.0:\n",
    "            keep = (torch.rand(s.numel(), generator=gen) < p)\n",
    "            s, t = s[keep], t[keep]\n",
    "        srcL.append(s); dstL.append(t)\n",
    "\n",
    "    src_param = torch.cat(srcL, 0)\n",
    "    dst_param = torch.cat(dstL, 0)\n",
    "    if src_param.numel() == 0:\n",
    "        raise ValueError(\"no layered arcs after sparsification (increase p or layer sizes)\")\n",
    "\n",
    "    # fixed big arcs L1 <-> LK (still full)\n",
    "    s1a, t1a = bip(L1, LK)\n",
    "    s2a, t2a = bip(LK, L1)\n",
    "    src_fixed = torch.cat([s1a, s2a], 0)\n",
    "    dst_fixed = torch.cat([t1a, t2a], 0)\n",
    "    m = int(src_fixed.numel())\n",
    "    cap_fixed  = torch.full((m,), float(big_cap),  dtype=torch.float32)\n",
    "    cost_fixed = torch.full((m,), float(big_cost), dtype=torch.float32)\n",
    "\n",
    "    return Network(\n",
    "        n=n, src_param=src_param, dst_param=dst_param,\n",
    "        ax_nodes=ax_nodes, b_nodes=b_nodes, fix_node=fix_node,\n",
    "        src_fixed=src_fixed, dst_fixed=dst_fixed,\n",
    "        cap_fixed=cap_fixed, cost_fixed=cost_fixed\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58780994-c27b-4dd9-ba2e-865cb4b95c60",
   "metadata": {},
   "source": [
    "Complete Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19f2419e-da64-4739-b344-fc8eb115ad0a",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "\n",
    "def build_complete_doubled_network(\n",
    "    n: int,\n",
    "    n_ax: int | None = None,\n",
    "    n_b: int | None = None,\n",
    "    *,\n",
    "    fix_node: int | None = None,\n",
    "    big_cost: float = 1e3,\n",
    "    big_cap: float = 1e9,\n",
    ") -> Network:\n",
    "    n = int(n)\n",
    "    if n < 2: raise ValueError(\"n must be >= 2\")\n",
    "\n",
    "    k = int(math.isqrt(n))\n",
    "    n_ax = k if n_ax is None else int(n_ax)\n",
    "    n_b  = k if n_b  is None else int(n_b)\n",
    "\n",
    "    fix_node = (n - 1) if fix_node is None else int(fix_node)\n",
    "    if not (0 <= fix_node < n): raise ValueError(\"fix_node out of range\")\n",
    "    if n_ax < 0 or n_b < 0 or (n_ax + n_b) > (n - 1):\n",
    "        raise ValueError(\"need 0<=n_ax,n_b and n_ax+n_b <= n-1 (excluding fix_node)\")\n",
    "\n",
    "    # complete directed arcs i->j, i!=j\n",
    "    v = torch.arange(n, dtype=torch.long)\n",
    "    src = v.repeat_interleave(n)\n",
    "    dst = v.repeat(n)\n",
    "    mask = (src != dst)\n",
    "    src = src[mask]; dst = dst[mask]               # m = n*(n-1)\n",
    "\n",
    "    # learnable copy\n",
    "    src_param, dst_param = src, dst\n",
    "\n",
    "    # fixed big copy (parallel arcs)\n",
    "    src_fixed, dst_fixed = src.clone(), dst.clone()\n",
    "    m = int(src_fixed.numel())\n",
    "    cap_fixed  = torch.full((m,), float(big_cap),  dtype=torch.float32)\n",
    "    cost_fixed = torch.full((m,), float(big_cost), dtype=torch.float32)\n",
    "\n",
    "    # choose ax/b nodes deterministically in order (excluding fix_node)\n",
    "    nodes = torch.arange(n, dtype=torch.long)\n",
    "    avail = nodes[nodes != fix_node]\n",
    "    ax_nodes = avail[:n_ax].clone()\n",
    "    b_nodes  = avail[n_ax:n_ax + n_b].clone()\n",
    "\n",
    "    return Network(\n",
    "        n=n,\n",
    "        src_param=src_param, dst_param=dst_param,\n",
    "        ax_nodes=ax_nodes, b_nodes=b_nodes,\n",
    "        fix_node=fix_node,\n",
    "        src_fixed=src_fixed, dst_fixed=dst_fixed,\n",
    "        cap_fixed=cap_fixed, cost_fixed=cost_fixed,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c13547-8c12-4413-a9d3-368d015cd778",
   "metadata": {},
   "source": [
    "#### The DFN Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bd3916f-ee07-4ef5-8b4e-05ba9438ee93",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DFN(nn.Module):\n",
    "    def __init__(self, net: Network, input_dim: int, solver: str = \"lemon\"):\n",
    "        super().__init__()\n",
    "        self.n = int(net.n)\n",
    "        self.fix_node = int(net.fix_node)\n",
    "        self.solver = str(solver)\n",
    "\n",
    "        # store network data inside the module (so .to(device) works, and checkpoints keep the graph)\n",
    "        self.register_buffer(\"src_param\", net.src_param.long())\n",
    "        self.register_buffer(\"dst_param\", net.dst_param.long())\n",
    "        self.register_buffer(\"src_fixed\", net.src_fixed.long())\n",
    "        self.register_buffer(\"dst_fixed\", net.dst_fixed.long())\n",
    "        self.register_buffer(\"cap_fixed\", net.cap_fixed.float())\n",
    "        self.register_buffer(\"cost_fixed\", net.cost_fixed.float())\n",
    "        self.register_buffer(\"ax_nodes\",  net.ax_nodes.long())\n",
    "        self.register_buffer(\"b_nodes\",   net.b_nodes.long())\n",
    "\n",
    "        m_param = int(self.src_param.numel())\n",
    "        n_ax    = int(self.ax_nodes.numel())\n",
    "        n_b     = int(self.b_nodes.numel())\n",
    "\n",
    "        self.cap_raw  = nn.Parameter(torch.randn(m_param) + 1)\n",
    "        self.cost_raw = nn.Parameter(torch.randn(m_param) + 1)\n",
    "        self.A_raw    = nn.Parameter(torch.randn(n_ax, input_dim) + 1)\n",
    "        self.b_raw    = nn.Parameter(torch.randn(n_b) + 1)\n",
    "\n",
    "    def forward(self, x: torch.Tensor, alpha: float = 1.0, beta: float = 0.0) -> torch.Tensor:\n",
    "        capP  = round_param(self.cap_raw,  nonneg=True)\n",
    "        costP = round_param(self.cost_raw, nonneg=True)\n",
    "        A     = round_param(self.A_raw,    nonneg=False)\n",
    "        b0    = round_param(self.b_raw,    nonneg=False)\n",
    "\n",
    "        def solve_one(x1: torch.Tensor):\n",
    "            src  = torch.cat([self.src_param, self.src_fixed], 0)\n",
    "            dst  = torch.cat([self.dst_param, self.dst_fixed], 0)\n",
    "            cap  = torch.cat([capP,  self.cap_fixed.to(x1.device, x1.dtype)], 0)\n",
    "            cost = torch.cat([costP, self.cost_fixed.to(x1.device, x1.dtype)], 0)\n",
    "\n",
    "            supply = torch.zeros(self.n, device=x1.device, dtype=torch.float64)\n",
    "            supply[self.ax_nodes] = (A.double() @ x1.double())\n",
    "            supply[self.b_nodes] = b0.double()\n",
    "            supply[self.fix_node] -= supply.sum()\n",
    "\n",
    "            obj = min_cost_flow_value(self.n, src, dst, cost, cap, supply, solver=self.solver)\n",
    "            return alpha * obj + beta\n",
    "\n",
    "        return solve_one(x) if x.dim() == 1 else torch.stack([solve_one(xi) for xi in x], 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb0c0c1d-e1d6-4259-ad3b-55cd5b258432",
   "metadata": {},
   "source": [
    "### Gradient Boosted Trees"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72c356e0-d483-442b-a32b-49f0d7eab840",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df43dee6-8978-44d8-a576-cceea5fa46a1",
   "metadata": {},
   "source": [
    "### ICNN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e74105b4-7d14-4b4d-9261-160fc4b44827",
   "metadata": {},
   "source": [
    "### Max-Affine"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8a3d912-fa0a-4acc-a12f-10a089f213de",
   "metadata": {},
   "source": [
    "# Training on Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "609e32f6-7edd-40dd-9764-e06a592d0e73",
   "metadata": {},
   "source": [
    "# Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88a70431-932d-40db-8ad1-d21fd24d60ec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4f8af971-80e0-4757-8971-f6020f3aa07b",
   "metadata": {},
   "source": [
    "# A TEST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f678282e-6e0f-4536-962b-e5d3836f562d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "@torch.no_grad()\n",
    "def r2_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:\n",
    "    y_true = y_true.view(-1); y_pred = y_pred.view(-1)\n",
    "    sse = ((y_true - y_pred)**2).sum()\n",
    "    sst = ((y_true - y_true.mean())**2).sum()\n",
    "    return float((1.0 - sse/(sst + 1e-12)).item())\n",
    "\n",
    "@torch.no_grad()\n",
    "def mse_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:\n",
    "    return float(F.mse_loss(y_pred, y_true).item())\n",
    "\n",
    "def make_int_quadratic_data(N=1000, d=10, x_min=-50, x_max=50, noise=0.05, seed=0, device=\"cpu\"):\n",
    "    g = torch.Generator(device=\"cpu\").manual_seed(int(seed))\n",
    "    X = torch.randint(int(x_min), int(x_max), (int(N), int(d)), generator=g).to(device=device, dtype=torch.float32)\n",
    "\n",
    "    R = torch.randn(d, d, generator=g).to(device)\n",
    "    Q = R.T @ R + 0.2 * torch.eye(d, device=device)\n",
    "    lin = torch.randn(d, generator=g).to(device)\n",
    "\n",
    "    y = (X @ Q * X).sum(dim=1) + X @ lin\n",
    "    if noise:\n",
    "        y = y + float(noise) * torch.randn_like(y)\n",
    "\n",
    "    y = (y - y.mean()) / (y.std() + 1e-8)\n",
    "    return X, y\n",
    "\n",
    "\n",
    "def compare_dfn_mlp_live(\n",
    "    *,\n",
    "    # data\n",
    "    N=1000, d=10, x_min=-50, x_max=50, noise=0.05, seed=1,\n",
    "    train_frac=0.7, val_frac=0.15, test_frac=0.15,\n",
    "    normalize_X=False,\n",
    "    device=None,\n",
    "\n",
    "    # choose graph\n",
    "    graph=\"layered\",\n",
    "    layer_sizes=(32,32,32),\n",
    "    n_nodes=None,\n",
    "    n_ax=None, n_b=None,\n",
    "    big_cost=1e3, big_cap=1e9,\n",
    "    p=1.0, p_seed=None,\n",
    "\n",
    "    # DFN output affine + solver\n",
    "    solver=\"lemon\", alpha=1e-5, beta=-5.0,\n",
    "\n",
    "    # training\n",
    "    steps=600, B=16, eval_every=10,\n",
    "    lr_cost=1e-1, lr_cap=1e-1, lr_A=1e-1, lr_b=1e-1,\n",
    "\n",
    "    # MLP lr controls\n",
    "    lr_mlp=1e-3,\n",
    "    lr_mlp_W=None,\n",
    "    lr_mlp_bias=None,\n",
    "    mlp_hidden=(256,256),\n",
    "\n",
    "    # init (mean,std)\n",
    "    init_cost=(1.0,0.3), init_cap=(3.0,0.3),\n",
    "    init_A=(1.0,0.3),    init_b=(1.0,0.3),\n",
    "\n",
    "    # plots\n",
    "    scatter_n=600,\n",
    "    track_time=True,\n",
    "    track_grads=True,  # NOTE: now tracks PARAM UPDATES (mean |Δparam|), not gradients\n",
    "):\n",
    "    import sys, time, numpy as np\n",
    "    import torch\n",
    "    import torch.nn as nn\n",
    "    import torch.nn.functional as F\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    # ---------------- helpers: PARAM UPDATE MAGNITUDES ----------------\n",
    "    @torch.no_grad()\n",
    "    def mean_abs_delta(p_after: torch.Tensor, p_before: torch.Tensor) -> float:\n",
    "        return float((p_after - p_before).abs().mean().item())\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def weighted_mean_abs_delta_list(params, params_before) -> float:\n",
    "        \"\"\"Element-weighted mean |Δ| across a list of tensors (big layers count more).\"\"\"\n",
    "        if not params:\n",
    "            return 0.0\n",
    "        tot = torch.zeros((), device=params[0].device)\n",
    "        n = 0\n",
    "        for p, p0 in zip(params, params_before):\n",
    "            tot += (p - p0).abs().sum()\n",
    "            n += p.numel()\n",
    "        return float((tot / max(n, 1)).item())\n",
    "\n",
    "    # ---- device safety ----\n",
    "    if device is None:\n",
    "        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "    else:\n",
    "        device = \"cuda\" if (str(device) == \"cuda\" and torch.cuda.is_available()) else \"cpu\"\n",
    "\n",
    "    if abs((train_frac + val_frac + test_frac) - 1.0) > 1e-9:\n",
    "        raise ValueError(\"train_frac + val_frac + test_frac must sum to 1.0\")\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    # ---- data ----\n",
    "    X, y = make_int_quadratic_data(N=N, d=d, x_min=x_min, x_max=x_max, noise=noise, seed=seed, device=device)\n",
    "    perm = torch.randperm(N, device=device)\n",
    "    ntr = int(train_frac * N)\n",
    "    nva = int(val_frac   * N)\n",
    "    tr, va, te = perm[:ntr], perm[ntr:ntr+nva], perm[ntr+nva:]\n",
    "\n",
    "    Xtr, ytr = X[tr], y[tr]\n",
    "    Xva, yva = X[va], y[va]\n",
    "    Xte, yte = X[te], y[te]\n",
    "\n",
    "    if normalize_X:\n",
    "        mu = Xtr.mean(0, keepdim=True)\n",
    "        sd = Xtr.std(0, keepdim=True) + 1e-8\n",
    "        Xtr = (Xtr - mu)/sd\n",
    "        Xva = (Xva - mu)/sd\n",
    "        Xte = (Xte - mu)/sd\n",
    "\n",
    "    # ---- graph ----\n",
    "    g = str(graph).lower()\n",
    "    if g == \"layered\":\n",
    "        net = build_layered_network(\n",
    "            layer_sizes=list(layer_sizes),\n",
    "            n_ax=n_ax, n_b=n_b,\n",
    "            big_cost=big_cost, big_cap=big_cap,\n",
    "            p=p, seed=p_seed\n",
    "        )\n",
    "    elif g == \"complete\":\n",
    "        n_nodes_ = int(n_nodes) if n_nodes is not None else int(sum(layer_sizes))\n",
    "        net = build_complete_doubled_network(\n",
    "            n=n_nodes_,\n",
    "            n_ax=n_ax, n_b=n_b,\n",
    "            big_cost=big_cost, big_cap=big_cap\n",
    "        )\n",
    "    else:\n",
    "        raise ValueError('graph must be \"layered\" or \"complete\"')\n",
    "\n",
    "    # ---- models ----\n",
    "    dfn = DFN(net, input_dim=d, solver=solver).to(device)\n",
    "\n",
    "    class _MLP(nn.Module):\n",
    "        def __init__(self, din, hidden):\n",
    "            super().__init__()\n",
    "            L, dd = [], din\n",
    "            for h in hidden:\n",
    "                L += [nn.Linear(dd, h), nn.ReLU()]\n",
    "                dd = h\n",
    "            L += [nn.Linear(dd, 1)]\n",
    "            self.net = nn.Sequential(*L)\n",
    "        def forward(self, x): return self.net(x).squeeze(-1)\n",
    "\n",
    "    mlp = _MLP(d, tuple(map(int, mlp_hidden))).to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dfn.cost_raw.normal_(*init_cost); dfn.cap_raw.normal_(*init_cap)\n",
    "        dfn.A_raw.normal_(*init_A);       dfn.b_raw.normal_(*init_b)\n",
    "\n",
    "    opt_dfn = torch.optim.Adam([\n",
    "        {\"params\":[dfn.cost_raw], \"lr\": lr_cost},\n",
    "        {\"params\":[dfn.cap_raw],  \"lr\": lr_cap},\n",
    "        {\"params\":[dfn.A_raw],    \"lr\": lr_A},\n",
    "        {\"params\":[dfn.b_raw],    \"lr\": lr_b},\n",
    "    ])\n",
    "\n",
    "    # ---- MLP optimizer with separate LR for W vs bias ----\n",
    "    lrW = float(lr_mlp if lr_mlp_W is None else lr_mlp_W)\n",
    "    lrB = float(lr_mlp if lr_mlp_bias is None else lr_mlp_bias)\n",
    "    W_params, b_params = [], []\n",
    "    for m in mlp.modules():\n",
    "        if isinstance(m, nn.Linear):\n",
    "            W_params.append(m.weight)\n",
    "            if m.bias is not None:\n",
    "                b_params.append(m.bias)\n",
    "\n",
    "    opt_mlp = torch.optim.Adam([\n",
    "        {\"params\": W_params, \"lr\": lrW},\n",
    "        {\"params\": b_params, \"lr\": lrB},\n",
    "    ])\n",
    "    print(f\"[MLP lrs] lrW={lrW:g}  lrB={lrB:g}  (fallback lr_mlp={lr_mlp:g})\")\n",
    "\n",
    "    def sync():\n",
    "        if device == \"cuda\":\n",
    "            torch.cuda.synchronize()\n",
    "\n",
    "    # ---- live plots (MSE + gap + time) ----\n",
    "    fig, axes = plt.subplots(3 if track_time else 2, 1, figsize=(7, 9 if track_time else 6), sharex=True)\n",
    "    axL, axG = axes[0], axes[1]\n",
    "    axT = axes[2] if track_time else None\n",
    "\n",
    "    (ldt,) = axL.plot([], [], label=\"DFN train MSE\")\n",
    "    (ldv,) = axL.plot([], [], label=\"DFN val MSE\")\n",
    "    (lmt,) = axL.plot([], [], label=\"MLP train MSE\")\n",
    "    (lmv,) = axL.plot([], [], label=\"MLP val MSE\")\n",
    "    axL.set_title(f\"MSE (train/val) | graph={g}\")\n",
    "    axL.set_ylabel(\"MSE\")\n",
    "    axL.legend()\n",
    "\n",
    "    (gt,) = axG.plot([], [], label=\"gap train (DFN-MLP)\")\n",
    "    (gv,) = axG.plot([], [], label=\"gap val (DFN-MLP)\")\n",
    "    axG.axhline(0.0, lw=1)\n",
    "    axG.set_yscale(\"symlog\", linthresh=1e-6, linscale=1.0)\n",
    "    axG.set_ylabel(\"gap\")\n",
    "    axG.legend()\n",
    "    axG.set_xlabel(\"step\" if not track_time else \"\")\n",
    "\n",
    "    if track_time:\n",
    "        (td,) = axT.plot([], [], label=\"DFN train time (cum s)\")\n",
    "        (tm,) = axT.plot([], [], label=\"MLP train time (cum s)\")\n",
    "        axT.set_xlabel(\"step\")\n",
    "        axT.set_ylabel(\"seconds\")\n",
    "        axT.legend()\n",
    "\n",
    "    disp = None\n",
    "    if \"ipykernel\" in sys.modules:\n",
    "        from IPython.display import display\n",
    "        disp = display(fig, display_id=True)\n",
    "    else:\n",
    "        plt.ion(); plt.show(block=False)\n",
    "\n",
    "    # ---- live scatter (VAL only) ----\n",
    "    sc_n = min(int(scatter_n), Xva.shape[0])\n",
    "    sc_idx = torch.randperm(Xva.shape[0], device=device)[:sc_n]\n",
    "    Xsc, ysc = Xva[sc_idx], yva[sc_idx]\n",
    "    y_true_cpu = ysc.detach().cpu().numpy()\n",
    "    mn, mx = float(ysc.min().item()), float(ysc.max().item())\n",
    "\n",
    "    fig2, (axD, axM) = plt.subplots(1, 2, figsize=(10,4), sharex=True, sharey=True)\n",
    "    sd_sc = axD.scatter([], [], s=12); sm_sc = axM.scatter([], [], s=12)\n",
    "    for ax in (axD, axM):\n",
    "        ax.plot([mn, mx], [mn, mx], lw=1)\n",
    "        ax.set_xlim(mn, mx); ax.set_ylim(mn, mx)\n",
    "        ax.set_xlabel(\"y_true\")\n",
    "    axD.set_ylabel(\"y_pred\")\n",
    "\n",
    "    disp2 = None\n",
    "    if \"ipykernel\" in sys.modules:\n",
    "        from IPython.display import display\n",
    "        disp2 = display(fig2, display_id=True)\n",
    "    else:\n",
    "        plt.ion(); plt.show(block=False)\n",
    "\n",
    "    # ---- PARAM UPDATE live plots: DFN and MLP SEPARATE ----\n",
    "    if track_grads:\n",
    "        # DFN updates figure\n",
    "        figUd, axUd = plt.subplots(1, 1, figsize=(8, 4), sharex=True)\n",
    "        (u_cost,) = axUd.plot([], [], label=\"mean|Δcost|\")\n",
    "        (u_cap,)  = axUd.plot([], [], label=\"mean|Δcap|\")\n",
    "        (u_A,)    = axUd.plot([], [], label=\"mean|ΔA|\")\n",
    "        (u_b,)    = axUd.plot([], [], label=\"mean|Δb|\")\n",
    "        axUd.set_yscale(\"symlog\", linthresh=1e-12)\n",
    "        axUd.set_title(\"DFN parameter update size (mean |Δparam|)\")\n",
    "        axUd.set_xlabel(\"step\")\n",
    "        axUd.set_ylabel(\"mean |Δ|\")\n",
    "        axUd.legend(loc=\"upper right\")\n",
    "\n",
    "        # MLP updates figure\n",
    "        figUm, axUm = plt.subplots(1, 1, figsize=(8, 4), sharex=True)\n",
    "        (u_W,)    = axUm.plot([], [], label=\"mean|ΔW|\")\n",
    "        (u_bias,) = axUm.plot([], [], label=\"mean|Δbias|\")\n",
    "        axUm.set_yscale(\"symlog\", linthresh=1e-12)\n",
    "        axUm.set_title(\"MLP parameter update size (mean |Δparam|)\")\n",
    "        axUm.set_xlabel(\"step\")\n",
    "        axUm.set_ylabel(\"mean |Δ|\")\n",
    "        axUm.legend(loc=\"upper right\")\n",
    "\n",
    "        dispUd = dispUm = None\n",
    "        if \"ipykernel\" in sys.modules:\n",
    "            from IPython.display import display\n",
    "            dispUd = display(figUd, display_id=True)\n",
    "            dispUm = display(figUm, display_id=True)\n",
    "        else:\n",
    "            plt.ion()\n",
    "            plt.show(block=False)\n",
    "    else:\n",
    "        figUd = axUd = dispUd = None\n",
    "        figUm = axUm = dispUm = None\n",
    "        u_cost = u_cap = u_A = u_b = None\n",
    "        u_W = u_bias = None\n",
    "\n",
    "    # ---- histories + best checkpoints ----\n",
    "    S = []\n",
    "    dtr, dva, mtr, mva = [], [], [], []\n",
    "    Td_hist, Tm_hist = [], []\n",
    "\n",
    "    # update histories (mean |Δ|)\n",
    "    U_cost, U_cap, U_A, U_b, U_W, U_bias = [], [], [], [], [], []\n",
    "\n",
    "    best_d_val = float(\"inf\"); best_m_val = float(\"inf\")\n",
    "    best_d_step = -1;          best_m_step = -1\n",
    "    best_d_state = None;       best_m_state = None\n",
    "\n",
    "    t_dfn = 0.0\n",
    "    t_mlp = 0.0\n",
    "\n",
    "    for t in range(int(steps)):\n",
    "        idx = torch.randint(0, Xtr.shape[0], (int(B),), device=device)\n",
    "        Xb, yb = Xtr[idx], ytr[idx]\n",
    "\n",
    "        track_this = bool(track_grads and (t % int(eval_every) == 0))\n",
    "\n",
    "        # ---- DFN step ----\n",
    "        t0 = time.perf_counter()\n",
    "        dfn.train()\n",
    "        yhat_d = dfn(Xb, alpha=alpha, beta=beta)\n",
    "        loss_d = F.mse_loss(yhat_d, yb)\n",
    "        opt_dfn.zero_grad(set_to_none=True)\n",
    "        loss_d.backward()\n",
    "\n",
    "        if track_this:\n",
    "            d_cost0 = dfn.cost_raw.detach().clone()\n",
    "            d_cap0  = dfn.cap_raw.detach().clone()\n",
    "            d_A0    = dfn.A_raw.detach().clone()\n",
    "            d_b0    = dfn.b_raw.detach().clone()\n",
    "\n",
    "        opt_dfn.step()\n",
    "\n",
    "        if track_this:\n",
    "            d_u_cost = mean_abs_delta(dfn.cost_raw, d_cost0)\n",
    "            d_u_cap  = mean_abs_delta(dfn.cap_raw,  d_cap0)\n",
    "            d_u_A    = mean_abs_delta(dfn.A_raw,    d_A0)\n",
    "            d_u_b    = mean_abs_delta(dfn.b_raw,    d_b0)\n",
    "\n",
    "        sync()\n",
    "        t_dfn += time.perf_counter() - t0\n",
    "\n",
    "        # ---- MLP step ----\n",
    "        t0 = time.perf_counter()\n",
    "        mlp.train()\n",
    "        yhat_m = mlp(Xb)\n",
    "        loss_m = F.mse_loss(yhat_m, yb)\n",
    "        opt_mlp.zero_grad(set_to_none=True)\n",
    "        loss_m.backward()\n",
    "\n",
    "        if track_this:\n",
    "            W0 = [p.detach().clone() for p in W_params]\n",
    "            b0 = [p.detach().clone() for p in b_params]\n",
    "\n",
    "        opt_mlp.step()\n",
    "\n",
    "        if track_this:\n",
    "            m_u_W    = weighted_mean_abs_delta_list(W_params, W0)\n",
    "            m_u_bias = weighted_mean_abs_delta_list(b_params, b0)\n",
    "\n",
    "        sync()\n",
    "        t_mlp += time.perf_counter() - t0\n",
    "\n",
    "        if t % int(eval_every) == 0:\n",
    "            dfn.eval(); mlp.eval()\n",
    "            with torch.no_grad():\n",
    "                dtr_mse = float(loss_d.item())\n",
    "                mtr_mse = float(loss_m.item())\n",
    "\n",
    "                yva_d = dfn(Xva, alpha=alpha, beta=beta)\n",
    "                yva_m = mlp(Xva)\n",
    "                dva_mse = mse_score(yva_d, yva)\n",
    "                mva_mse = mse_score(yva_m, yva)\n",
    "                dva_r2  = r2_score(yva_d, yva)\n",
    "                mva_r2  = r2_score(yva_m, yva)\n",
    "\n",
    "            if dva_mse < best_d_val:\n",
    "                best_d_val = dva_mse\n",
    "                best_d_step = t\n",
    "                best_d_state = {k: v.detach().cpu().clone() for k, v in dfn.state_dict().items()}\n",
    "            if mva_mse < best_m_val:\n",
    "                best_m_val = mva_mse\n",
    "                best_m_step = t\n",
    "                best_m_state = {k: v.detach().cpu().clone() for k, v in mlp.state_dict().items()}\n",
    "\n",
    "            S.append(t)\n",
    "            dtr.append(dtr_mse); dva.append(dva_mse)\n",
    "            mtr.append(mtr_mse); mva.append(mva_mse)\n",
    "            if track_time:\n",
    "                Td_hist.append(t_dfn); Tm_hist.append(t_mlp)\n",
    "\n",
    "            if track_grads:\n",
    "                # append update magnitudes from this tracked step\n",
    "                U_cost.append(d_u_cost); U_cap.append(d_u_cap); U_A.append(d_u_A); U_b.append(d_u_b)\n",
    "                U_W.append(m_u_W);       U_bias.append(m_u_bias)\n",
    "\n",
    "            # update curves\n",
    "            ldt.set_data(S, dtr); ldv.set_data(S, dva)\n",
    "            lmt.set_data(S, mtr); lmv.set_data(S, mva)\n",
    "            axL.relim(); axL.autoscale_view()\n",
    "\n",
    "            gt.set_data(S, [a-b for a,b in zip(dtr, mtr)])\n",
    "            gv.set_data(S, [a-b for a,b in zip(dva, mva)])\n",
    "            axG.relim(); axG.autoscale_view()\n",
    "\n",
    "            if track_time:\n",
    "                td.set_data(S, Td_hist); tm.set_data(S, Tm_hist)\n",
    "                axT.relim(); axT.autoscale_view()\n",
    "\n",
    "            # update scatter (VAL subset)\n",
    "            with torch.no_grad():\n",
    "                ypd = dfn(Xsc, alpha=alpha, beta=beta).detach().cpu().numpy()\n",
    "                ypm = mlp(Xsc).detach().cpu().numpy()\n",
    "            sd_sc.set_offsets(np.c_[y_true_cpu, ypd])\n",
    "            sm_sc.set_offsets(np.c_[y_true_cpu, ypm])\n",
    "            axD.set_title(f\"DFN VAL: MSE={dva_mse:.3g}  R²={dva_r2:.3f}\")\n",
    "            axM.set_title(f\"MLP VAL: MSE={mva_mse:.3g}  R²={mva_r2:.3f}\")\n",
    "\n",
    "            # update DFN update plot\n",
    "            if track_grads:\n",
    "                u_cost.set_data(S, U_cost); u_cap.set_data(S, U_cap)\n",
    "                u_A.set_data(S, U_A);       u_b.set_data(S, U_b)\n",
    "                axUd.relim(); axUd.autoscale_view()\n",
    "\n",
    "                # update MLP update plot\n",
    "                u_W.set_data(S, U_W); u_bias.set_data(S, U_bias)\n",
    "                axUm.relim(); axUm.autoscale_view()\n",
    "\n",
    "            fig.canvas.draw(); fig2.canvas.draw()\n",
    "            if disp is not None: disp.update(fig)\n",
    "            else: fig.canvas.flush_events(); plt.pause(0.001)\n",
    "            if disp2 is not None: disp2.update(fig2)\n",
    "            else: fig2.canvas.flush_events(); plt.pause(0.001)\n",
    "\n",
    "            if track_grads:\n",
    "                figUd.canvas.draw()\n",
    "                if dispUd is not None: dispUd.update(figUd)\n",
    "                else: figUd.canvas.flush_events(); plt.pause(0.001)\n",
    "\n",
    "                figUm.canvas.draw()\n",
    "                if dispUm is not None: dispUm.update(figUm)\n",
    "                else: figUm.canvas.flush_events(); plt.pause(0.001)\n",
    "\n",
    "    if best_d_state is not None:\n",
    "        dfn.load_state_dict(best_d_state)\n",
    "    if best_m_state is not None:\n",
    "        mlp.load_state_dict(best_m_state)\n",
    "\n",
    "    dfn.eval(); mlp.eval()\n",
    "    with torch.no_grad():\n",
    "        yte_d = dfn(Xte, alpha=alpha, beta=beta)\n",
    "        yte_m = mlp(Xte)\n",
    "        d_test_mse = mse_score(yte_d, yte)\n",
    "        m_test_mse = mse_score(yte_m, yte)\n",
    "        d_test_r2  = r2_score(yte_d, yte)\n",
    "        m_test_r2  = r2_score(yte_m, yte)\n",
    "\n",
    "    print(f\"[best] DFN step={best_d_step}  best DFN val MSE={best_d_val:.6g}\")\n",
    "    print(f\"[best] MLP step={best_m_step}  best MLP val MSE={best_m_val:.6g}\")\n",
    "    print(f\"[test] DFN: MSE={d_test_mse:.6g}  R²={d_test_r2:.4f}\")\n",
    "    print(f\"[test] MLP: MSE={m_test_mse:.6g}  R²={m_test_r2:.4f}\")\n",
    "    if track_time:\n",
    "        print(f\"[time] train-only total: DFN={t_dfn:.2f}s  MLP={t_mlp:.2f}s\")\n",
    "        print(f\"[time] per step:         DFN={t_dfn/steps:.4f}s  MLP={t_mlp/steps:.4f}s\")\n",
    "\n",
    "    updates = {\n",
    "        \"dfn_cost\": U_cost, \"dfn_cap\": U_cap, \"dfn_A\": U_A, \"dfn_b\": U_b,\n",
    "        \"mlp_W\": U_W, \"mlp_bias\": U_bias,\n",
    "    }\n",
    "\n",
    "    return {\n",
    "        \"dfn\": dfn,\n",
    "        \"mlp\": mlp,\n",
    "        \"best\": {\"dfn_step\": best_d_step, \"mlp_step\": best_m_step, \"dfn_val_mse\": best_d_val, \"mlp_val_mse\": best_m_val},\n",
    "        \"test\": {\"dfn_mse\": d_test_mse, \"dfn_r2\": d_test_r2, \"mlp_mse\": m_test_mse, \"mlp_r2\": m_test_r2},\n",
    "        \"hist\": {\"steps\": S, \"dfn_val_mse\": dva, \"mlp_val_mse\": mva, \"dfn_tr_mse\": dtr, \"mlp_tr_mse\": mtr},\n",
    "        \"updates\": updates,\n",
    "        \"grads\": updates,  # alias so old code that expects \"grads\" doesn't break (now contains updates)\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35739a3d-a43e-4512-9a25-c35aeaaa1097",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = compare_dfn_mlp_live(\n",
    "    N=1000, d=10, x_min=-50, x_max=50, noise=0.05, seed=1,\n",
    "    train_frac=0.7, val_frac=0.15, test_frac=0.15,\n",
    "\n",
    "    graph=\"layered\",\n",
    "    layer_sizes=[32, 64, 64, 32],\n",
    "    p=1,\n",
    "\n",
    "    big_cost=1e3, big_cap=1e9,\n",
    "    solver=\"lemon\",\n",
    "    alpha=1e-7, beta=-5.0,\n",
    "\n",
    "    steps=50000, B=8, eval_every=10,\n",
    "    lr_cost=1e-2, lr_cap=1e-3, lr_A=1e-2, lr_b=1e-1,\n",
    "    \n",
    "    lr_mlp_W=1e-3,\n",
    "    lr_mlp_bias=1e-2,\n",
    "    mlp_hidden=(256,256),\n",
    "\n",
    "    init_cost=(1.0,0.3), init_cap=(3.0,0.3),\n",
    "    init_A=(1.0,0.3), init_b=(1.0,0.3),\n",
    "\n",
    "    scatter_n=600,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e58d01-c214-42bb-9115-0fac21645d85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9072ad1e-0a4b-4141-afdb-945c53fbccb0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4ea6931-b097-40a3-9215-4c796d541011",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32ec96ce-1e0a-4fbc-afd9-a03bb6a38f40",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
