{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "96398989-280e-41b3-bc11-99b4b51e9439",
   "metadata": {},
   "source": [
    "# General Helpers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "767d34c9-66bd-4986-a324-4daeec80aee1",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bef302bd-4ca0-4269-9ba1-ac6d05ef5819",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import sys\n",
    "from pathlib import Path\n",
    "import cppimport\n",
    "from __future__ import annotations\n",
    "from pathlib import Path\n",
    "from typing import Optional, Tuple, Dict, List\n",
    "import random\n",
    "import tempfile"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cd5a10c-7768-4318-8244-beaf6c4c84d1",
   "metadata": {},
   "source": [
    "### Python Implementation of LEMON \n",
    "The following defines an importable Python module `lemon_mcf` which runs LEMON solvers (written in C++)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b4658685-aab5-4c45-aa6d-36e862bbdc9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "repo = Path().resolve().parent\n",
    "sys.path.insert(0, str(repo))\n",
    "lemon_mcf = cppimport.imp(\"lemon_mcf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "800144cd-efc4-47e8-95d6-e75cad74419d",
   "metadata": {},
   "source": [
    "The module has the function `lemon_mcf.solve_mcf` which solves the minimum cost flow problem. The function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    out = lemon_mcf.solve_mcf(n, src, dst, cost, lower, cap, supply, tol=1e-9)\n",
    "    \n",
    "**Inputs (NumPy arrays)**\n",
    "- `n` : `int` — number of nodes (indexed `0..n-1`)\n",
    "- `src`, `dst` : `np.ndarray` shape `(m,)`, dtype `int64` — directed edges `src[i] -> dst[i]`\n",
    "- `cost`, `lower`, `cap` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge cost and bounds\n",
    "- `supply` : `np.ndarray` shape `(n,)`, dtype `float64` — node supplies/demands (`>0` supply, `<0` demand)\n",
    "- `tol` : `float` — tolerance for bound-status flags\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `out[\"status\"]` : `int` (typically `0` means OPTIMAL)\n",
    "- `out[\"flow\"]` : `np.ndarray` shape `(m,)`, dtype `float64`\n",
    "- `out[\"potential\"]` : `np.ndarray` shape `(n,)`, dtype `float64` (node potentials; defined up to an additive constant)\n",
    "- `out[\"reduced_cost\"]` : `np.ndarray` shape `(m,)`, dtype `float64`, computed as `cost[i] + potential[src[i]] - potential[dst[i]]`\n",
    "- (optional) bound-status flags: boolean arrays indicating whether each arc is at its lower bound or at its capacity\n",
    "- `out[\"total_cost\"]` : `float` (objective value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1ff805be-adf0-44c9-a12c-d4d2fb077251",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'status': 1, 'flow': array([1., 2., 0.]), 'potential': array([-2.,  0., -1.]), 'reduced_cost': array([0., 0., 4.]), 'at_lower': array([False, False,  True]), 'at_upper': array([False,  True, False]), 'total_cost': 4.0}\n"
     ]
    }
   ],
   "source": [
    "# ----- Example of Usage ------ #\n",
    "\n",
    "n = 3\n",
    "src    = np.array([0, 0, 1], dtype=np.int64)\n",
    "dst    = np.array([1, 2, 2], dtype=np.int64)\n",
    "cost   = np.array([2.0, 1.0, 3.0], dtype=np.float64)\n",
    "cap    = np.array([5.0, 2.0, 4.0], dtype=np.float64)\n",
    "lower  = np.zeros_like(cost)\n",
    "supply = np.array([3.0, -1.0, -2.0], dtype=np.float64)\n",
    "\n",
    "out_min_cost_flow = lemon_mcf.solve_mcf(n, src, dst, cost, lower, cap, supply)\n",
    "print(out_min_cost_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5edb874-0d6f-4c0a-a1c0-8707a04bd0c7",
   "metadata": {},
   "source": [
    "It also has the function `lemon_mcf.max_flow` which solves the maximum s-t flow problem. The function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "    out = lemon_mcf.max_flow(n, src, dst, cap, s, t)\n",
    "\n",
    "**Inputs (NumPy arrays)**\n",
    "- `n` : `int` — number of nodes (indexed `0..n-1`)\n",
    "- `src`, `dst` : `np.ndarray` shape `(m,)`, dtype `int64` — directed edges `src[i] -> dst[i]`\n",
    "- `cap` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge capacities (nonnegative)\n",
    "- `s` : `int` — source node index\n",
    "- `t` : `int` — sink node index\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `out[\"value\"]` : `float` — maximum flow value from `s` to `t`\n",
    "- `out[\"flow\"]` : `np.ndarray` shape `(m,)`, dtype `float64` — per-edge flow values (in the same order as the input edges)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "84d7091d-b4dd-46db-84b2-e580dd826951",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'value': 4.0, 'flow': array([2., 2., 2., 2.])}\n"
     ]
    }
   ],
   "source": [
    "# ----- Example of Usage ------ #\n",
    "\n",
    "n = 4\n",
    "src = np.array([0,0,1,2], dtype=np.int64)\n",
    "dst = np.array([1,2,3,3], dtype=np.int64)\n",
    "cap = np.array([3.0,2.0,2.0,4.0], dtype=np.float64)\n",
    "out_max_flow = lemon_mcf.max_flow(n, src, dst, cap, 0, 3)\n",
    "print(out_max_flow)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2d023ed-4551-4a12-895b-5499538ccfe3",
   "metadata": {},
   "source": [
    "# Generating Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c19ed1a-276c-4f7b-930b-15ee5dd1a087",
   "metadata": {},
   "source": [
    "## Synthetic Quadratic "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e668b358-356c-4b0f-8f52-3fa79709f650",
   "metadata": {},
   "source": [
    "#### Quadratic Instance Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d51f4f1d-d52d-4167-8546-975f7056157a",
   "metadata": {},
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "\n",
    "    data = generate_synthetic_convex_quadratic_dataset(d, m, M, X, sigma2, K, seed=None)\n",
    "\n",
    "**Inputs**\n",
    "- `d` : `int` — dimension (`d > 0`)\n",
    "- `m`, `M` : `float` — eigenvalue bounds (`0 < m <= M`)\n",
    "- `X` : `int` — integer box bound (samples are uniform over `[-X, X]^d`, coordinate-wise)\n",
    "- `sigma2` : `float` — noise variance (`sigma2 >= 0`)\n",
    "- `K` : `int` — number of samples (`K > 0`)\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "\n",
    "**Process**\n",
    "1. Sample eigenvalues `λ_1,...,λ_d ~ Uniform([m, M])` i.i.d.\n",
    "2. Sample an (approximately) Haar-uniform orthonormal matrix `U ∈ R^{d×d}` via QR of a Gaussian matrix.\n",
    "3. Form `Q = U diag(λ) U^T` (positive definite since `m > 0`).\n",
    "4. Sample an integer optimizer `x*` uniformly from the integer box `[-X, X]^d`.\n",
    "5. Sample `K` integer covariates `x^1,...,x^K` uniformly from `[-X, X]^d`.\n",
    "6. Sample i.i.d. noise `ε^k ~ N(0, sigma2)` for `k=1,...,K`.\n",
    "7. Set targets  \n",
    "   `y^k = (x^k - x*)^T Q (x^k - x*) + ε^k`.\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `data[\"U\"]` : `np.ndarray` shape `(d, d)` — sampled orthonormal matrix\n",
    "- `data[\"lambdas\"]` : `np.ndarray` shape `(d,)` — sampled eigenvalues\n",
    "- `data[\"Q\"]` : `np.ndarray` shape `(d, d)` — sampled quadratic matrix `Q = U diag(lambdas) U^T`\n",
    "- `data[\"x_star\"]` : `np.ndarray` shape `(d,)`, dtype `int` — sampled optimizer `x*`\n",
    "- `data[\"X_samples\"]` : `np.ndarray` shape `(K, d)`, dtype `int` — sampled inputs `x^k`\n",
    "- `data[\"y\"]` : `np.ndarray` shape `(K,)`, dtype `float` — targets `y^k`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "edd30065-8973-4a0f-8315-99a96ac42355",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _haar_orthonormal_matrix(d: int, rng: np.random.Generator) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Sample an (approximately) Haar-uniform random orthonormal matrix U in R^{dxd}\n",
    "    using QR decomposition of a standard Gaussian matrix, with a sign correction.\n",
    "    \"\"\"\n",
    "    A = rng.standard_normal((d, d))\n",
    "    U, R = np.linalg.qr(A)\n",
    "    s = np.sign(np.diag(R))\n",
    "    s[s == 0] = 1.0\n",
    "    U = U * s  # multiply columns by sign\n",
    "    return U\n",
    "\n",
    "def generate_synthetic_convex_quadratic_dataset(\n",
    "    d: int,\n",
    "    m: float,\n",
    "    M: float,\n",
    "    X: int,\n",
    "    sigma2: float,\n",
    "    K: int,\n",
    "    seed: int | None = None\n",
    "):\n",
    "    if d <= 0:\n",
    "        raise ValueError(\"d must be positive.\")\n",
    "    if not (M >= m):\n",
    "        raise ValueError(\"Require M >= m.\")\n",
    "    if m <= 0:\n",
    "        raise ValueError(\"Require m > 0 so that Q is positive definite.\")\n",
    "    if K <= 0:\n",
    "        raise ValueError(\"K must be positive.\")\n",
    "    if sigma2 < 0:\n",
    "        raise ValueError(\"sigma2 must be nonnegative.\")\n",
    "    if not isinstance(X, (int, np.integer)):\n",
    "        raise ValueError(\"X must be an integer.\")\n",
    "    X = int(X)\n",
    "    if X < 0:\n",
    "        raise ValueError(\"X must be a nonnegative integer.\")\n",
    "\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # (i) sample eigenvalues\n",
    "    lambdas = rng.uniform(m, M, size=d)\n",
    "\n",
    "    # (ii) sample orthonormal matrix U\n",
    "    U = _haar_orthonormal_matrix(d, rng)\n",
    "\n",
    "    # Build Q = U diag(lambdas) U^T\n",
    "    Q = U @ np.diag(lambdas) @ U.T\n",
    "\n",
    "    # Sample integer vectors uniformly from [-X, X]^d\n",
    "    def sample_int_box(num: int) -> np.ndarray:\n",
    "        return rng.integers(-X, X + 1, size=(num, d), dtype=int)\n",
    "\n",
    "    # (iii) sample x*\n",
    "    x_star = sample_int_box(1).reshape(-1)  # shape (d,)\n",
    "\n",
    "    # (iv) sample K vectors x^1,...,x^K\n",
    "    X_samples = sample_int_box(K)  # shape (K, d)\n",
    "\n",
    "    # (v) sample Gaussian noises\n",
    "    sigma = float(np.sqrt(sigma2))\n",
    "    eps = rng.normal(loc=0.0, scale=sigma, size=K)\n",
    "\n",
    "    # (vi) compute y^k\n",
    "    diffs = X_samples - x_star  # shape (K, d)\n",
    "    quad = np.einsum(\"bi,ij,bj->b\", diffs, Q, diffs)  # shape (K,)\n",
    "    y = quad + eps\n",
    "\n",
    "    return {\n",
    "        \"U\": U,\n",
    "        \"lambdas\": lambdas,\n",
    "        \"Q\": Q,\n",
    "        \"x_star\": x_star,\n",
    "        \"X_samples\": X_samples,\n",
    "        \"y\": y,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "666e79e4-8daa-483c-995d-6a3b92367b88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_star: [ 7  4  4 -2  8]\n",
      "X_samples shape: (100, 5)\n",
      "y shape: (100,)\n",
      "min eigenvalue(Q) ~ 0.5413190888213227\n"
     ]
    }
   ],
   "source": [
    "# --- Example of Usage --- #\n",
    "data = generate_synthetic_convex_quadratic_dataset(\n",
    "    d=5, m=0.5, M=3.0, X=10, sigma2=0.25, K=100, seed=0\n",
    ")\n",
    "\n",
    "print(\"x_star:\", data[\"x_star\"])\n",
    "print(\"X_samples shape:\", data[\"X_samples\"].shape)\n",
    "print(\"y shape:\", data[\"y\"].shape)\n",
    "print(\"min eigenvalue(Q) ~\", np.min(np.linalg.eigvalsh(data[\"Q\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f403d314-b3f6-4933-b466-a6b3363839c7",
   "metadata": {},
   "source": [
    "#### Full Dataset Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "302dbd89-36d9-45e5-a1a4-d34846a4a6fc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4c49e91f-881b-4020-8878-c22cf217c00c",
   "metadata": {},
   "source": [
    "## Synthetic Min Cost Flow"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f773417-3500-4b5a-b931-59eacf5b61c2",
   "metadata": {},
   "source": [
    "#### NETGEN Instance Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6c6656d-e5ac-45fd-975d-398e4fbe472b",
   "metadata": {},
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "\n",
    "    u, v, lo, cap, cost, b = netgen_instance_arrays(\n",
    "        nodes, arcs, sources, sinks, cost_bounds, cap_bounds,\n",
    "        total_supply=1000, seed=None\n",
    "    )\n",
    "\n",
    "**Inputs**\n",
    "- `nodes` : `int` — number of nodes (`nodes > 0`)\n",
    "- `arcs` : `int` — number of directed arcs/edges to generate (`arcs > 0`)\n",
    "- `sources` : `int` — number of supply (source) nodes (`sources >= 0`)\n",
    "- `sinks` : `int` — number of demand (sink) nodes (`sinks >= 0`)\n",
    "- `cost_bounds` : `tuple[int, int]` — arc cost bounds `(min_cost, max_cost)` with `min_cost <= max_cost`\n",
    "- `cap_bounds` : `tuple[int, int]` — arc capacity bounds `(min_cap, max_cap)` with `min_cap <= max_cap`\n",
    "- `total_supply` : `int` — total supply injected into the network (`total_supply >= 0`); total demand equals `-total_supply`\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "\n",
    "**Output (arrays)**\n",
    "Returns a 6-tuple:\n",
    "\n",
    "- `u` : `np.ndarray` shape `(m,)`, dtype `int` — tail (source) node index for each arc (0-based)\n",
    "- `v` : `np.ndarray` shape `(m,)`, dtype `int` — head (destination) node index for each arc (0-based)  \n",
    "  (arc `i` is `u[i] -> v[i]`)\n",
    "- `lo` : `np.ndarray` shape `(m,)`, dtype `int` — lower bound for each arc (0 in our instances)\n",
    "- `cap` : `np.ndarray` shape `(m,)`, dtype `int` — capacity for each arc\n",
    "- `cost` : `np.ndarray` shape `(m,)`, dtype `int` — per-unit cost for each arc\n",
    "- `b` : `np.ndarray` shape `(nodes,)`, dtype `int` — node supplies/demands:\n",
    "  - `b[j] > 0` means node `j` is a supply node (source)\n",
    "  - `b[j] < 0` means node `j` is a demand node (sink)\n",
    "  - `sum(b) == 0` (balanced instance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7fb8ebaa-0e13-4c0b-bbf6-00c0b60c4547",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _parse_dimacs_mcf(path: Path) -> tuple[int, Dict[int, int], List[tuple[int, int, int, int, int]]]:\n",
    "    # DIMACS min-cost flow:\n",
    "    # p min N M\n",
    "    # n i b\n",
    "    # a u v lo cap cost\n",
    "    n = None\n",
    "    supply: Dict[int, int] = {}\n",
    "    arcs: List[tuple[int, int, int, int, int]] = []\n",
    "\n",
    "    for raw in path.read_text(encoding=\"utf-8\").splitlines():\n",
    "        line = raw.strip()\n",
    "        if not line or line.startswith(\"c\"):\n",
    "            continue\n",
    "        parts = line.split()\n",
    "        if parts[0] == \"p\":\n",
    "            n = int(parts[2])\n",
    "        elif parts[0] == \"n\":\n",
    "            supply[int(parts[1])] = int(parts[2])\n",
    "        elif parts[0] == \"a\":\n",
    "            u, v, lo, cap, cost = map(int, parts[1:])\n",
    "            arcs.append((u, v, lo, cap, cost))\n",
    "\n",
    "    if n is None:\n",
    "        raise ValueError(\"Missing 'p min N M' line.\")\n",
    "    return n, supply, arcs\n",
    "\n",
    "\n",
    "def netgen_instance_arrays(\n",
    "    *,\n",
    "    nodes: int,\n",
    "    arcs: int,\n",
    "    sources: int,\n",
    "    sinks: int,\n",
    "    cost_bounds: Tuple[int, int],\n",
    "    cap_bounds: Tuple[int, int],\n",
    "    total_supply: int = 1000,\n",
    "    seed: Optional[int] = None,\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      u, v, lo, cap, cost, b\n",
    "    where:\n",
    "      - u,v,lo,cap,cost are arrays of length m\n",
    "      - b is an array of length n with supply(+)/demand(-), sum(b)=0\n",
    "      - nodes are 0-based indices\n",
    "    \"\"\"\n",
    "    if nodes <= 0 or arcs <= 0:\n",
    "        raise ValueError(\"nodes and arcs must be positive.\")\n",
    "    if sources < 0 or sinks < 0 or sources + sinks > nodes:\n",
    "        raise ValueError(\"Need sources+sinks <= nodes.\")\n",
    "    minc, maxc = cost_bounds\n",
    "    mincap, maxcap = cap_bounds\n",
    "    if minc > maxc or mincap > maxcap:\n",
    "        raise ValueError(\"Bad bounds: min must be <= max.\")\n",
    "    if seed is None:\n",
    "        seed = random.randint(1, 2_000_000_000)\n",
    "\n",
    "    try:\n",
    "        import pynetgen  # type: ignore\n",
    "    except ImportError as e:\n",
    "        raise RuntimeError(\n",
    "            \"pynetgen is not installed. Install it (pip install pynetgen), \"\n",
    "            \"or tell me which NETGEN executable you have and I’ll adapt a subprocess call.\"\n",
    "        ) from e\n",
    "\n",
    "    tmpdir = Path(tempfile.mkdtemp(prefix=\"netgen_\"))\n",
    "    dimacs_path = tmpdir / \"instance.dimacs\"\n",
    "\n",
    "    pynetgen.netgen_generate(\n",
    "        seed=seed,\n",
    "        nodes=nodes,\n",
    "        sources=sources,\n",
    "        sinks=sinks,\n",
    "        density=arcs,\n",
    "        mincost=minc,\n",
    "        maxcost=maxc,\n",
    "        supply=total_supply,\n",
    "        tsources=0,\n",
    "        tsinks=0,\n",
    "        hicost=0,\n",
    "        capacitated=100,\n",
    "        mincap=mincap,\n",
    "        maxcap=maxcap,\n",
    "        rng=1,\n",
    "        type=None,\n",
    "        fname=str(dimacs_path),\n",
    "    )\n",
    "\n",
    "    n, supply_1b, arc_list_1b = _parse_dimacs_mcf(dimacs_path)\n",
    "    m = len(arc_list_1b)\n",
    "\n",
    "    u = np.empty(m, dtype=np.int64)\n",
    "    v = np.empty(m, dtype=np.int64)\n",
    "    lo = np.empty(m, dtype=np.int64)\n",
    "    cap = np.empty(m, dtype=np.int64)\n",
    "    cost = np.empty(m, dtype=np.int64)\n",
    "\n",
    "    for i, (uu, vv, l, c, w) in enumerate(arc_list_1b):\n",
    "        u[i] = uu - 1\n",
    "        v[i] = vv - 1\n",
    "        lo[i] = l\n",
    "        cap[i] = c\n",
    "        cost[i] = w\n",
    "\n",
    "    b = np.zeros(n, dtype=np.int64)\n",
    "    for node_1b, val in supply_1b.items():\n",
    "        b[node_1b - 1] = val\n",
    "\n",
    "    if b.sum() != 0:\n",
    "        raise ValueError(f\"Unbalanced instance: sum(b)={b.sum()} (should be 0).\")\n",
    "\n",
    "    return u, v, lo, cap, cost, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "51948156-ee64-4380-9cd9-f31f811dd585",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n=200, m=1500, balance sum(b)=0\n",
      "#sources=10, #sinks=10\n",
      "sources (node, supply): [(0, 1054), (1, 864), (2, 372), (3, 49), (4, 299), (5, 917), (6, 565), (7, 45), (8, 711), (9, 124)]\n",
      "sinks   (node, demand): [(190, -198), (191, -867), (192, -233), (193, -347), (194, -527), (195, -758), (196, -1408), (197, -600), (198, -31), (199, -31)]\n",
      "\n",
      "First edges:\n",
      "   i      u      v     lo     cap    cost\n",
      "--------------------------------------------\n",
      "   0      0    118      0    1054      23\n",
      "   1      0    175      0      84       8\n",
      "   2      0    109      0      96      13\n",
      "   3      0    134      0      55      41\n",
      "   4      0     95      0      92      26\n",
      "   5      0    131      0      95       6\n",
      "   6      0     50      0      16       2\n",
      "   7      0     48      0      59      42\n",
      "   8      0     47      0      76      31\n",
      "   9      0    187      0      19      36\n",
      "\n",
      "cap stats: 0 107.35933333333334 1054\n",
      "cost stats: 1 25.227333333333334 50\n"
     ]
    }
   ],
   "source": [
    "# ------ Example of usage ------ #\n",
    "\n",
    "u, v, lo, cap, cost, b = netgen_instance_arrays(\n",
    "    nodes=200,\n",
    "    arcs=1500,\n",
    "    sources=10,\n",
    "    sinks=10,\n",
    "    cost_bounds=(1, 50),\n",
    "    cap_bounds=(0, 100),\n",
    "    total_supply=5000,\n",
    "    seed=7,\n",
    ")\n",
    "\n",
    "n = len(b)\n",
    "m = len(u)\n",
    "print(f\"n={n}, m={m}, balance sum(b)={int(b.sum())}\")\n",
    "\n",
    "src = np.where(b > 0)[0]\n",
    "snk = np.where(b < 0)[0]\n",
    "print(f\"#sources={len(src)}, #sinks={len(snk)}\")\n",
    "print(\"sources (node, supply):\", [(int(i), int(b[i])) for i in src[:10]])\n",
    "print(\"sinks   (node, demand):\", [(int(i), int(b[i])) for i in snk[:10]])\n",
    "\n",
    "k = 10\n",
    "k = min(k, m)\n",
    "print(\"\\nFirst edges:\")\n",
    "print(f\"{'i':>4}  {'u':>5}  {'v':>5}  {'lo':>5}  {'cap':>6}  {'cost':>6}\")\n",
    "print(\"-\" * 44)\n",
    "for i in range(k):\n",
    "    print(f\"{i:>4}  {int(u[i]):>5}  {int(v[i]):>5}  {int(lo[i]):>5}  {int(cap[i]):>6}  {int(cost[i]):>6}\")\n",
    "\n",
    "print(\"\\ncap stats:\", int(cap.min()), float(cap.mean()), int(cap.max()))\n",
    "print(\"cost stats:\", int(cost.min()), float(cost.mean()), int(cost.max()))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "273eeb7d-38f9-4cdb-9c43-1c3e920c792d",
   "metadata": {},
   "source": [
    "#### NETGEN Dataset Generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "723aab3f-b920-44d6-a891-dc87f6a89b61",
   "metadata": {},
   "source": [
    "The main function can be described as follows:\n",
    "\n",
    "**Signature**\n",
    "\n",
    "    data = generate_netgen_mcf_dataset(n, family, K, tilde_C, sigma2, seed=None,\n",
    "                                       cap_range=(1,1000), cost_range=(1,10000),\n",
    "                                       total_supply=1000, verbose=False)\n",
    "\n",
    "**Inputs**\n",
    "- `n` : `int` — number of nodes in the NETGEN network (`n > 0`)\n",
    "- `family` : `str` — family: `\"sparse\"` or `\"dense\"`\n",
    "  - `\"sparse\"` uses `m = 8n` arcs\n",
    "  - `\"dense\"` uses `m = ceil(n * sqrt(n))` arcs\n",
    "- `K` : `int` — number of balance/label samples to generate (`K > 0`)\n",
    "- `tilde_C` : `int` — max auxiliary cost used when generating balances (`tilde_C > 0`)\n",
    "- `sigma2` : `float` — label noise variance (`sigma2 >= 0`)\n",
    "- `seed` : `int` or `None` — RNG seed\n",
    "- `cap_range` : `(int, int)` — integer capacity range for NETGEN arcs (default `(1,1000)`)\n",
    "- `cost_range` : `(int, int)` — integer cost range for NETGEN arcs (default `(1,10000)`)\n",
    "- `total_supply` : `int` — NETGEN parameter controlling total supply in the feasible balance it produces\n",
    "  (required by NETGEN to generate the graph but the resulting NETGEN balance is discarded)\n",
    "- `verbose` : `bool` — print `F_max` and basic diagnostics if `True`\n",
    "\n",
    "**Process**\n",
    "1. Set the number of supply and demand nodes to  \n",
    "   `|S| = |T| = floor(sqrt(n))`.\n",
    "2. Choose the number of arcs `m` based on `family`:\n",
    "   - If `\"sparse\"`: `m = 8n`\n",
    "   - If `\"dense\"`:  `m = ceil(n * sqrt(n))`\n",
    "3. Generate a NETGEN minimum-cost flow instance on `n` nodes with:\n",
    "   - arc capacities `u_e ~ Unif({1,...,1000})`\n",
    "   - arc costs `c_e ~ Unif({1,...,10000})`\n",
    "   NETGEN also returns a feasible balance vector `b0`; discard its magnitudes but use its sign pattern to\n",
    "   identify the node sets:\n",
    "   - `S = {v : b0(v) > 0}` (supplies)\n",
    "   - `T = {v : b0(v) < 0}` (demands)\n",
    "4. Augment the network by adding a super-source `s` and super-sink `t`:\n",
    "   - add arcs `(s,v)` for all `v ∈ S` with capacity `+∞`\n",
    "   - add arcs `(v,t)` for all `v ∈ T` with capacity `+∞`\n",
    "5. Compute `F_max`, the value of a maximum `s–t` flow in the augmented network.\n",
    "6. For each sample `k = 1,...,K`:\n",
    "   1. Sample a target flow value `F^(k) ~ Unif({0,1,...,F_max})`.\n",
    "   2. Sample auxiliary integer arc costs on the augmented network  \n",
    "      `tilde_c_e^(k) ~ Unif({1,...,tilde_C})`.\n",
    "   3. Solve a minimum-cost `s–t` flow of value `F^(k)` on the augmented network under costs `tilde_c^(k)`,\n",
    "      obtaining a flow `f^(k)`.\n",
    "   4. Define the balance vector `b^(k) ∈ Z^n` on the original node set by:\n",
    "      - `b^(k)(v) = f_sv^(k)` for `v ∈ S`\n",
    "      - `b^(k)(v) = -f_vt^(k)` for `v ∈ T`\n",
    "      - `b^(k)(v) = 0` otherwise\n",
    "   5. Solve the original minimum-cost flow problem on `G(V,E)` with capacities `u`, costs `c`,\n",
    "      and balance `b^(k)` to obtain the optimal cost `z^(k)`.\n",
    "   6. Sample noise `ε^(k) ~ N(0, sigma2)` and set the label `y^(k) = z^(k) + ε^(k)`.\n",
    "\n",
    "**Output (`dict`)**\n",
    "- `data[\"n\"]` : `int` — number of nodes\n",
    "- `data[\"m\"]` : `int` — number of arcs in the original NETGEN graph\n",
    "- `data[\"u\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc tail nodes (0-indexed)\n",
    "- `data[\"v\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc head nodes (0-indexed)\n",
    "- `data[\"lo\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc lower bounds\n",
    "- `data[\"cap\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc capacities\n",
    "- `data[\"cost\"]` : `np.ndarray` shape `(m,)`, dtype `int` — arc costs\n",
    "- `data[\"S_nodes\"]` : `np.ndarray` shape `(|S|,)`, dtype `int` — supply node indices\n",
    "- `data[\"T_nodes\"]` : `np.ndarray` shape `(|T|,)`, dtype `int` — demand node indices\n",
    "- `data[\"X_samples\"]` : `np.ndarray` shape `(K, n)`, dtype `int` — sampled balance vectors `b^(k)`\n",
    "- `data[\"y\"]` : `np.ndarray` shape `(K,)`, dtype `float` — noisy labels `y^(k)`\n",
    "- `data[\"z\"]` : `np.ndarray` shape `(K,)`, dtype `float` — clean optimal costs `z^(k)` (before noise)\n",
    "- `data[\"F_max\"]` : `int` — maximum `s–t` flow value in the augmented network\n",
    "- `data[\"F_targets\"]` : `np.ndarray` shape `(K,)`, dtype `int` — sampled target values `F^(k)`\n",
    "- `data[\"tilde_C\"]` : `int` — auxiliary cost bound used\n",
    "- `data[\"sigma2\"]` : `float` — noise variance used\n",
    "- `data[\"seed\"]` : `int` or `None` — RNG seed used\n",
    "- `data[\"family\"]` : `str` — `\"sparse\"` or `\"dense\"`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8d6820d3-b93a-4172-a39e-4c96595e5999",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "\n",
    "def generate_netgen_mcf_dataset(\n",
    "    n: int,\n",
    "    family: str,                 # \"sparse\" or \"dense\"\n",
    "    K: int,\n",
    "    tilde_C: int,\n",
    "    sigma2: float,\n",
    "    seed: int | None = None,\n",
    "    cap_range: tuple[int, int] = (1, 1000),\n",
    "    cost_range: tuple[int, int] = (1, 10000),\n",
    "    total_supply: int = 1000,\n",
    "    verbose: bool = False,\n",
    "):\n",
    "    if n <= 0:\n",
    "        raise ValueError(\"n must be positive.\")\n",
    "    if K <= 0:\n",
    "        raise ValueError(\"K must be positive.\")\n",
    "    if tilde_C <= 0:\n",
    "        raise ValueError(\"tilde_C must be positive.\")\n",
    "    if sigma2 < 0:\n",
    "        raise ValueError(\"sigma2 must be nonnegative.\")\n",
    "    if family not in {\"sparse\", \"dense\"}:\n",
    "        raise ValueError(\"family must be 'sparse' or 'dense'.\")\n",
    "\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # --- NETGEN sizes ---\n",
    "    S_cnt = int(math.isqrt(n))                 # floor(sqrt(n))\n",
    "    T_cnt = int(math.isqrt(n))                 # floor(sqrt(n))\n",
    "    if family == \"sparse\":\n",
    "        m_target = 8 * n\n",
    "    else:\n",
    "        m_target = int(math.ceil(n * math.sqrt(n)))\n",
    "\n",
    "    # --- generate NETGEN instance (graph + a feasible b we will discard) ---\n",
    "    # Uses your notebook's netgen_instance_arrays\n",
    "    u, v, lo, cap, cost, b0 = netgen_instance_arrays(\n",
    "        nodes=n,\n",
    "        arcs=m_target,\n",
    "        sources=S_cnt,\n",
    "        sinks=T_cnt,\n",
    "        cost_bounds=cost_range,\n",
    "        cap_bounds=cap_range,\n",
    "        total_supply=total_supply,\n",
    "        seed=None if seed is None else int(seed),\n",
    "    )\n",
    "\n",
    "    # Identify supply/demand node sets from NETGEN's b (we discard magnitudes, keep sets)\n",
    "    S_nodes = np.where(b0 > 0)[0]\n",
    "    T_nodes = np.where(b0 < 0)[0]\n",
    "\n",
    "    if len(S_nodes) != S_cnt or len(T_nodes) != T_cnt:\n",
    "        raise RuntimeError(\n",
    "            f\"NETGEN returned |S|={len(S_nodes)}, |T|={len(T_nodes)} but expected \"\n",
    "            f\"{S_cnt} and {T_cnt}. (Check generator settings.)\"\n",
    "        )\n",
    "\n",
    "    # --- build augmented network with super-source s and super-sink t ---\n",
    "    s = n\n",
    "    t = n + 1\n",
    "    N_aug = n + 2\n",
    "\n",
    "    # \"infinite\" capacity for super arcs: big enough not to bind\n",
    "    INF = float(np.sum(cap, dtype=np.int64) + 1)\n",
    "\n",
    "    # original arcs\n",
    "    src_aug = [u.astype(np.int64)]\n",
    "    dst_aug = [v.astype(np.int64)]\n",
    "    cap_aug = [cap.astype(np.float64)]\n",
    "\n",
    "    # s -> v for v in S\n",
    "    src_s = np.full(len(S_nodes), s, dtype=np.int64)\n",
    "    dst_s = S_nodes.astype(np.int64)\n",
    "    cap_s = np.full(len(S_nodes), INF, dtype=np.float64)\n",
    "\n",
    "    # v -> t for v in T\n",
    "    src_t = T_nodes.astype(np.int64)\n",
    "    dst_t = np.full(len(T_nodes), t, dtype=np.int64)\n",
    "    cap_t = np.full(len(T_nodes), INF, dtype=np.float64)\n",
    "\n",
    "    src_aug = np.concatenate(src_aug + [src_s, src_t])\n",
    "    dst_aug = np.concatenate(dst_aug + [dst_s, dst_t])\n",
    "    cap_aug = np.concatenate(cap_aug + [cap_s, cap_t])\n",
    "\n",
    "    # Remember arc indices for extracting b^(k)\n",
    "    m0 = len(u)\n",
    "    idx_s_to_v = {int(vv): m0 + i for i, vv in enumerate(S_nodes)}\n",
    "    idx_v_to_t = {int(vv): m0 + len(S_nodes) + i for i, vv in enumerate(T_nodes)}\n",
    "\n",
    "    # --- compute F_max via max s-t flow on augmented network ---\n",
    "    out_mf = lemon_mcf.max_flow(N_aug, src_aug, dst_aug, cap_aug, s, t)\n",
    "    F_max = int(out_mf[\"value\"])\n",
    "    if verbose:\n",
    "        print(f\"[augmented] F_max = {F_max}\")\n",
    "\n",
    "    # --- generate K balances and labels ---\n",
    "    B = np.zeros((K, n), dtype=np.int64)\n",
    "    z = np.zeros(K, dtype=np.float64)\n",
    "    y = np.zeros(K, dtype=np.float64)\n",
    "    F_targets = np.zeros(K, dtype=np.int64)\n",
    "\n",
    "    lower_aug = np.zeros_like(cap_aug, dtype=np.float64)\n",
    "\n",
    "    for k in range(K):\n",
    "        # sample target value F^(k)\n",
    "        Fk = int(rng.integers(0, F_max + 1))\n",
    "        F_targets[k] = Fk\n",
    "\n",
    "        # sample auxiliary integer costs on ALL augmented arcs\n",
    "        aux_cost = rng.integers(1, tilde_C + 1, size=len(src_aug), dtype=np.int64).astype(np.float64)\n",
    "\n",
    "        # min-cost s-t flow of value Fk via supplies\n",
    "        supply_aug = np.zeros(N_aug, dtype=np.float64)\n",
    "        supply_aug[s] = float(Fk)\n",
    "        supply_aug[t] = -float(Fk)\n",
    "\n",
    "        out_aux = lemon_mcf.solve_mcf(\n",
    "            N_aug,\n",
    "            src_aug,\n",
    "            dst_aug,\n",
    "            aux_cost,\n",
    "            lower_aug,\n",
    "            cap_aug,\n",
    "            supply_aug,\n",
    "        )\n",
    "        if out_aux[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"Aux min-cost flow failed at k={k} with status={out_aux['status']}.\")\n",
    "\n",
    "        flow_aug = out_aux[\"flow\"]  # float array per arc\n",
    "\n",
    "        # build balance b^(k) on original nodes\n",
    "        bk = np.zeros(n, dtype=np.int64)\n",
    "        for vv in S_nodes:\n",
    "            bk[int(vv)] = int(round(float(flow_aug[idx_s_to_v[int(vv)]])))\n",
    "        for vv in T_nodes:\n",
    "            bk[int(vv)] = -int(round(float(flow_aug[idx_v_to_t[int(vv)]])))\n",
    "\n",
    "        B[k] = bk\n",
    "\n",
    "        # solve original min-cost flow to get z^(k)\n",
    "        out_orig = lemon_mcf.solve_mcf(\n",
    "            n,\n",
    "            u.astype(np.int64),\n",
    "            v.astype(np.int64),\n",
    "            cost.astype(np.float64),\n",
    "            lo.astype(np.float64),\n",
    "            cap.astype(np.float64),\n",
    "            bk.astype(np.float64),\n",
    "        )\n",
    "        if out_orig[\"status\"] != 1:\n",
    "            raise RuntimeError(f\"Original min-cost flow failed at k={k} with status={out_orig['status']}.\")\n",
    "\n",
    "        zk = float(out_orig[\"total_cost\"])\n",
    "        z[k] = zk\n",
    "\n",
    "        # noisy label\n",
    "        eps = float(rng.normal(0.0, math.sqrt(sigma2)))\n",
    "        y[k] = zk + eps\n",
    "\n",
    "    return {\n",
    "        # graph\n",
    "        \"n\": n,\n",
    "        \"m\": len(u),\n",
    "        \"u\": u.astype(np.int64),\n",
    "        \"v\": v.astype(np.int64),\n",
    "        \"lo\": lo.astype(np.int64),\n",
    "        \"cap\": cap.astype(np.int64),\n",
    "        \"cost\": cost.astype(np.int64),\n",
    "        \"S_nodes\": S_nodes.astype(np.int64),\n",
    "        \"T_nodes\": T_nodes.astype(np.int64),\n",
    "\n",
    "        # dataset (quadratic-like)\n",
    "        \"X_samples\": B,     # balances b^(k)\n",
    "        \"y\": y,             # noisy labels\n",
    "\n",
    "        # extras (often useful)\n",
    "        \"z\": z,             # clean optimal costs\n",
    "        \"F_max\": F_max,\n",
    "        \"F_targets\": F_targets,\n",
    "        \"tilde_C\": tilde_C,\n",
    "        \"sigma2\": float(sigma2),\n",
    "        \"seed\": seed,\n",
    "        \"family\": family,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "891009e5-34d9-4d9e-a62e-ef7321e4ddc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[augmented] F_max = 53031\n",
      "X_samples shape: (100, 200)\n",
      "y shape: (100,)\n",
      "F_max: 53031\n"
     ]
    }
   ],
   "source": [
    "# ---- Example of Usage ---- #\n",
    "\n",
    "data = generate_netgen_mcf_dataset(\n",
    "    n=200,\n",
    "    family=\"sparse\",\n",
    "    K=100,\n",
    "    tilde_C=100,\n",
    "    sigma2=0.25,\n",
    "    seed=0,\n",
    "    verbose=True,\n",
    ")\n",
    "\n",
    "print(\"X_samples shape:\", data[\"X_samples\"].shape)  # (K, n)\n",
    "print(\"y shape:\", data[\"y\"].shape)                  # (K,)\n",
    "print(\"F_max:\", data[\"F_max\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da97f236-a558-46cc-9803-8ce740c9c2a0",
   "metadata": {},
   "source": [
    "#### Full Dataset Generator "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06e9b3dc-4978-4afe-a58b-b58f2c5c505a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e1cdd06e-81d1-4f27-a39b-81d540a21241",
   "metadata": {},
   "source": [
    "# Models Definition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b131da93-331f-46ef-91e4-1ced802f0e27",
   "metadata": {},
   "source": [
    "### DFN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72c356e0-d483-442b-a32b-49f0d7eab840",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb0c0c1d-e1d6-4259-ad3b-55cd5b258432",
   "metadata": {},
   "source": [
    "## Gradient Boosted Trees"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df43dee6-8978-44d8-a576-cceea5fa46a1",
   "metadata": {},
   "source": [
    "### ICNN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e74105b4-7d14-4b4d-9261-160fc4b44827",
   "metadata": {},
   "source": [
    "### Max-Affine"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8a3d912-fa0a-4acc-a12f-10a089f213de",
   "metadata": {},
   "source": [
    "# Training on Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "609e32f6-7edd-40dd-9764-e06a592d0e73",
   "metadata": {},
   "source": [
    "# Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce95d1e-487b-48e2-b1b8-3db68491ba7c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0827ae1-dfd0-432f-98c2-caa1616d0aca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
