{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "96398989-280e-41b3-bc11-99b4b51e9439",
   "metadata": {},
   "source": [
    "# General Helpers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "767d34c9-66bd-4986-a324-4daeec80aee1",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bef302bd-4ca0-4269-9ba1-ac6d05ef5819",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "import sys, json, time, math, random, tempfile\n",
    "from pathlib import Path\n",
    "from typing import Optional, Tuple, Dict, List\n",
    "import numpy as np\n",
    "import cppimport\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from IPython.display import Image, display\n",
    "from dataclasses import dataclass, field\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cd5a10c-7768-4318-8244-beaf6c4c84d1",
   "metadata": {},
   "source": [
    "### Python Implementation of LEMON \n",
    "The following defines an importable Python module `lemon_mcf` which runs LEMON solvers (written in C++)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b4658685-aab5-4c45-aa6d-36e862bbdc9a",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "repo = Path().resolve().parent\n",
    "sys.path.insert(0, str(repo))\n",
    "lemon_mcf = cppimport.imp(\"lemon_mcf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "800144cd-efc4-47e8-95d6-e75cad74419d",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "The module has the function `lemon_mcf.solve_mcf` which solves the minimum cost flow problem. The function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    out = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply, tol=1e-9)\n",
    "    \n",
    "**Inputs (NumPy arrays)**\n",
    "- `n` : `int` — number of nodes (indexed `0..n-1`)\n",
    "- `src`, `dst` : `np.ndarray` shape `(m,)`, dtype `int64` — directed edges `src[i] -> dst[i]`\n",
    "- `cost`, `cap` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge cost and capacities\n",
    "- `supply` : `np.ndarray` shape `(n,)`, dtype `float64` — node supplies/demands (`>0` supply, `<0` demand)\n",
    "- `tol` : `float` — tolerance for capacity-status flags\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `out[\"status\"]` : `int`\n",
    "- `out[\"flow\"]` : `np.ndarray` shape `(m,)`, dtype `float64`\n",
    "- `out[\"potential\"]` : `np.ndarray` shape `(n,)`, dtype `float64` (node potentials; defined up to an additive constant)\n",
    "- `out[\"reduced_cost\"]` : `np.ndarray` shape `(m,)`, dtype `float64`, computed as `cost[i] + potential[src[i]] - potential[dst[i]]`\n",
    "- (optional) capacity-status flag: boolean arrays indicating whether each arc is at its capacity\n",
    "- `out[\"total_cost\"]` : `float` (objective value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1ff805be-adf0-44c9-a12c-d4d2fb077251",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'status': 1, 'flow': array([1., 2., 0.]), 'potential': array([-2.,  0., -1.]), 'reduced_cost': array([0., 0., 4.]), 'at_capacity': array([False,  True, False]), 'total_cost': 4.0}\n"
     ]
    }
   ],
   "source": [
    "# ----- Example of Usage ------ #\n",
    "\n",
    "n = 3\n",
    "src    = np.array([0, 0, 1], dtype=np.int64)\n",
    "dst    = np.array([1, 2, 2], dtype=np.int64)\n",
    "cost   = np.array([2.0, 1.0, 3.0], dtype=np.float64)\n",
    "cap    = np.array([5.0, 2.0, 4.0], dtype=np.float64)\n",
    "supply = np.array([3.0, -1.0, -2.0], dtype=np.float64)\n",
    "\n",
    "out_min_cost_flow = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply)\n",
    "print(out_min_cost_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5edb874-0d6f-4c0a-a1c0-8707a04bd0c7",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "It also has the function `lemon_mcf.max_flow` which solves the maximum s-t flow problem. The function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    out = lemon_mcf.max_flow(n, src, dst, cap, s, t)\n",
    "\n",
    "**Inputs (NumPy arrays)**\n",
    "- `n` : `int` — number of nodes (indexed `0..n-1`)\n",
    "- `src`, `dst` : `np.ndarray` shape `(m,)`, dtype `int64` — directed edges `src[i] -> dst[i]`\n",
    "- `cap` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge capacities (nonnegative)\n",
    "- `s` : `int` — source node index\n",
    "- `t` : `int` — sink node index\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `out[\"value\"]` : `float` — maximum flow value from `s` to `t`\n",
    "- `out[\"flow\"]` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge flow values (in the same order as the input edges)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "84d7091d-b4dd-46db-84b2-e580dd826951",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true,
     "source_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'value': 4.0, 'flow': array([2., 2., 2., 2.])}\n"
     ]
    }
   ],
   "source": [
    "# ----- Example of Usage ------ #\n",
    "\n",
    "n = 4\n",
    "src = np.array([0,0,1,2], dtype=np.int64)\n",
    "dst = np.array([1,2,3,3], dtype=np.int64)\n",
    "cap = np.array([3.0,2.0,2.0,4.0], dtype=np.float64)\n",
    "out_max_flow = lemon_mcf.max_flow(n, src, dst, cap, 0, 3)\n",
    "print(out_max_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2d023ed-4551-4a12-895b-5499538ccfe3",
   "metadata": {},
   "source": [
    "# Generating Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c19ed1a-276c-4f7b-930b-15ee5dd1a087",
   "metadata": {},
   "source": [
    "## Synthetic Quadratic "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e668b358-356c-4b0f-8f52-3fa79709f650",
   "metadata": {},
   "source": [
    "#### Quadratic Instance Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d51f4f1d-d52d-4167-8546-975f7056157a",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    data = generate_synthetic_convex_quadratic_dataset(d, m, M, X, sigma2, K, seed=None)\n",
    "\n",
    "**Inputs**\n",
    "- `d` : `int` — dimension (`d > 0`)\n",
    "- `m`, `M` : `float` — eigenvalue bounds (`0 < m <= M`)\n",
    "- `X` : `int` — integer box bound (samples are uniform over `[-X, X]^d`, coordinate-wise)\n",
    "- `sigma2` : `float` — noise variance (`sigma2 >= 0`)\n",
    "- `K` : `int` — number of samples (`K > 0`)\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "\n",
    "**Process**\n",
    "1. Sample eigenvalues `λ_1,...,λ_d ~ Uniform([m, M])` i.i.d.\n",
    "2. Sample an (approximately) Haar-uniform orthonormal matrix `U ∈ R^{d×d}` via QR of a Gaussian matrix.\n",
    "3. Form `Q = U diag(λ) U^T` (positive definite since `m > 0`).\n",
    "4. Sample an integer optimizer `x*` uniformly from the integer box `[-X, X]^d`.\n",
    "5. Sample `K` integer covariates `x^1,...,x^K` uniformly from `[-X, X]^d`.\n",
    "6. Sample i.i.d. noise `ε^k ~ N(0, sigma2)` for `k=1,...,K`.\n",
    "7. Set targets `y^k = (x^k - x*)^T Q (x^k - x*) + ε^k`.\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `data[\"U\"]` : `np.ndarray` shape `(d, d)` — sampled orthonormal matrix\n",
    "- `data[\"lambdas\"]` : `np.ndarray` shape `(d,)` — sampled eigenvalues\n",
    "- `data[\"Q\"]` : `np.ndarray` shape `(d, d)` — sampled quadratic matrix `Q = U diag(lambdas) U^T`\n",
    "- `data[\"x_star\"]` : `np.ndarray` shape `(d,)`, dtype `int` — sampled optimizer `x*`\n",
    "- `data[\"X_samples\"]` : `np.ndarray` shape `(K, d)`, dtype `int` — sampled inputs `x^k`\n",
    "- `data[\"y\"]` : `np.ndarray` shape `(K,)`, dtype `float` — targets `y^k`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "edd30065-8973-4a0f-8315-99a96ac42355",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _haar_orthonormal_matrix(d: int, rng: np.random.Generator) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Sample an (approximately) Haar-uniform random orthonormal matrix U in R^{dxd}\n",
    "    using QR decomposition of a standard Gaussian matrix, with a sign correction.\n",
    "    \"\"\"\n",
    "    A = rng.standard_normal((d, d))\n",
    "    U, R = np.linalg.qr(A)\n",
    "    s = np.sign(np.diag(R))\n",
    "    s[s == 0] = 1.0\n",
    "    U = U * s  # multiply columns by sign\n",
    "    return U\n",
    "\n",
    "def generate_synthetic_convex_quadratic_dataset(\n",
    "    d: int,\n",
    "    m: float,\n",
    "    M: float,\n",
    "    X: int,\n",
    "    sigma2: float,\n",
    "    K: int,\n",
    "    seed: int | None = None\n",
    "):\n",
    "    if d <= 0:\n",
    "        raise ValueError(\"d must be positive.\")\n",
    "    if not (M >= m):\n",
    "        raise ValueError(\"Require M >= m.\")\n",
    "    if m <= 0:\n",
    "        raise ValueError(\"Require m > 0 so that Q is positive definite.\")\n",
    "    if K <= 0:\n",
    "        raise ValueError(\"K must be positive.\")\n",
    "    if sigma2 < 0:\n",
    "        raise ValueError(\"sigma2 must be nonnegative.\")\n",
    "    if not isinstance(X, (int, np.integer)):\n",
    "        raise ValueError(\"X must be an integer.\")\n",
    "    X = int(X)\n",
    "    if X < 0:\n",
    "        raise ValueError(\"X must be a nonnegative integer.\")\n",
    "\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # (i) sample eigenvalues\n",
    "    lambdas = rng.uniform(m, M, size=d)\n",
    "\n",
    "    # (ii) sample orthonormal matrix U\n",
    "    U = _haar_orthonormal_matrix(d, rng)\n",
    "\n",
    "    # Build Q = U diag(lambdas) U^T\n",
    "    Q = U @ np.diag(lambdas) @ U.T\n",
    "\n",
    "    # Sample integer vectors uniformly from [-X, X]^d\n",
    "    def sample_int_box(num: int) -> np.ndarray:\n",
    "        return rng.integers(-X, X + 1, size=(num, d), dtype=int)\n",
    "\n",
    "    # (iii) sample x*\n",
    "    x_star = sample_int_box(1).reshape(-1)  # shape (d,)\n",
    "\n",
    "    # (iv) sample K vectors x^1,...,x^K\n",
    "    X_samples = sample_int_box(K)  # shape (K, d)\n",
    "\n",
    "    # (v) sample Gaussian noises\n",
    "    sigma = float(np.sqrt(sigma2))\n",
    "    eps = rng.normal(loc=0.0, scale=sigma, size=K)\n",
    "\n",
    "    # (vi) compute y^k\n",
    "    diffs = X_samples - x_star  # shape (K, d)\n",
    "    quad = np.einsum(\"bi,ij,bj->b\", diffs, Q, diffs)  # shape (K,)\n",
    "    y = quad + eps\n",
    "\n",
    "    return {\n",
    "        \"U\": U,\n",
    "        \"lambdas\": lambdas,\n",
    "        \"Q\": Q,\n",
    "        \"x_star\": x_star,\n",
    "        \"X_samples\": X_samples,\n",
    "        \"y\": y,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "666e79e4-8daa-483c-995d-6a3b92367b88",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true,
     "source_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_star: [ 7  4  4 -2  8]\n",
      "X_samples shape: (100, 5)\n",
      "y shape: (100,)\n",
      "min eigenvalue(Q) ~ 0.5413190888213227\n"
     ]
    }
   ],
   "source": [
    "# --- Example of Usage --- #\n",
    "data = generate_synthetic_convex_quadratic_dataset(\n",
    "    d=5, m=0.5, M=3.0, X=10, sigma2=0.25, K=100, seed=0\n",
    ")\n",
    "\n",
    "print(\"x_star:\", data[\"x_star\"])\n",
    "print(\"X_samples shape:\", data[\"X_samples\"].shape)\n",
    "print(\"y shape:\", data[\"y\"].shape)\n",
    "print(\"min eigenvalue(Q) ~\", np.min(np.linalg.eigvalsh(data[\"Q\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c49e91f-881b-4020-8878-c22cf217c00c",
   "metadata": {},
   "source": [
    "## Synthetic Min Cost Flow"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f773417-3500-4b5a-b931-59eacf5b61c2",
   "metadata": {},
   "source": [
    "#### NETGEN Instance Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6c6656d-e5ac-45fd-975d-398e4fbe472b",
   "metadata": {},
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    u, v, cap, cost, b = netgen_instance_arrays(nodes, arcs, sources, sinks, cost_bounds, cap_bounds, total_supply=1000, seed=None)\n",
    "\n",
    "**Inputs**\n",
    "- `nodes` : `int` — number of nodes (`nodes > 0`)\n",
    "- `arcs` : `int` — number of directed arcs/edges to generate (`arcs > 0`)\n",
    "- `sources` : `int` — number of supply (source) nodes (`sources >= 0`)\n",
    "- `sinks` : `int` — number of demand (sink) nodes (`sinks >= 0`)\n",
    "- `cost_bounds` : `tuple[int, int]` — arc cost bounds `(min_cost, max_cost)` with `min_cost <= max_cost`\n",
    "- `cap_bounds` : `tuple[int, int]` — arc capacity bounds `(min_cap, max_cap)` with `min_cap <= max_cap`\n",
    "- `total_supply` : `int` — total supply injected into the network (`total_supply >= 0`)\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "\n",
    "**Output (arrays)**\n",
    "Returns a 5-tuple:\n",
    "\n",
    "- `u` : `np.ndarray` shape `(m,)`, dtype `int` — tail (source) node index for each arc (0-based)\n",
    "- `v` : `np.ndarray` shape `(m,)`, dtype `int` — head (destination) node index for each arc (0-based)  \n",
    "  (arc `i` is `u[i] -> v[i]`)\n",
    "- `cap` : `np.ndarray` shape `(m,)`, dtype `int` — capacity for each arc\n",
    "- `cost` : `np.ndarray` shape `(m,)`, dtype `int` — per-unit cost for each arc\n",
    "- `b` : `np.ndarray` shape `(nodes,)`, dtype `int` — flow balance for each node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7fb8ebaa-0e13-4c0b-bbf6-00c0b60c4547",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def _parse_dimacs_mcf(path: Path) -> tuple[int, Dict[int, int], List[tuple[int, int, int, int]]]:\n",
    "    n = None\n",
    "    supply: Dict[int, int] = {}\n",
    "    arcs: List[tuple[int, int, int, int]] = []  # (u, v, cap, cost)\n",
    "\n",
    "    for raw in path.read_text(encoding=\"utf-8\").splitlines():\n",
    "        line = raw.strip()\n",
    "        if not line or line.startswith(\"c\"):\n",
    "            continue\n",
    "        parts = line.split()\n",
    "        if parts[0] == \"p\":\n",
    "            n = int(parts[2])\n",
    "        elif parts[0] == \"n\":\n",
    "            supply[int(parts[1])] = int(parts[2])\n",
    "        elif parts[0] == \"a\":\n",
    "            u, v, ignored, cap, cost = map(int, parts[1:])\n",
    "            if ignored != 0:\n",
    "                raise ValueError(f\"Unexpected nonzero ignored field in DIMACS arc line: value={ignored}\")\n",
    "            arcs.append((u, v, cap, cost))\n",
    "\n",
    "    if n is None:\n",
    "        raise ValueError(\"Missing 'p min N M' line.\")\n",
    "    return n, supply, arcs\n",
    "\n",
    "\n",
    "def netgen_instance_arrays(\n",
    "    *,\n",
    "    nodes: int,\n",
    "    arcs: int,\n",
    "    sources: int,\n",
    "    sinks: int,\n",
    "    cost_bounds: Tuple[int, int],\n",
    "    cap_bounds: Tuple[int, int],\n",
    "    total_supply: int = 1000,\n",
    "    seed: Optional[int] = None,\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      u, v, cap, cost, b\n",
    "    where:\n",
    "      - u,v,cap,cost are arrays of length m\n",
    "      - b is an array of length n with supply(+)/demand(-), sum(b)=0\n",
    "      - nodes are 0-based indices\n",
    "    \"\"\"\n",
    "    if nodes <= 0 or arcs <= 0:\n",
    "        raise ValueError(\"nodes and arcs must be positive.\")\n",
    "    if sources < 0 or sinks < 0 or sources + sinks > nodes:\n",
    "        raise ValueError(\"Need sources+sinks <= nodes.\")\n",
    "    minc, maxc = cost_bounds\n",
    "    mincap, maxcap = cap_bounds\n",
    "    if minc > maxc or mincap > maxcap:\n",
    "        raise ValueError(\"Bad bounds: min must be <= max.\")\n",
    "    if seed is None:\n",
    "        seed = random.randint(1, 2_000_000_000)\n",
    "\n",
    "    try:\n",
    "        import pynetgen  # type: ignore\n",
    "    except ImportError as e:\n",
    "        raise RuntimeError(\n",
    "            \"pynetgen is not installed.\"\n",
    "        ) from e\n",
    "\n",
    "    tmpdir = Path(tempfile.mkdtemp(prefix=\"netgen_\"))\n",
    "    dimacs_path = tmpdir / \"instance.dimacs\"\n",
    "\n",
    "    pynetgen.netgen_generate(\n",
    "        seed=seed,\n",
    "        nodes=nodes,\n",
    "        sources=sources,\n",
    "        sinks=sinks,\n",
    "        density=arcs,\n",
    "        mincost=minc,\n",
    "        maxcost=maxc,\n",
    "        supply=total_supply,\n",
    "        tsources=0,\n",
    "        tsinks=0,\n",
    "        hicost=0,\n",
    "        capacitated=100,\n",
    "        mincap=mincap,\n",
    "        maxcap=maxcap,\n",
    "        rng=1,\n",
    "        type=None,\n",
    "        fname=str(dimacs_path),\n",
    "    )\n",
    "\n",
    "    n, supply_1b, arc_list_1b = _parse_dimacs_mcf(dimacs_path)\n",
    "    m = len(arc_list_1b)\n",
    "\n",
    "    u = np.empty(m, dtype=np.int64)\n",
    "    v = np.empty(m, dtype=np.int64)\n",
    "    cap = np.empty(m, dtype=np.int64)\n",
    "    cost = np.empty(m, dtype=np.int64)\n",
    "\n",
    "    for i, (uu, vv, c, w) in enumerate(arc_list_1b):\n",
    "        u[i] = uu - 1\n",
    "        v[i] = vv - 1\n",
    "        cap[i] = c\n",
    "        cost[i] = w\n",
    "\n",
    "    b = np.zeros(n, dtype=np.int64)\n",
    "    for node_1b, val in supply_1b.items():\n",
    "        b[node_1b - 1] = val\n",
    "\n",
    "    if b.sum() != 0:\n",
    "        raise ValueError(f\"Unbalanced instance: sum(b)={b.sum()} (should be 0).\")\n",
    "\n",
    "    return u, v, cap, cost, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "51948156-ee64-4380-9cd9-f31f811dd585",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true,
     "source_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n=200, m=1500, balance sum(b)=0\n",
      "#sources=10, #sinks=10\n",
      "sources (node, supply): [(0, 1054), (1, 864), (2, 372), (3, 49), (4, 299), (5, 917), (6, 565), (7, 45), (8, 711), (9, 124)]\n",
      "sinks   (node, demand): [(190, -546), (191, -1003), (192, -42), (193, -49), (194, -1021), (195, -621), (196, -432), (197, -768), (198, -33), (199, -485)]\n",
      "\n",
      "First edges:\n",
      "   i      u      v     cap    cost\n",
      "-----------------------------------\n",
      "   0      0    118    1054      23\n",
      "   1      0    175      85       8\n",
      "   2      0    109      97      13\n",
      "   3      0    134      56      41\n",
      "   4      0     95      93      26\n",
      "   5      0    131      96       6\n",
      "   6      0     50      17       2\n",
      "   7      0     48      60      42\n",
      "   8      0     47      77      31\n",
      "   9      0    187      20      36\n",
      "\n",
      "cap stats: 1 108.12266666666666 1054\n",
      "cost stats: 1 24.924 50\n"
     ]
    }
   ],
   "source": [
    "# ------ Example of usage ------ #\n",
    "\n",
    "u, v, cap, cost, b = netgen_instance_arrays(\n",
    "    nodes=200,\n",
    "    arcs=1500,\n",
    "    sources=10,\n",
    "    sinks=10,\n",
    "    cost_bounds=(1, 50),\n",
    "    cap_bounds=(1, 100),\n",
    "    total_supply=5000,\n",
    "    seed=7,\n",
    ")\n",
    "\n",
    "n = len(b)\n",
    "m = len(u)\n",
    "print(f\"n={n}, m={m}, balance sum(b)={int(b.sum())}\")\n",
    "\n",
    "src = np.where(b > 0)[0]\n",
    "snk = np.where(b < 0)[0]\n",
    "print(f\"#sources={len(src)}, #sinks={len(snk)}\")\n",
    "print(\"sources (node, supply):\", [(int(i), int(b[i])) for i in src[:10]])\n",
    "print(\"sinks   (node, demand):\", [(int(i), int(b[i])) for i in snk[:10]])\n",
    "\n",
    "k = 10\n",
    "k = min(k, m)\n",
    "print(\"\\nFirst edges:\")\n",
    "print(f\"{'i':>4}  {'u':>5}  {'v':>5}  {'cap':>6}  {'cost':>6}\")\n",
    "print(\"-\" * 35)\n",
    "for i in range(k):\n",
    "    print(f\"{i:>4}  {int(u[i]):>5}  {int(v[i]):>5}  {int(cap[i]):>6}  {int(cost[i]):>6}\")\n",
    "\n",
    "print(\"\\ncap stats:\", int(cap.min()), float(cap.mean()), int(cap.max()))\n",
    "print(\"cost stats:\", int(cost.min()), float(cost.mean()), int(cost.max()))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "273eeb7d-38f9-4cdb-9c43-1c3e920c792d",
   "metadata": {},
   "source": [
    "#### NETGEN Dataset Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "723aab3f-b920-44d6-a891-dc87f6a89b61",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    data = generate_netgen_mcf_dataset(n, family, K, tilde_C, sigma2, seed=None, cap_range=(1,1000), cost_range=(1,10000), total_supply=1000, verbose=False)\n",
    "\n",
    "**Inputs**\n",
    "- `n` : `int` — number of nodes in the NETGEN network (`n > 0`)\n",
    "- `family` : `str` — family: `\"sparse\"` or `\"dense\"`\n",
    "  - `\"sparse\"` uses `m = 8n` arcs\n",
    "  - `\"dense\"` uses `m = ceil(n * sqrt(n))` arcs\n",
    "- `K` : `int` — number of balance/label samples to generate (`K > 0`)\n",
    "- `tilde_C` : `int` — max auxiliary cost used when generating balances (`tilde_C > 0`)\n",
    "- `sigma2` : `float` — label noise variance (`sigma2 >= 0`)\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "- `cap_range` : `(int, int)` — integer capacity range for NETGEN arcs (default `(1,1000)`)\n",
    "- `cost_range` : `(int, int)` — integer cost range for NETGEN arcs (default `(1,10000)`)\n",
    "- `total_supply` : `int` — total supply in the feasible balance NETGEN produces (needed to generate network but balances are discarded)\n",
    "- `verbose` : `bool` — print `F_max` and basic diagnostics if `True`\n",
    "\n",
    "**Process**\n",
    "1. Set the number of supply and demand nodes to  \n",
    "   `|S| = |T| = floor(sqrt(n))`.\n",
    "2. Choose the number of arcs `m` based on `family`:\n",
    "   - If `\"sparse\"`: `m = 8n`\n",
    "   - If `\"dense\"`:  `m = ceil(n * sqrt(n))`\n",
    "3. Generate a NETGEN minimum-cost flow instance on `n` nodes with:\n",
    "   - arc capacities `u_e ~ Unif({1,...,1000})`\n",
    "   - arc costs `c_e ~ Unif({1,...,10000})`\n",
    "   NETGEN also returns a feasible balance vector `b0`; discard its magnitudes but use its sign pattern to\n",
    "   identify the node sets:\n",
    "   - `S = {v : b0(v) > 0}` (supplies)\n",
    "   - `T = {v : b0(v) < 0}` (demands)\n",
    "4. Augment the network by adding a super-source `s` and super-sink `t`:\n",
    "   - add arcs `(s,v)` for all `v ∈ S` with capacity `+∞`\n",
    "   - add arcs `(v,t)` for all `v ∈ T` with capacity `+∞`\n",
    "5. Compute `F_max`, the value of a maximum `s–t` flow in the augmented network.\n",
    "6. For each sample `k = 1,...,K`:\n",
    "   1. Sample a target flow value `F^(k) ~ Unif({0,1,...,F_max})`.\n",
    "   2. Sample auxiliary integer arc costs on the augmented network  \n",
    "      `tilde_c_e^(k) ~ Unif({1,...,tilde_C})`.\n",
    "   3. Solve a minimum-cost `s–t` flow of value `F^(k)` on the augmented network under costs `tilde_c^(k)`,\n",
    "      obtaining a flow `f^(k)`.\n",
    "   4. Define the balance vector `b^(k) ∈ Z^n` on the original node set by:\n",
    "      - `b^(k)(v) = f_sv^(k)` for `v ∈ S`\n",
    "      - `b^(k)(v) = -f_vt^(k)` for `v ∈ T`\n",
    "      - `b^(k)(v) = 0` otherwise\n",
    "   5. Solve the original minimum-cost flow problem on `G(V,E)` with capacities `u`, costs `c`,\n",
    "      and balance `b^(k)` to obtain the optimal cost `z^(k)`.\n",
    "   6. Sample noise `ε^(k) ~ N(0, sigma2)` and set the label `y^(k) = z^(k) + ε^(k)`.\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `data[\"n\"]` : `int` — number of nodes\n",
    "- `data[\"m\"]` : `int` — number of arcs in the original NETGEN graph\n",
    "- `data[\"u\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc tail nodes (0-indexed)\n",
    "- `data[\"v\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc head nodes (0-indexed)\n",
    "- `data[\"cap\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc capacities\n",
    "- `data[\"cost\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc costs\n",
    "- `data[\"S_nodes\"]` : `np.ndarray` shape `(|S|,)`, dtype `int` — supply node indices\n",
    "- `data[\"T_nodes\"]` : `np.ndarray` shape `(|T|,)`, dtype `int` — demand node indices\n",
    "- `data[\"X_samples\"]` : `np.ndarray` shape `(K, n)`, dtype `int` — sampled balance vectors `b^(k)`\n",
    "- `data[\"y\"]` : `np.ndarray` shape `(K,)`, dtype `float` — noisy labels `y^(k)`\n",
    "- `data[\"z\"]` : `np.ndarray` shape `(K,)`, dtype `float` — clean optimal costs `z^(k)` (before noise)\n",
    "- `data[\"F_max\"]` : `int` — maximum `s–t` flow value in the augmented network\n",
    "- `data[\"F_targets\"]` : `np.ndarray` shape `(K,)`, dtype `int` — sampled target values `F^(k)`\n",
    "- `data[\"tilde_C\"]` : `int` — auxiliary cost bound used\n",
    "- `data[\"sigma2\"]` : `float` — noise variance used\n",
    "- `data[\"seed\"]` : `int` or `None` — RNG seed used\n",
    "- `data[\"family\"]` : `str` — `\"sparse\"` or `\"dense\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8d6820d3-b93a-4172-a39e-4c96595e5999",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "\n",
    "def generate_netgen_mcf_dataset(\n",
    "    n: int,\n",
    "    family: str,                 # \"sparse\" or \"dense\"\n",
    "    K: int,\n",
    "    tilde_C: int,\n",
    "    sigma2: float,\n",
    "    seed: int | None = None,\n",
    "    cap_range: tuple[int, int] = (1, 1000),\n",
    "    cost_range: tuple[int, int] = (1, 10000),\n",
    "    total_supply: int = 1000,\n",
    "    verbose: bool = False,\n",
    "):\n",
    "    if n <= 0:\n",
    "        raise ValueError(\"n must be positive.\")\n",
    "    if K <= 0:\n",
    "        raise ValueError(\"K must be positive.\")\n",
    "    if tilde_C <= 0:\n",
    "        raise ValueError(\"tilde_C must be positive.\")\n",
    "    if sigma2 < 0:\n",
    "        raise ValueError(\"sigma2 must be nonnegative.\")\n",
    "    if family not in {\"sparse\", \"dense\"}:\n",
    "        raise ValueError(\"family must be 'sparse' or 'dense'.\")\n",
    "\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # --- NETGEN sizes ---\n",
    "    S_cnt = int(math.isqrt(n))                 # floor(sqrt(n))\n",
    "    T_cnt = int(math.isqrt(n))                 # floor(sqrt(n))\n",
    "    if family == \"sparse\":\n",
    "        m_target = 8 * n\n",
    "    else:\n",
    "        m_target = int(math.ceil(n * math.sqrt(n)))\n",
    "\n",
    "    # --- generate NETGEN instance (graph + a feasible b we will discard) ---\n",
    "    u, v, cap, cost, b0 = netgen_instance_arrays(\n",
    "        nodes=n,\n",
    "        arcs=m_target,\n",
    "        sources=S_cnt,\n",
    "        sinks=T_cnt,\n",
    "        cost_bounds=cost_range,\n",
    "        cap_bounds=cap_range,\n",
    "        total_supply=total_supply,\n",
    "        seed=None if seed is None else int(seed),\n",
    "    )\n",
    "\n",
    "    # Identify supply/demand node sets from NETGEN's b (we discard magnitudes, keep sets)\n",
    "    S_nodes = np.where(b0 > 0)[0]\n",
    "    T_nodes = np.where(b0 < 0)[0]\n",
    "\n",
    "    if len(S_nodes) != S_cnt or len(T_nodes) != T_cnt:\n",
    "        raise RuntimeError(\n",
    "            f\"NETGEN returned |S|={len(S_nodes)}, |T|={len(T_nodes)} but expected \"\n",
    "            f\"{S_cnt} and {T_cnt}. (Check generator settings.)\"\n",
    "        )\n",
    "\n",
    "    # --- build augmented network with super-source s and super-sink t ---\n",
    "    s = n\n",
    "    t = n + 1\n",
    "    N_aug = n + 2\n",
    "\n",
    "    # \"infinite\" capacity for super arcs: big enough not to bind\n",
    "    INF = float(np.sum(cap, dtype=np.int64) + 1)\n",
    "\n",
    "    # original arcs\n",
    "    src_aug = [u.astype(np.int64)]\n",
    "    dst_aug = [v.astype(np.int64)]\n",
    "    cap_aug = [cap.astype(np.float64)]\n",
    "\n",
    "    # s -> v for v in S\n",
    "    src_s = np.full(len(S_nodes), s, dtype=np.int64)\n",
    "    dst_s = S_nodes.astype(np.int64)\n",
    "    cap_s = np.full(len(S_nodes), INF, dtype=np.float64)\n",
    "\n",
    "    # v -> t for v in T\n",
    "    src_t = T_nodes.astype(np.int64)\n",
    "    dst_t = np.full(len(T_nodes), t, dtype=np.int64)\n",
    "    cap_t = np.full(len(T_nodes), INF, dtype=np.float64)\n",
    "\n",
    "    src_aug = np.concatenate(src_aug + [src_s, src_t])\n",
    "    dst_aug = np.concatenate(dst_aug + [dst_s, dst_t])\n",
    "    cap_aug = np.concatenate(cap_aug + [cap_s, cap_t])\n",
    "\n",
    "    # Remember arc indices for extracting b^(k)\n",
    "    m0 = len(u)\n",
    "    idx_s_to_v = {int(vv): m0 + i for i, vv in enumerate(S_nodes)}\n",
    "    idx_v_to_t = {int(vv): m0 + len(S_nodes) + i for i, vv in enumerate(T_nodes)}\n",
    "\n",
    "    # --- compute F_max via max s-t flow on augmented network ---\n",
    "    out_mf = lemon_mcf.max_flow(N_aug, src_aug, dst_aug, cap_aug, s, t)\n",
    "    F_max = int(out_mf[\"value\"])\n",
    "    if verbose:\n",
    "        print(f\"[augmented] F_max = {F_max}\")\n",
    "\n",
    "    # --- generate K balances and labels ---\n",
    "    B = np.zeros((K, len(S_nodes) + len(T_nodes)), dtype=np.int64)\n",
    "    z = np.zeros(K, dtype=np.float64)\n",
    "    y = np.zeros(K, dtype=np.float64)\n",
    "    F_targets = np.zeros(K, dtype=np.int64)\n",
    "\n",
    "    for k in range(K):\n",
    "        # sample target value F^(k)\n",
    "        Fk = int(rng.integers(0, F_max + 1))\n",
    "        F_targets[k] = Fk\n",
    "\n",
    "        # sample auxiliary integer costs on ALL augmented arcs\n",
    "        aux_cost = rng.integers(1, tilde_C + 1, size=len(src_aug), dtype=np.int64).astype(np.float64)\n",
    "\n",
    "        # min-cost s-t flow of value Fk via supplies\n",
    "        supply_aug = np.zeros(N_aug, dtype=np.float64)\n",
    "        supply_aug[s] = float(Fk)\n",
    "        supply_aug[t] = -float(Fk)\n",
    "\n",
    "        out_aux = lemon_mcf.solve_mcf(\n",
    "            N_aug,\n",
    "            src_aug,\n",
    "            dst_aug,\n",
    "            aux_cost,\n",
    "            cap_aug,\n",
    "            supply_aug,\n",
    "        )\n",
    "        if out_aux[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"Aux min-cost flow failed at k={k} with status={out_aux['status']}.\")\n",
    "\n",
    "        flow_aug = out_aux[\"flow\"]  # float array per arc\n",
    "\n",
    "        # build balance b^(k) on original nodes\n",
    "        bk = np.zeros(n, dtype=np.int64)\n",
    "        for vv in S_nodes:\n",
    "            bk[int(vv)] = int(round(float(flow_aug[idx_s_to_v[int(vv)]])))\n",
    "        for vv in T_nodes:\n",
    "            bk[int(vv)] = -int(round(float(flow_aug[idx_v_to_t[int(vv)]])))\n",
    "\n",
    "        B[k] = bk[np.r_[S_nodes, T_nodes]]\n",
    "\n",
    "        # solve original min-cost flow to get z^(k)\n",
    "        out_orig = lemon_mcf.solve_mcf(\n",
    "            n,\n",
    "            u.astype(np.int64),\n",
    "            v.astype(np.int64),\n",
    "            cost.astype(np.float64),\n",
    "            cap.astype(np.float64),\n",
    "            bk.astype(np.float64),\n",
    "        )\n",
    "        if out_orig[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"Original min-cost flow failed at k={k} with status={out_orig['status']}.\")\n",
    "\n",
    "        zk = float(out_orig[\"total_cost\"])\n",
    "        z[k] = zk\n",
    "\n",
    "        # noisy label\n",
    "        eps = float(rng.normal(0.0, math.sqrt(sigma2)))\n",
    "        y[k] = zk + eps\n",
    "\n",
    "    return {\n",
    "        # graph\n",
    "        \"n\": n,\n",
    "        \"m\": len(u),\n",
    "        \"u\": u.astype(np.int64),\n",
    "        \"v\": v.astype(np.int64),\n",
    "        \"cap\": cap.astype(np.int64),\n",
    "        \"cost\": cost.astype(np.int64),\n",
    "        \"S_nodes\": S_nodes.astype(np.int64),\n",
    "        \"T_nodes\": T_nodes.astype(np.int64),\n",
    "\n",
    "        # dataset (quadratic-like)\n",
    "        \"X_samples\": B,     # balances b^(k)\n",
    "        \"y\": y,             # noisy labels\n",
    "\n",
    "        # extras (often useful)\n",
    "        \"z\": z,             # clean optimal costs\n",
    "        \"F_max\": F_max,\n",
    "        \"F_targets\": F_targets,\n",
    "        \"tilde_C\": tilde_C,\n",
    "        \"sigma2\": float(sigma2),\n",
    "        \"seed\": seed,\n",
    "        \"family\": family,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "891009e5-34d9-4d9e-a62e-ef7321e4ddc1",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[augmented] F_max = 44749\n",
      "X_samples shape: (100, 200)\n",
      "y shape: (100,)\n",
      "F_max: 44749\n"
     ]
    }
   ],
   "source": [
    "# ---- Example of Usage ---- #\n",
    "\n",
    "data = generate_netgen_mcf_dataset(\n",
    "    n=200,\n",
    "    family=\"sparse\",\n",
    "    K=100,\n",
    "    tilde_C=100,\n",
    "    sigma2=0.25,\n",
    "    seed=0,\n",
    "    verbose=True,\n",
    ")\n",
    "\n",
    "print(\"X_samples shape:\", data[\"X_samples\"].shape)  # (K, n)\n",
    "print(\"y shape:\", data[\"y\"].shape)                  # (K,)\n",
    "print(\"F_max:\", data[\"F_max\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fd37ea3-560e-4d85-8907-455186f275b5",
   "metadata": {},
   "source": [
    "## DataSet Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28f59382-e2e6-4350-acb4-38c12890f4f0",
   "metadata": {},
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    X, y = gen_dataset(kind, **params)\n",
    "\n",
    "**Inputs**\n",
    "- `kind` : `str` — dataset type, must be `\"quadratic\"` or `\"netgen\"`\n",
    "- `params` : keyword args — parameters required by the chosen dataset type (see below)\n",
    "\n",
    "---\n",
    "\n",
    "`kind=\"quadratic\"`\n",
    "\n",
    "**Required `params`**\n",
    "- `d` : `int` — input dimension (`d > 0`)\n",
    "- `m` : `float` — minimum eigenvalue lower bound for the quadratic form (`m > 0`)\n",
    "- `M` : `float` — maximum eigenvalue upper bound (`M >= m`)\n",
    "- `X` : `int` — feature sampling half-width (features sampled from `[-X, X]`, `X > 0`)\n",
    "- `K` : `int` — number of samples (`K > 0`)\n",
    "\n",
    "**Optional `params`**\n",
    "- `sigma2` : `float` — label noise variance (`sigma2 >= 0`, default `0.0`)\n",
    "- `seed` : `int` or `None` — RNG seed (default `None`)\n",
    "\n",
    "---\n",
    "\n",
    "`kind=\"netgen\"`\n",
    "\n",
    "**Required `params`**\n",
    "- `n` : `int` — number of nodes (`n > 0`)\n",
    "- `family` : `str` — graph family (e.g., `\"sparse\"` or `\"dense\"`)\n",
    "- `tilde_C` : `int` — NETGEN generation parameter (`tilde_C > 0`)\n",
    "- `K` : `int` — number of samples (`K > 0`)\n",
    "\n",
    "**Optional `params`**\n",
    "- `sigma2` : `float` — label noise variance (`sigma2 >= 0`, default `0.0`)\n",
    "- `seed` : `int` or `None` — RNG seed (default `None`)\n",
    "- `cap_range` : `tuple[int, int]` — capacity bounds `(min_cap, max_cap)` with `min_cap <= max_cap` (default `(1, 1000)`)\n",
    "- `cost_range` : `tuple[int, int]` — cost bounds `(min_cost, max_cost)` with `min_cost <= max_cost` (default `(1, 10000)`)\n",
    "- `total_supply` : `int` — total supply injected into the network (`total_supply >= 0`, default `1000`)\n",
    "- `verbose` : `bool` — print generation details (default `False`)\n",
    "\n",
    "---\n",
    "\n",
    "**Output**\n",
    "Returns a 2-tuple:\n",
    "\n",
    "- `X` : `np.ndarray` — feature matrix containing `K` samples  \n",
    "  (typically shape `(K, d)` for `\"quadratic\"`; for `\"netgen\"` it depends on the generator’s feature representation)\n",
    "- `y` : `np.ndarray` shape `(K,)` — labels/targets for each sample\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82fc65d8-218d-423c-96c4-e67580f32688",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_dataset(kind, **p):\n",
    "    if kind == \"quadratic\":\n",
    "        d = generate_synthetic_convex_quadratic_dataset(\n",
    "            d=p[\"d\"], m=p[\"m\"], M=p[\"M\"], X=p[\"X\"],\n",
    "            K=p[\"K\"], sigma2=p.get(\"sigma2\", 0.0), seed=p.get(\"seed\", None),\n",
    "        )\n",
    "    elif kind == \"netgen\":\n",
    "        d = generate_netgen_mcf_dataset(\n",
    "            n=p[\"n\"], family=p[\"family\"], tilde_C=p[\"tilde_C\"],\n",
    "            K=p[\"K\"], sigma2=p.get(\"sigma2\", 0.0), seed=p.get(\"seed\", None),\n",
    "            cap_range=p.get(\"cap_range\", (1, 1000)),\n",
    "            cost_range=p.get(\"cost_range\", (1, 10000)),\n",
    "            total_supply=p.get(\"total_supply\", 1000),\n",
    "            verbose=p.get(\"verbose\", False),\n",
    "        )\n",
    "    else:\n",
    "        raise ValueError(\"kind must be 'quadratic' or 'netgen'\")\n",
    "    return d[\"X_samples\"], d[\"y\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1cdd06e-81d1-4f27-a39b-81d540a21241",
   "metadata": {},
   "source": [
    "# Models Definition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b131da93-331f-46ef-91e4-1ced802f0e27",
   "metadata": {},
   "source": [
    "### DFN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d617ce9-24ab-4015-9fd0-6bc983d61fab",
   "metadata": {},
   "source": [
    "#### Helper Functions and Modules"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07be67a0-5e8e-4e5f-9e2d-5acda3ab83be",
   "metadata": {},
   "source": [
    "STE op:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0d7ccba7-9593-421d-a6ef-778b9dd9c437",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def round_param(raw: torch.Tensor, step: float = 1.0, nonneg: bool = False, eps: float = 0.0) -> torch.Tensor:\n",
    "    x = F.softplus(raw) + eps if nonneg else raw\n",
    "    y = torch.round(x / step) * step\n",
    "    return x + (y - x).detach()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab29e386-ed1c-4bff-aae5-830ec56acfa2",
   "metadata": {},
   "source": [
    "Min Cost Flow op:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2da3d4c2-7972-4d9e-8167-722a22f7a409",
   "metadata": {},
   "source": [
    "**Signature**  \n",
    "    z = min_cost_flow_value(n_nodes, src, dst, cost, cap, supply, solver=\"lemon\")\n",
    "\n",
    "**Inputs**\n",
    "- `n_nodes` : `int` — number of nodes (indexed `0..n_nodes-1`)\n",
    "- `src`, `dst` : `torch.Tensor` shape `(m,)`, integer/long — arc endpoints `src[e] -> dst[e]`\n",
    "- `cost` : `torch.Tensor` shape `(m,)`, float — per-arc cost\n",
    "- `cap` : `torch.Tensor` shape `(m,)`, float — per-arc capacity (upper bound)\n",
    "- `supply` : `torch.Tensor` shape `(n_nodes,)`, float — node flow balance (`>0` supply, `<0` demand)\n",
    "- `solver` : `str` — `\"lemon\"` or `\"gurobi\"`\n",
    "\n",
    "**Output**\n",
    "- `z` : `torch.Tensor` scalar — optimal objective value of the min-cost flow\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "980806c5-11ab-4527-b164-a88981fca9ff",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "_TOL = 1e-9\n",
    "\n",
    "def _solve_lemon(n, src, dst, cost, cap, supply):\n",
    "    if lemon_mcf is None:\n",
    "        raise RuntimeError(\"LEMON requested but lemon_mcf is not available.\")\n",
    "    out = lemon_mcf.solve_mcf(n, src, dst, cost, cap, supply, tol=_TOL)\n",
    "    if out[\"status\"] != 1:\n",
    "        raise RuntimeError(f\"LEMON failed (status={out['status']})\")\n",
    "    flow = out[\"flow\"]\n",
    "    pot  = out[\"potential\"]\n",
    "    red  = out[\"reduced_cost\"]\n",
    "    at   = out.get(\"at_cap\", out.get(\"at_capacity\", None))\n",
    "    if at is None:\n",
    "        at = np.abs(flow - cap) <= _TOL\n",
    "    return float(out[\"total_cost\"]), flow, pot, red, at\n",
    "\n",
    "\n",
    "def _solve_gurobi(n, src, dst, cost, cap, supply):\n",
    "    try:\n",
    "        import gurobipy as gp\n",
    "        from gurobipy import GRB\n",
    "    except Exception as e:\n",
    "        raise ImportError(\"gurobipy is required (with a valid license).\") from e\n",
    "\n",
    "    m = int(src.size)\n",
    "    out_idx = [[] for _ in range(n)]\n",
    "    in_idx  = [[] for _ in range(n)]\n",
    "    for k in range(m):\n",
    "        out_idx[int(src[k])].append(k)\n",
    "        in_idx[int(dst[k])].append(k)\n",
    "\n",
    "    model = gp.Model()\n",
    "    model.Params.OutputFlag = 0\n",
    "    model.Params.LogToConsole = 0\n",
    "\n",
    "    x = model.addVars(m, lb=0.0, ub=cap.tolist(), obj=cost.tolist(), vtype=GRB.CONTINUOUS, name=\"x\")\n",
    "    model.ModelSense = GRB.MINIMIZE\n",
    "\n",
    "    bal = []\n",
    "    for i in range(n):\n",
    "        bal.append(model.addConstr(\n",
    "            gp.quicksum(x[k] for k in out_idx[i]) - gp.quicksum(x[k] for k in in_idx[i]) == float(supply[i])\n",
    "        ))\n",
    "\n",
    "    model.optimize()\n",
    "    if model.Status != GRB.OPTIMAL:\n",
    "        raise RuntimeError(f\"Gurobi failed (status={model.Status})\")\n",
    "\n",
    "    flow = np.fromiter((x[k].X  for k in range(m)), dtype=np.float64, count=m)\n",
    "    red  = np.fromiter((x[k].RC for k in range(m)), dtype=np.float64, count=m)\n",
    "    pot  = -np.fromiter((bal[i].Pi for i in range(n)), dtype=np.float64, count=n)  # match LEMON convention\n",
    "    at   = np.abs(flow - cap) <= _TOL\n",
    "    return float(model.ObjVal), flow, pot, red, at\n",
    "\n",
    "\n",
    "class _MCFValue(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, n_nodes, src, dst, cost, cap, supply, solver=\"lemon\"):\n",
    "        n = int(n_nodes)\n",
    "        src = src.long(); dst = dst.long()\n",
    "        m = int(src.numel())\n",
    "        if dst.numel() != m or cost.numel() != m or cap.numel() != m or supply.numel() != n:\n",
    "            raise ValueError(\"bad shapes\")\n",
    "\n",
    "        src_np = src.detach().cpu().contiguous().reshape(-1).numpy().astype(np.int64, copy=False)\n",
    "        dst_np = dst.detach().cpu().contiguous().reshape(-1).numpy().astype(np.int64, copy=False)\n",
    "        c_np   = cost.detach().cpu().contiguous().reshape(-1).numpy().astype(np.float64, copy=False)\n",
    "        u_np   = cap.detach().cpu().contiguous().reshape(-1).numpy().astype(np.float64, copy=False)\n",
    "        b_np   = supply.detach().cpu().contiguous().reshape(-1).numpy().astype(np.float64, copy=False)\n",
    "\n",
    "        if abs(float(b_np.sum())) > _TOL:\n",
    "            raise ValueError(f\"require sum(supply)=0 (got {float(b_np.sum()):.3e})\")\n",
    "        if m and ((src_np < 0).any() or (src_np >= n).any() or (dst_np < 0).any() or (dst_np >= n).any()):\n",
    "            raise ValueError(\"src/dst out of range\")\n",
    "        if m and (u_np < 0).any():\n",
    "            raise ValueError(\"cap must be nonnegative\")\n",
    "\n",
    "        s = str(solver).lower()\n",
    "        if s == \"lemon\":\n",
    "            obj, flow, pot, red, at = _solve_lemon(n, src_np, dst_np, c_np, u_np, b_np)\n",
    "        elif s == \"gurobi\":\n",
    "            obj, flow, pot, red, at = _solve_gurobi(n, src_np, dst_np, c_np, u_np, b_np)\n",
    "        else:\n",
    "            raise ValueError(f\"unknown solver: {solver}\")\n",
    "\n",
    "        ctx.flow, ctx.pot, ctx.red, ctx.at = flow, pot, red, at\n",
    "        return cost.new_tensor(obj)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g):\n",
    "        dev, dt = g.device, g.dtype\n",
    "        flow = torch.as_tensor(ctx.flow, device=dev, dtype=dt)\n",
    "        pot  = torch.as_tensor(ctx.pot,  device=dev, dtype=dt)\n",
    "        red  = torch.as_tensor(ctx.red,  device=dev, dtype=dt)\n",
    "        at   = torch.as_tensor(ctx.at,   device=dev, dtype=torch.bool)\n",
    "\n",
    "        grad_cost = flow\n",
    "        grad_cap  = torch.where(at, red, torch.zeros_like(red))\n",
    "        grad_sup  = pot.mean() - pot  # gauge-fixed (-dual) gradient\n",
    "\n",
    "        return None, None, None, grad_cost * g, grad_cap * g, grad_sup * g, None\n",
    "\n",
    "\n",
    "def min_cost_flow_value(n_nodes, src, dst, cost, cap, supply, solver=\"lemon\"):\n",
    "    return _MCFValue.apply(n_nodes, src, dst, cost, cap, supply, solver)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49e1537e-0608-496b-8cfd-7fd4ba19a910",
   "metadata": {},
   "source": [
    "Network Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "90ebc084-4dde-4fb5-b606-f1190431ac95",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "@dataclass(frozen=True)\n",
    "class Network:\n",
    "    n: int\n",
    "    src_param: torch.Tensor\n",
    "    dst_param: torch.Tensor\n",
    "    ax_nodes: torch.Tensor\n",
    "    b_nodes: torch.Tensor\n",
    "    fix_node: int\n",
    "    src_fixed: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.long))\n",
    "    dst_fixed: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.long))\n",
    "    cap_fixed: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float32))\n",
    "    cost_fixed: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed2502cb-6f96-4e64-b75a-aa085e94a1dd",
   "metadata": {},
   "source": [
    "Multi-Layered Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "39851bc6-86be-45cd-8d14-f121dda84f9c",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def build_layered_network(\n",
    "    layer_sizes: list[int],\n",
    "    n_ax: int | None = None,\n",
    "    n_b: int | None = None,\n",
    "    big_cost: float = 1e3,\n",
    "    big_cap: float = 1e9,\n",
    "    p: float = 1.0,          # keep-prob for each internal (layered) arc\n",
    "    seed: int | None = None, # for reproducible sparsification\n",
    ") -> Network:\n",
    "    if not (0.0 <= float(p) <= 1.0):\n",
    "        raise ValueError(\"p must be in [0,1].\")\n",
    "    gen = None\n",
    "    if seed is not None:\n",
    "        gen = torch.Generator(device=\"cpu\").manual_seed(int(seed))\n",
    "\n",
    "    layer_sizes = list(map(int, layer_sizes))\n",
    "    K = len(layer_sizes)\n",
    "    if K < 2: raise ValueError(\"need at least 2 layers\")\n",
    "    n = sum(layer_sizes)\n",
    "    if n <= 0: raise ValueError(\"sum(layer_sizes) must be > 0\")\n",
    "\n",
    "    layers, off = [], 0\n",
    "    for s in layer_sizes:\n",
    "        if s < 0: raise ValueError(\"layer_sizes must be nonnegative\")\n",
    "        layers.append(torch.arange(off, off + s, dtype=torch.long))\n",
    "        off += s\n",
    "    L1, LK = layers[0], layers[-1]\n",
    "    s1, sK = int(L1.numel()), int(LK.numel())\n",
    "    if s1 == 0 or sK == 0: raise ValueError(\"L1 and LK must be non-empty\")\n",
    "\n",
    "    # boundary split: default half/half inside each of L1 and LK\n",
    "    boundary = s1 + sK\n",
    "    if n_ax is None and n_b is None:\n",
    "        v1, vK = s1 // 2, sK // 2\n",
    "        c1, cK = s1 - v1, sK - vK\n",
    "    else:\n",
    "        n_ax = boundary - int(n_b) if n_ax is None else int(n_ax)\n",
    "        n_b  = boundary - int(n_ax) if n_b  is None else int(n_b)\n",
    "        if n_ax < 0 or n_b < 0 or n_ax + n_b > boundary: raise ValueError(\"bad n_ax/n_b\")\n",
    "        v1 = min(n_ax, s1 // 2); vK = n_ax - v1\n",
    "        c1 = min(n_b,  s1 - v1); cK = n_b  - c1\n",
    "        if vK > sK or vK + cK > sK: raise ValueError(\"n_ax/n_b too large for LK\")\n",
    "\n",
    "    ax_nodes = torch.cat([L1[:v1], LK[:vK]], 0)\n",
    "    b_nodes  = torch.cat([L1[v1:v1+c1], LK[vK:vK+cK]], 0)\n",
    "    fix_node = int(LK[-1].item())\n",
    "\n",
    "    def bip(U, V):\n",
    "        su, sv = int(U.numel()), int(V.numel())\n",
    "        return U.repeat_interleave(sv), V.repeat(su)\n",
    "\n",
    "    # learnable layered arcs Li -> L(i+1), sparsified with keep-prob p\n",
    "    srcL, dstL = [], []\n",
    "    for i in range(K - 1):\n",
    "        s, t = bip(layers[i], layers[i + 1])\n",
    "        if p < 1.0:\n",
    "            keep = (torch.rand(s.numel(), generator=gen) < p)\n",
    "            s, t = s[keep], t[keep]\n",
    "        srcL.append(s); dstL.append(t)\n",
    "\n",
    "    src_param = torch.cat(srcL, 0)\n",
    "    dst_param = torch.cat(dstL, 0)\n",
    "    if src_param.numel() == 0:\n",
    "        raise ValueError(\"no layered arcs after sparsification (increase p or layer sizes)\")\n",
    "\n",
    "    # fixed big arcs L1 <-> LK (still full)\n",
    "    s1a, t1a = bip(L1, LK)\n",
    "    s2a, t2a = bip(LK, L1)\n",
    "    src_fixed = torch.cat([s1a, s2a], 0)\n",
    "    dst_fixed = torch.cat([t1a, t2a], 0)\n",
    "    m = int(src_fixed.numel())\n",
    "    cap_fixed  = torch.full((m,), float(big_cap),  dtype=torch.float32)\n",
    "    cost_fixed = torch.full((m,), float(big_cost), dtype=torch.float32)\n",
    "\n",
    "    return Network(\n",
    "        n=n, src_param=src_param, dst_param=dst_param,\n",
    "        ax_nodes=ax_nodes, b_nodes=b_nodes, fix_node=fix_node,\n",
    "        src_fixed=src_fixed, dst_fixed=dst_fixed,\n",
    "        cap_fixed=cap_fixed, cost_fixed=cost_fixed\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58780994-c27b-4dd9-ba2e-865cb4b95c60",
   "metadata": {},
   "source": [
    "Complete Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "19f2419e-da64-4739-b344-fc8eb115ad0a",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "\n",
    "def build_complete_doubled_network(\n",
    "    n: int,\n",
    "    n_ax: int | None = None,\n",
    "    n_b: int | None = None,\n",
    "    *,\n",
    "    fix_node: int | None = None,\n",
    "    big_cost: float = 1e3,\n",
    "    big_cap: float = 1e9,\n",
    ") -> Network:\n",
    "    n = int(n)\n",
    "    if n < 2: raise ValueError(\"n must be >= 2\")\n",
    "\n",
    "    k = int(math.isqrt(n))\n",
    "    n_ax = k if n_ax is None else int(n_ax)\n",
    "    n_b  = k if n_b  is None else int(n_b)\n",
    "\n",
    "    fix_node = (n - 1) if fix_node is None else int(fix_node)\n",
    "    if not (0 <= fix_node < n): raise ValueError(\"fix_node out of range\")\n",
    "    if n_ax < 0 or n_b < 0 or (n_ax + n_b) > (n - 1):\n",
    "        raise ValueError(\"need 0<=n_ax,n_b and n_ax+n_b <= n-1 (excluding fix_node)\")\n",
    "\n",
    "    # complete directed arcs i->j, i!=j\n",
    "    v = torch.arange(n, dtype=torch.long)\n",
    "    src = v.repeat_interleave(n)\n",
    "    dst = v.repeat(n)\n",
    "    mask = (src != dst)\n",
    "    src = src[mask]; dst = dst[mask]               # m = n*(n-1)\n",
    "\n",
    "    # learnable copy\n",
    "    src_param, dst_param = src, dst\n",
    "\n",
    "    # fixed big copy (parallel arcs)\n",
    "    src_fixed, dst_fixed = src.clone(), dst.clone()\n",
    "    m = int(src_fixed.numel())\n",
    "    cap_fixed  = torch.full((m,), float(big_cap),  dtype=torch.float32)\n",
    "    cost_fixed = torch.full((m,), float(big_cost), dtype=torch.float32)\n",
    "\n",
    "    # choose ax/b nodes deterministically in order (excluding fix_node)\n",
    "    nodes = torch.arange(n, dtype=torch.long)\n",
    "    avail = nodes[nodes != fix_node]\n",
    "    ax_nodes = avail[:n_ax].clone()\n",
    "    b_nodes  = avail[n_ax:n_ax + n_b].clone()\n",
    "\n",
    "    return Network(\n",
    "        n=n,\n",
    "        src_param=src_param, dst_param=dst_param,\n",
    "        ax_nodes=ax_nodes, b_nodes=b_nodes,\n",
    "        fix_node=fix_node,\n",
    "        src_fixed=src_fixed, dst_fixed=dst_fixed,\n",
    "        cap_fixed=cap_fixed, cost_fixed=cost_fixed,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c13547-8c12-4413-a9d3-368d015cd778",
   "metadata": {},
   "source": [
    "#### The DFN Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2bd3916f-ee07-4ef5-8b4e-05ba9438ee93",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "class DFN(nn.Module):\n",
    "    def __init__(self, net: Network, input_dim: int, solver: str = \"lemon\"):\n",
    "        super().__init__()\n",
    "        self.n = int(net.n)\n",
    "        self.fix_node = int(net.fix_node)\n",
    "        self.solver = str(solver)\n",
    "\n",
    "        # store network data inside the module (so .to(device) works, and checkpoints keep the graph)\n",
    "        self.register_buffer(\"src_param\", net.src_param.long())\n",
    "        self.register_buffer(\"dst_param\", net.dst_param.long())\n",
    "        self.register_buffer(\"src_fixed\", net.src_fixed.long())\n",
    "        self.register_buffer(\"dst_fixed\", net.dst_fixed.long())\n",
    "        self.register_buffer(\"cap_fixed\", net.cap_fixed.float())\n",
    "        self.register_buffer(\"cost_fixed\", net.cost_fixed.float())\n",
    "        self.register_buffer(\"ax_nodes\",  net.ax_nodes.long())\n",
    "        self.register_buffer(\"b_nodes\",   net.b_nodes.long())\n",
    "\n",
    "        m_param = int(self.src_param.numel())\n",
    "        n_ax    = int(self.ax_nodes.numel())\n",
    "        n_b     = int(self.b_nodes.numel())\n",
    "\n",
    "        self.cap_raw  = nn.Parameter(torch.randn(m_param) + 1)\n",
    "        self.cost_raw = nn.Parameter(torch.randn(m_param) + 1)\n",
    "        self.A_raw    = nn.Parameter(torch.randn(n_ax, input_dim) + 1)\n",
    "        self.b_raw    = nn.Parameter(torch.randn(n_b) + 1)\n",
    "\n",
    "    def forward(self, x: torch.Tensor, alpha: float = 1.0, beta: float = 0.0) -> torch.Tensor:\n",
    "        capP  = round_param(self.cap_raw,  nonneg=True)\n",
    "        costP = round_param(self.cost_raw, nonneg=True)\n",
    "        A     = round_param(self.A_raw,    nonneg=False)\n",
    "        b0    = round_param(self.b_raw,    nonneg=False)\n",
    "\n",
    "        def solve_one(x1: torch.Tensor):\n",
    "            src  = torch.cat([self.src_param, self.src_fixed], 0)\n",
    "            dst  = torch.cat([self.dst_param, self.dst_fixed], 0)\n",
    "            cap  = torch.cat([capP,  self.cap_fixed.to(x1.device, x1.dtype)], 0)\n",
    "            cost = torch.cat([costP, self.cost_fixed.to(x1.device, x1.dtype)], 0)\n",
    "\n",
    "            supply = torch.zeros(self.n, device=x1.device, dtype=torch.float64)\n",
    "            supply[self.ax_nodes] = (A.double() @ x1.double())\n",
    "            supply[self.b_nodes] = b0.double()\n",
    "            supply[self.fix_node] -= supply.sum()\n",
    "\n",
    "            obj = min_cost_flow_value(self.n, src, dst, cost, cap, supply, solver=self.solver)\n",
    "            return alpha * obj + beta\n",
    "\n",
    "        return solve_one(x) if x.dim() == 1 else torch.stack([solve_one(xi) for xi in x], 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72c356e0-d483-442b-a32b-49f0d7eab840",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d547b665-8251-45cb-8d9a-ffef58911114",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dim, out_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(in_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e74105b4-7d14-4b4d-9261-160fc4b44827",
   "metadata": {},
   "source": [
    "### Max-Affine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "0fbc2d7a-4792-47a8-ab0c-6ff991f5492a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaxAffine(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim, n_pieces):\n",
    "        super().__init__()\n",
    "        self.W = nn.Parameter(torch.randn(n_pieces, out_dim, in_dim) / (in_dim ** 0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces, out_dim))\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: (batch, in_dim)\n",
    "        y = torch.einsum(\"bi,koi->bko\", x, self.W) + self.b  # (batch, n_pieces, out_dim)\n",
    "        return y.max(dim=1).values  # (batch, out_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b30023b-4802-4f90-b074-026359cf9623",
   "metadata": {},
   "source": [
    "### LogSumExp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d2e2022-711d-4a9e-bdf7-4ef803bac1f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSET(nn.Module):\n",
    "    def __init__(self, in_dim, n_pieces, T=0.01, eps=1e-8):\n",
    "        super().__init__()\n",
    "        self.A = nn.Parameter(torch.randn(n_pieces, in_dim) / (in_dim ** 0.5))\n",
    "        self.b = nn.Parameter(torch.zeros(n_pieces))\n",
    "        self.T = float(T)   # keep fixed for simplicity/stability\n",
    "        self.eps = eps\n",
    "\n",
    "    def forward(self, x):\n",
    "        T = self.T + self.eps\n",
    "        z = (x @ self.A.t() + self.b) / T          # (B, K)\n",
    "        return T * torch.logsumexp(z, dim=-1)      # (B,)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8a3d912-fa0a-4acc-a12f-10a089f213de",
   "metadata": {},
   "source": [
    "# Training on Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "223fa05c-6713-4c67-9625-1a4a95131c3a",
   "metadata": {},
   "source": [
    "### General Trainer wrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "489b185f-6c77-4c9a-b49a-8c3664cc477e",
   "metadata": {},
   "source": [
    "**Signature**\n",
    "    model, hist = fit_best(model, X, y, epochs=50, bs=32, lr=1e-3, val_frac=0.2, seed=0, logdir=\"runs/exp\", device=None)\n",
    "\n",
    "**Inputs**\n",
    "- `model` : `torch.nn.Module` — any PyTorch model such that `model(x_batch)` returns predictions compatible with `y`\n",
    "- `X` : `np.ndarray` or `torch.Tensor` — feature matrix of shape `(N, d)`\n",
    "- `y` : `np.ndarray` or `torch.Tensor` — targets of shape `(N,)` or `(N, 1)`\n",
    "- `epochs` : `int` — number of training epochs (`epochs > 0`, default `50`)\n",
    "- `bs` : `int` — batch size (`bs > 0`, default `32`)\n",
    "- `lr` : `float` — Adam learning rate (`lr > 0`, default `1e-3`)\n",
    "- `val_frac` : `float` — validation split fraction in `(0, 1)` (default `0.2`)\n",
    "  - validation is assumed active; validation set size is `max(1, int(N * val_frac))`\n",
    "- `seed` : `int` — RNG seed used for the train/val split and shuffling (default `0`)\n",
    "- `logdir` : `str` — TensorBoard log directory (default `\"runs/exp\"`)\n",
    "  - logs live scalars/histograms under this run directory\n",
    "- `device` : `str` or `None` — device to train on (`\"cpu\"` or `\"cuda\"`)\n",
    "  - if `None`, uses `\"cuda\"` if available else `\"cpu\"`\n",
    "\n",
    "**Outputs**\n",
    "Returns a 2-tuple:\n",
    "\n",
    "1) `model` : `torch.nn.Module`\n",
    "- The trained model restored to the **best validation** epoch (lowest validation loss)\n",
    "\n",
    "2) `hist` : `list[dict]` — training history (one dict per epoch)\n",
    "Each element contains:\n",
    "- `train_loss` : `float` — mean train MSE for the epoch\n",
    "- `val_loss` : `float` — mean validation MSE for the epoch\n",
    "- `grad_norm` : `float` — average global gradient norm for the epoch\n",
    "- `param_norm` : `float` — global parameter L2 norm\n",
    "- `param_update_norm` : `float` — L2 norm of parameter change vs previous epoch\n",
    "- `samples_seen` : `int` — cumulative number of training samples processed\n",
    "- `time_s` : `float` — elapsed wall-clock time (seconds) since training start\n",
    "- `samples_per_s` : `float` — throughput, `samples_seen / time_s`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "45b09aec-5365-4502-841a-153b49e49197",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time, copy, torch\n",
    "from torch.utils.data import TensorDataset, DataLoader, random_split\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "def fit_best(model, X, y, *, epochs=50, bs=32, lr=1e-3, val_frac=0.2, seed=0,\n",
    "             logdir=\"runs/exp\", device=None):\n",
    "    X = torch.as_tensor(X, dtype=torch.float32)\n",
    "    y = torch.as_tensor(y, dtype=torch.float32).view(-1)\n",
    "\n",
    "    g = torch.Generator().manual_seed(seed)\n",
    "    n_val = max(1, int(len(y) * val_frac))\n",
    "    tr_ds, va_ds = random_split(TensorDataset(X, y), [len(y) - n_val, n_val], generator=g)\n",
    "    tr = DataLoader(tr_ds, batch_size=bs, shuffle=True, generator=g)\n",
    "    va = DataLoader(va_ds, batch_size=bs, shuffle=False)\n",
    "\n",
    "    device = device or (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model = model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "    mse = torch.nn.MSELoss()\n",
    "    w = SummaryWriter(logdir)\n",
    "\n",
    "    hist, best_val, best_state = [], float(\"inf\"), None\n",
    "    prev = [p.detach().clone() for p in model.parameters()]\n",
    "    t0, seen = time.perf_counter(), 0\n",
    "\n",
    "    for ep in range(epochs):\n",
    "        model.train()\n",
    "        tot = cnt = gtot = steps = 0.0\n",
    "\n",
    "        for xb, yb in tr:\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            loss = mse(model(xb).view(-1), yb)\n",
    "            loss.backward()\n",
    "\n",
    "            gn2 = 0.0\n",
    "            for p in model.parameters():\n",
    "                if p.grad is not None:\n",
    "                    gn2 += float(p.grad.detach().float().norm()) ** 2\n",
    "            gtot += gn2 ** 0.5\n",
    "\n",
    "            opt.step()\n",
    "            b = xb.size(0)\n",
    "            tot += float(loss) * b\n",
    "            cnt += b\n",
    "            seen += b\n",
    "            steps += 1\n",
    "\n",
    "        tr_loss = tot / cnt\n",
    "        gnorm = gtot / steps\n",
    "\n",
    "        model.eval()\n",
    "        vtot = vcnt = 0.0\n",
    "        with torch.no_grad():\n",
    "            for xb, yb in va:\n",
    "                xb, yb = xb.to(device), yb.to(device)\n",
    "                v = mse(model(xb).view(-1), yb)\n",
    "                vtot += float(v) * xb.size(0)\n",
    "                vcnt += xb.size(0)\n",
    "        va_loss = vtot / vcnt\n",
    "        if va_loss < best_val:\n",
    "            best_val, best_state = va_loss, copy.deepcopy(model.state_dict())\n",
    "\n",
    "        pn2 = 0.0\n",
    "        for p in model.parameters():\n",
    "            pn2 += float(p.detach().float().norm()) ** 2\n",
    "        pn = pn2 ** 0.5\n",
    "\n",
    "        upd2 = 0.0\n",
    "        for p, pp in zip(model.parameters(), prev):\n",
    "            upd2 += float((p.detach() - pp).float().norm()) ** 2\n",
    "        upd = upd2 ** 0.5\n",
    "        prev = [p.detach().clone() for p in model.parameters()]\n",
    "\n",
    "        t = time.perf_counter() - t0\n",
    "        sps = seen / max(1e-9, t)\n",
    "\n",
    "        w.add_scalar(\"loss/train\", tr_loss, ep)\n",
    "        w.add_scalar(\"loss/val\", va_loss, ep)\n",
    "        w.add_scalar(\"grad/norm\", gnorm, ep)\n",
    "        w.add_scalar(\"params/norm\", pn, ep)\n",
    "        w.add_scalar(\"params/update_norm\", upd, ep)\n",
    "        w.add_scalar(\"perf/samples_per_s\", sps, ep)\n",
    "        w.add_scalar(\"perf/time_s\", t, ep)\n",
    "\n",
    "        xb, yb = next(iter(va))\n",
    "        with torch.no_grad():\n",
    "            yp = model(xb.to(device)).view(-1).cpu()\n",
    "        w.add_histogram(\"pred/y_pred\", yp, ep)\n",
    "        w.add_histogram(\"pred/y_true\", yb.view(-1), ep)\n",
    "\n",
    "        hist.append({\"train_loss\": tr_loss, \"val_loss\": va_loss, \"grad_norm\": gnorm,\n",
    "                     \"param_norm\": pn, \"param_update_norm\": upd,\n",
    "                     \"samples_seen\": seen, \"time_s\": t, \"samples_per_s\": sps})\n",
    "\n",
    "    model.load_state_dict(best_state)\n",
    "    w.close()\n",
    "    return model, hist\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0f8d3eb-ed08-4f34-a1d2-ee1188982386",
   "metadata": {},
   "source": [
    "### Generating Synthetic Quadtratic Dataset and Training Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e9f8d7a-b27a-4530-bc02-4ac75d6b8265",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bddf11da-e597-457b-af6e-487bf02f3b3a",
   "metadata": {},
   "source": [
    "### Generating Synthetic NetGen Dataset and Training Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dafff36-0c34-407d-a20b-b5ac31ed8915",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "609e32f6-7edd-40dd-9764-e06a592d0e73",
   "metadata": {},
   "source": [
    "# Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88a70431-932d-40db-8ad1-d21fd24d60ec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e58d01-c214-42bb-9115-0fac21645d85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9072ad1e-0a4b-4141-afdb-945c53fbccb0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4ea6931-b097-40a3-9215-4c796d541011",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32ec96ce-1e0a-4fbc-afd9-a03bb6a38f40",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
