{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ff9630a3",
   "metadata": {},
   "source": [
    "# DFN (modular) — pluggable solvers + integer-scaling for LEMON\n",
    "\n",
    "This notebook rewrites the **DFN-Gurobi** idea in the **concise, modular style** of `DFN.ipynb`:\n",
    "\n",
    "- **Graph = data** (`DigraphSpec`) so you can add new graph generators later.\n",
    "- **Solver choice**: `solver=\"gurobi\"` or `solver=\"lemon\"`.\n",
    "- **Integer-only solvers** (LEMON) receive **integers** by scaling to a fixed decimal precision `10^-p`.\n",
    "- **Objective/gradients are scaled back** so learning behaves as if everything stayed in the original units.\n",
    "- Uses **softplus** for nonnegative costs/caps.\n",
    "\n",
    "At the bottom there are **minimal smoke tests** (forward + backward) that run for any solvers available in your environment."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d934e4f7",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01bb42e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import sys\n",
    "from pathlib import Path\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import numpy as np\n",
    "import cppimport\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3da0367",
   "metadata": {},
   "source": [
    "## (Optional) LEMON binding\n",
    "\n",
    "If you have the `lemon_mcf` cppimport module in your repo (same as your previous notebooks), this will compile/import it.\n",
    "If not, you can still use `solver=\"gurobi\"`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea024152",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If this fails, you can ignore it and use solver=\"gurobi\".\n",
    "try:\n",
    "    repo = Path().resolve().parent\n",
    "    sys.path.insert(0, str(repo))\n",
    "    lemon_mcf = cppimport.imp(\"lemon_mcf\")\n",
    "    HAVE_LEMON = True\n",
    "except Exception as e:\n",
    "    lemon_mcf = None\n",
    "    HAVE_LEMON = False\n",
    "\n",
    "HAVE_LEMON\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff2f4e31",
   "metadata": {},
   "source": [
    "## Fractional precision → integerization (STE)\n",
    "\n",
    "We only do this to satisfy **integer-only solvers**.\n",
    "For a chosen `precision_digits = p`:\n",
    "\n",
    "- `scale = 10^p`, `step = 1/scale`\n",
    "- we pass solver inputs as integers: `round(x * scale)`\n",
    "- we scale the returned objective by `1/scale^2` so the result matches original units."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff71858",
   "metadata": {},
   "outputs": [],
   "source": [
    "class _RoundSTE(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x: torch.Tensor) -> torch.Tensor:  # noqa\n",
    "        return torch.round(x)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g: torch.Tensor) -> torch.Tensor:  # noqa\n",
    "        return g\n",
    "\n",
    "\n",
    "def round_ste(x: torch.Tensor) -> torch.Tensor:\n",
    "    return _RoundSTE.apply(x)\n",
    "\n",
    "\n",
    "def to_int_scaled(x: torch.Tensor, scale: float) -> torch.Tensor:\n",
    "    \"\"\"Return an integer-valued float tensor representing round(x * scale).\n",
    "\n",
    "    Gradients flow through as if this were identity (STE) times the scale factor.\n",
    "    \"\"\"\n",
    "    return round_ste(x * scale)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e380e9d3",
   "metadata": {},
   "source": [
    "## Graph spec + a multilayer generator\n",
    "\n",
    "To add new graphs later, just write another function that returns a `DigraphSpec`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5a903f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass(frozen=True)\n",
    "class DigraphSpec:\n",
    "    n: int\n",
    "    src: torch.Tensor          # (m,) int64\n",
    "    dst: torch.Tensor          # (m,) int64\n",
    "    m_learn: int               # first m_learn arcs are learnable\n",
    "    fixed_cost: torch.Tensor   # (m-m_learn,) float\n",
    "    fixed_cap: torch.Tensor    # (m-m_learn,) float\n",
    "    b_nodes: torch.Tensor      # nodes using learnable b entries\n",
    "    A_nodes: torch.Tensor      # nodes using (A @ x) entries\n",
    "    slack: int                 # node index forced to -sum(other balances)\n",
    "\n",
    "\n",
    "def make_multilayer(layer_sizes, big_cost=1e6, big_cap=1e6, var_frac=0.5) -> DigraphSpec:\n",
    "    \"\"\"Layered graph (like DFN-Gurobi), returned as a generic DigraphSpec.\n",
    "\n",
    "    - Learnable arcs: fully connect Li -> L(i+1)\n",
    "    - Fixed feasibility arcs: fully connect L1 <-> LK with big cost/cap\n",
    "    - Nonzero balances only on boundary layers, with the final node as slack\n",
    "\n",
    "    var_frac controls boundary split into (A@x) vs (learnable b):\n",
    "      A_nodes = first ~var_frac; b_nodes = rest (excluding slack).\n",
    "    \"\"\"\n",
    "    sizes = list(map(int, layer_sizes))\n",
    "    if any(s <= 0 for s in sizes):\n",
    "        raise ValueError(\"all layer sizes must be positive\")\n",
    "\n",
    "    offs = np.cumsum([0] + sizes[:-1]).astype(np.int64)\n",
    "    layers = [np.arange(offs[i], offs[i] + sizes[i], dtype=np.int64) for i in range(len(sizes))]\n",
    "    n = int(sum(sizes))\n",
    "\n",
    "    # learnable arcs: Li -> Li+1 (full bipartite)\n",
    "    srcL, dstL = [], []\n",
    "    for i in range(len(layers) - 1):\n",
    "        U, V = layers[i], layers[i + 1]\n",
    "        srcL.append(np.repeat(U, len(V)))\n",
    "        dstL.append(np.tile(V, len(U)))\n",
    "    srcL = np.concatenate(srcL) if srcL else np.zeros((0,), np.int64)\n",
    "    dstL = np.concatenate(dstL) if dstL else np.zeros((0,), np.int64)\n",
    "    m_learn = int(srcL.size)\n",
    "\n",
    "    # fixed arcs: L1 <-> LK\n",
    "    L1, LK = layers[0], layers[-1]\n",
    "    srcF = np.concatenate([np.repeat(L1, len(LK)), np.repeat(LK, len(L1))])\n",
    "    dstF = np.concatenate([np.tile(LK, len(L1)),   np.tile(L1, len(LK))])\n",
    "\n",
    "    src = np.concatenate([srcL, srcF]).astype(np.int64, copy=False)\n",
    "    dst = np.concatenate([dstL, dstF]).astype(np.int64, copy=False)\n",
    "\n",
    "    slack = int(LK[-1])\n",
    "    LK_wo = LK[LK != slack]\n",
    "\n",
    "    def split(nodes):\n",
    "        k = int(round(float(len(nodes)) * float(var_frac)))\n",
    "        k = max(0, min(k, len(nodes)))\n",
    "        return nodes[:k], nodes[k:]\n",
    "\n",
    "    A1, B1 = split(L1)\n",
    "    Ak, Bk = split(LK_wo)\n",
    "\n",
    "    A_nodes = np.concatenate([A1, Ak]).astype(np.int64)\n",
    "    b_nodes = np.concatenate([B1, Bk]).astype(np.int64)\n",
    "\n",
    "    fixed_cost = torch.full((int(srcF.size),), float(big_cost), dtype=torch.float32)\n",
    "    fixed_cap  = torch.full((int(srcF.size),), float(big_cap),  dtype=torch.float32)\n",
    "\n",
    "    return DigraphSpec(\n",
    "        n=n,\n",
    "        src=torch.from_numpy(src).long(),\n",
    "        dst=torch.from_numpy(dst).long(),\n",
    "        m_learn=m_learn,\n",
    "        fixed_cost=fixed_cost,\n",
    "        fixed_cap=fixed_cap,\n",
    "        b_nodes=torch.from_numpy(b_nodes).long(),\n",
    "        A_nodes=torch.from_numpy(A_nodes).long(),\n",
    "        slack=slack,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9aa077a",
   "metadata": {},
   "source": [
    "## Differentiable min-cost-flow value (LEMON + Gurobi)\n",
    "\n",
    "Both backprops use the same rule you were using:\n",
    "- `d/dcost = flow`\n",
    "- `d/dcap = reduced_cost` when the arc is tight, else `0`\n",
    "- `d/db = mean(potential) - potential`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddba969f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class _MCFValueLEMON(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, n_nodes, src, dst, cost, cap, b):\n",
    "        if lemon_mcf is None:\n",
    "            raise RuntimeError(\"LEMON solver requested but lemon_mcf is not available.\")\n",
    "\n",
    "        n = int(n_nodes)\n",
    "        src = src.long(); dst = dst.long()\n",
    "        m = int(src.numel())\n",
    "        if dst.numel() != m or cost.numel() != m or cap.numel() != m or b.numel() != n:\n",
    "            raise ValueError(\"bad shapes\")\n",
    "\n",
    "        src_np  = src.detach().cpu().contiguous().numpy().astype(np.int64, copy=False)\n",
    "        dst_np  = dst.detach().cpu().contiguous().numpy().astype(np.int64, copy=False)\n",
    "        cost_np = cost.detach().cpu().contiguous().numpy().astype(np.float64, copy=False)\n",
    "        cap_np  = cap.detach().cpu().contiguous().numpy().astype(np.float64, copy=False)\n",
    "        b_np    = b.detach().cpu().contiguous().numpy().astype(np.float64, copy=False)\n",
    "\n",
    "        out = lemon_mcf.solve_mcf(n, src_np, dst_np, cost_np, cap_np, b_np)\n",
    "        if int(out.get(\"status\", -1)) != 1:\n",
    "            raise RuntimeError(f\"LEMON failed (status={out.get('status', -1)})\")\n",
    "\n",
    "        flow = out[\"flow\"].astype(np.float64, copy=False)\n",
    "        pot  = out[\"potential\"].astype(np.float64, copy=False)\n",
    "        red  = out[\"reduced_cost\"].astype(np.float64, copy=False)\n",
    "        at   = np.abs(flow - cap_np) <= 1e-9\n",
    "\n",
    "        ctx.flow, ctx.pot, ctx.red, ctx.at = flow, pot, red, at\n",
    "        return cost.new_tensor(float(out[\"total_cost\"]))\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g):\n",
    "        dev, dt = g.device, g.dtype\n",
    "        flow = torch.as_tensor(ctx.flow, device=dev, dtype=dt)\n",
    "        pot  = torch.as_tensor(ctx.pot,  device=dev, dtype=dt)\n",
    "        red  = torch.as_tensor(ctx.red,  device=dev, dtype=dt)\n",
    "        at   = torch.as_tensor(ctx.at,   device=dev, dtype=torch.bool)\n",
    "\n",
    "        grad_cost = flow\n",
    "        grad_cap  = torch.where(at, red, torch.zeros_like(red))\n",
    "        grad_b    = pot.mean() - pot\n",
    "        return None, None, None, grad_cost * g, grad_cap * g, grad_b * g\n",
    "\n",
    "\n",
    "class _MCFValueGUROBI(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, n_nodes, src, dst, cost, cap, b):\n",
    "        try:\n",
    "            import gurobipy as gp\n",
    "            from gurobipy import GRB\n",
    "        except Exception as e:\n",
    "            raise ImportError(\"gurobipy is required (with a valid license).\") from e\n",
    "\n",
    "        n = int(n_nodes)\n",
    "        src = src.long(); dst = dst.long()\n",
    "        m = int(src.numel())\n",
    "        if dst.numel() != m or cost.numel() != m or cap.numel() != m or b.numel() != n:\n",
    "            raise ValueError(\"bad shapes\")\n",
    "\n",
    "        src_np  = src.detach().cpu().contiguous().numpy().astype(np.int64, copy=False)\n",
    "        dst_np  = dst.detach().cpu().contiguous().numpy().astype(np.int64, copy=False)\n",
    "        cost_np = cost.detach().cpu().contiguous().numpy().astype(np.float64, copy=False)\n",
    "        cap_np  = cap.detach().cpu().contiguous().numpy().astype(np.float64, copy=False)\n",
    "        b_np    = b.detach().cpu().contiguous().numpy().astype(np.float64, copy=False).copy()\n",
    "\n",
    "                out_idx = [[] for _ in range(n)]\n",
    "        in_idx  = [[] for _ in range(n)]\n",
    "        for k in range(m):\n",
    "            out_idx[int(src_np[k])].append(k)\n",
    "            in_idx[int(dst_np[k])].append(k)\n",
    "\n",
    "        model = gp.Model()\n",
    "        model.Params.OutputFlag = 0\n",
    "\n",
    "        x = model.addVars(m, lb=0.0, ub=cap_np.tolist(), obj=cost_np.tolist(), name=\"x\")\n",
    "        bal = []\n",
    "        for i in range(n):\n",
    "            bal.append(model.addConstr(\n",
    "                gp.quicksum(x[k] for k in out_idx[i]) - gp.quicksum(x[k] for k in in_idx[i]) == float(b_np[i])\n",
    "            ))\n",
    "\n",
    "        model.optimize()\n",
    "        if model.Status != GRB.OPTIMAL:\n",
    "            raise RuntimeError(f\"Gurobi failed (status={model.Status})\")\n",
    "\n",
    "        flow = np.fromiter((x[k].X  for k in range(m)), dtype=np.float64, count=m)\n",
    "        red  = np.fromiter((x[k].RC for k in range(m)), dtype=np.float64, count=m)\n",
    "        pot  = -np.fromiter((bal[i].Pi for i in range(n)), dtype=np.float64, count=n)\n",
    "        at   = np.abs(flow - cap_np) <= 1e-9\n",
    "\n",
    "        ctx.flow, ctx.pot, ctx.red, ctx.at = flow, pot, red, at\n",
    "        return cost.new_tensor(float(model.ObjVal))\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, g):\n",
    "        dev, dt = g.device, g.dtype\n",
    "        flow = torch.as_tensor(ctx.flow, device=dev, dtype=dt)\n",
    "        pot  = torch.as_tensor(ctx.pot,  device=dev, dtype=dt)\n",
    "        red  = torch.as_tensor(ctx.red,  device=dev, dtype=dt)\n",
    "        at   = torch.as_tensor(ctx.at,   device=dev, dtype=torch.bool)\n",
    "\n",
    "        grad_cost = flow\n",
    "        grad_cap  = torch.where(at, red, torch.zeros_like(red))\n",
    "        grad_b    = pot.mean() - pot\n",
    "        return None, None, None, grad_cost * g, grad_cap * g, grad_b * g\n",
    "\n",
    "\n",
    "def min_cost_flow_value(n_nodes, src, dst, cost, cap, b, solver=\"lemon\"):\n",
    "    s = str(solver).lower()\n",
    "    if s in (\"lemon\", \"lemon_mcf\"):\n",
    "        return _MCFValueLEMON.apply(n_nodes, src, dst, cost, cap, b)\n",
    "    if s in (\"gurobi\", \"grb\"):\n",
    "        return _MCFValueGUROBI.apply(n_nodes, src, dst, cost, cap, b)\n",
    "    raise ValueError(f\"unknown solver: {solver}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39f9493d",
   "metadata": {},
   "source": [
    "## DFN module (modular + solver choice)\n",
    "\n",
    "Key points vs your older `DFN.ipynb`:\n",
    "\n",
    "- We **do not quantize A itself** by default (only the solver inputs if integerization is enabled).\n",
    "- `integerize='auto'` means: use integers for LEMON, floats for Gurobi.\n",
    "- Optional **affine output** (`alpha * obj + beta`) like DFN-Gurobi."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "214a4fb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DFN(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        graph: DigraphSpec,\n",
    "        input_dim: int,\n",
    "        *,\n",
    "        solver: str = \"gurobi\",\n",
    "        precision_digits: int = 3,\n",
    "        eps_pos: float = 0.0,\n",
    "        integerize: str | bool = \"auto\",\n",
    "        affine_obj: bool = True,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.n = int(graph.n)\n",
    "        self.m_learn = int(graph.m_learn)\n",
    "        self.slack = int(graph.slack)\n",
    "\n",
    "        self.solver = str(solver)\n",
    "        self.scale = float(10 ** int(precision_digits))\n",
    "        self.step  = 1.0 / self.scale\n",
    "        self.eps_pos = float(eps_pos)\n",
    "\n",
    "        if integerize == \"auto\":\n",
    "            self.integerize = self.solver.lower() in (\"lemon\", \"lemon_mcf\")\n",
    "        else:\n",
    "            self.integerize = bool(integerize)\n",
    "\n",
    "        # graph buffers\n",
    "        self.register_buffer(\"src\", graph.src.long(), persistent=False)\n",
    "        self.register_buffer(\"dst\", graph.dst.long(), persistent=False)\n",
    "        self.register_buffer(\"b_nodes\", graph.b_nodes.long(), persistent=False)\n",
    "        self.register_buffer(\"A_nodes\", graph.A_nodes.long(), persistent=False)\n",
    "        self.register_buffer(\"fixed_cost\", graph.fixed_cost.float(), persistent=False)\n",
    "        self.register_buffer(\"fixed_cap\",  graph.fixed_cap.float(),  persistent=False)\n",
    "\n",
    "        nb = int(self.b_nodes.numel())\n",
    "        na = int(self.A_nodes.numel())\n",
    "\n",
    "        # learnables\n",
    "        self.A = nn.Parameter(0.01 * torch.randn(na, int(input_dim)))\n",
    "        self.b = nn.Parameter(torch.zeros(nb))\n",
    "\n",
    "        self.c_raw = nn.Parameter(0.01 * torch.randn(self.m_learn))\n",
    "        self.u_raw = nn.Parameter(0.01 * torch.randn(self.m_learn))\n",
    "\n",
    "        self.affine_obj = bool(affine_obj)\n",
    "        if self.affine_obj:\n",
    "            self.alpha = nn.Parameter(torch.ones(()))\n",
    "            self.beta  = nn.Parameter(torch.zeros(()))\n",
    "\n",
    "    def _solve_one(self, x1: torch.Tensor) -> torch.Tensor:\n",
    "        dev = x1.device\n",
    "        dt = torch.float64\n",
    "\n",
    "        # nonnegative learnable arc params\n",
    "        c = (F.softplus(self.c_raw) + self.eps_pos).to(dev, dtype=dt)\n",
    "        u = (F.softplus(self.u_raw) + self.eps_pos).to(dev, dtype=dt)\n",
    "\n",
    "        fC = self.fixed_cost.to(dev, dtype=dt)\n",
    "        fU = self.fixed_cap.to(dev, dtype=dt)\n",
    "\n",
    "        # balances\n",
    "        bvec = torch.zeros(self.n, device=dev, dtype=dt)\n",
    "        if self.b_nodes.numel():\n",
    "            bvec[self.b_nodes.to(dev)] = self.b.to(dev, dtype=dt)\n",
    "        if self.A_nodes.numel():\n",
    "            bvec[self.A_nodes.to(dev)] = (self.A.to(dev, dtype=dt) @ x1.to(dtype=dt))\n",
    "        bvec[self.slack] -= bvec.sum()\n",
    "\n",
    "        if self.integerize:\n",
    "            # integer solver inputs (as float tensors holding integers)\n",
    "            cI = to_int_scaled(c, self.scale)\n",
    "            uI = to_int_scaled(u, self.scale)\n",
    "            fCI = torch.round(fC * self.scale)\n",
    "            fUI = torch.round(fU * self.scale)\n",
    "\n",
    "            costI = torch.cat([cI, fCI]) if fCI.numel() else cI\n",
    "            capI  = torch.cat([uI, fUI]) if fUI.numel() else uI\n",
    "\n",
    "            bI = to_int_scaled(bvec, self.scale)\n",
    "            bI[self.slack] -= bI.sum()\n",
    "\n",
    "            objI = min_cost_flow_value(self.n, self.src.to(dev), self.dst.to(dev), costI, capI, bI, solver=self.solver)\n",
    "            out = objI * (self.step ** 2)\n",
    "        else:\n",
    "            cost = torch.cat([c, fC]) if fC.numel() else c\n",
    "            cap  = torch.cat([u, fU]) if fU.numel() else u\n",
    "            out = min_cost_flow_value(self.n, self.src.to(dev), self.dst.to(dev), cost, cap, bvec, solver=self.solver)\n",
    "\n",
    "        if self.affine_obj:\n",
    "            out = self.alpha.to(out.dtype) * out + self.beta.to(out.dtype)\n",
    "\n",
    "        return out\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        if x.dim() == 1:\n",
    "            y = self._solve_one(x)\n",
    "            return y.to(x.dtype)\n",
    "\n",
    "        ys = [self._solve_one(xi) for xi in x]\n",
    "        return torch.stack(ys, 0).to(x.dtype)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d11fa13a",
   "metadata": {},
   "source": [
    "## Minimal smoke tests (forward + backward)\n",
    "\n",
    "These tests are intentionally small and fast:\n",
    "\n",
    "- graph sanity checks\n",
    "- forward produces finite values\n",
    "- backward produces finite gradients\n",
    "\n",
    "They run for whichever solvers are available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd755750",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _have_gurobi():\n",
    "    try:\n",
    "        import gurobipy  # noqa\n",
    "        return True\n",
    "    except Exception:\n",
    "        return False\n",
    "\n",
    "\n",
    "def smoke_test(device=\"cpu\"):\n",
    "    torch.manual_seed(0)\n",
    "\n",
    "    g = make_multilayer([3, 4, 3], var_frac=0.5)\n",
    "    assert g.n == 10\n",
    "    assert g.m_learn > 0\n",
    "    assert g.slack not in g.A_nodes.tolist()\n",
    "    assert g.slack not in g.b_nodes.tolist()\n",
    "\n",
    "    x = torch.randn(6, 5, device=device)\n",
    "\n",
    "    solvers = []\n",
    "    if _have_gurobi():\n",
    "        solvers.append(\"gurobi\")\n",
    "    if HAVE_LEMON:\n",
    "        solvers.append(\"lemon\")\n",
    "\n",
    "    if not solvers:\n",
    "        print(\"No solvers available in this environment.\")\n",
    "        return\n",
    "\n",
    "    for s in solvers:\n",
    "        model = DFN(\n",
    "            g,\n",
    "            input_dim=5,\n",
    "            solver=s,\n",
    "            integerize=\"auto\",\n",
    "            precision_digits=3,\n",
    "            eps_pos=1e-6,\n",
    "            affine_obj=True,\n",
    "        ).to(device)\n",
    "\n",
    "        y = model(x).mean()\n",
    "        assert torch.isfinite(y).all()\n",
    "\n",
    "        y.backward()\n",
    "\n",
    "        for name, p in model.named_parameters():\n",
    "            assert p.grad is not None, f\"missing grad: {name}\"\n",
    "            assert torch.isfinite(p.grad).all(), f\"bad grad: {name}\"\n",
    "\n",
    "    print(f\"Smoke test passed for solvers: {solvers}\")\n",
    "\n",
    "\n",
    "smoke_test()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## Training tests (DFN vs MLP)\n\nThis mirrors the *training sanity checks* from the original notebooks: a small synthetic regression task (integer quadratic) and a side-by-side DFN vs MLP comparison with live plots.\n\n- Use `solver=\"gurobi\"` for fast iteration (continuous).\n- Use `solver=\"lemon\"` to test integer-scaling + LEMON (requires your `lemon_mcf` binding).\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "import sys, math, random\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\n\ntorch.set_printoptions(precision=4, sci_mode=False)\n\ndef _is_notebook():\n    return \"ipykernel\" in sys.modules\n\n# -------------------- data: integer quadratic --------------------\ndef make_int_quadratic_data(\n    N=4096, d=10, x_min=-20, x_max=20, noise=0.0, seed=0, device=\"cpu\"\n):\n    g = torch.Generator(device=\"cpu\").manual_seed(seed)\n    X = torch.randint(x_min, x_max, (N, d), generator=g).float().to(device)\n\n    R = torch.randn(d, d, generator=g).to(device)\n    Q = R.T @ R + 0.2 * torch.eye(d, device=device)\n    lin = torch.randn(d, generator=g).to(device)\n\n    y = (X @ Q * X).sum(dim=1) + X @ lin\n    if noise > 0:\n        y = y + noise * torch.randn_like(y, generator=g)\n\n    # normalize (helps optimization)\n    y = (y - y.mean()) / (y.std() + 1e-8)\n    return X, y\n\n# -------------------- baseline --------------------\nclass MLP(nn.Module):\n    def __init__(self, d, h=128):\n        super().__init__()\n        self.net = nn.Sequential(\n            nn.Linear(d, h), nn.ReLU(),\n            nn.Linear(h, h), nn.ReLU(),\n            nn.Linear(h, 1),\n        )\n    def forward(self, x):\n        return self.net(x).squeeze(-1)\n\n# -------------------- plotting helpers --------------------\ndef setup_live_loss(title=\"Live MSE\"):\n    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))\n    ax1.set_title(title)\n    ax1.set_xlabel(\"step\"); ax1.set_ylabel(\"MSE\")\n    ax2.set_title(\"DFN - MLP gap\")\n    ax2.set_xlabel(\"step\"); ax2.set_ylabel(\"MSE gap\")\n\n    lines = (\n        ax1.plot([], [], label=\"DFN train\")[0],\n        ax1.plot([], [], label=\"DFN val\")[0],\n        ax1.plot([], [], label=\"MLP train\")[0],\n        ax1.plot([], [], label=\"MLP val\")[0],\n        ax2.plot([], [], label=\"gap train\")[0],\n        ax2.plot([], [], label=\"gap val\")[0],\n    )\n    ax1.legend(); ax2.legend()\n\n    disp = None\n    if _is_notebook():\n        from IPython.display import display\n        disp = display(fig, display_id=True)\n    else:\n        plt.ion(); plt.show(block=False)\n    return fig, (ax1, ax2), lines, disp\n\n@torch.no_grad()\ndef update_live_loss(fig, axes, lines, disp, steps, dtr, dva, mtr, mva):\n    ax1, ax2 = axes\n    l_dtr, l_dva, l_mtr, l_mva, l_gtr, l_gva = lines\n\n    l_dtr.set_data(steps, dtr); l_dva.set_data(steps, dva)\n    l_mtr.set_data(steps, mtr); l_mva.set_data(steps, mva)\n    gap_tr = [a - b for a, b in zip(dtr, mtr)]\n    gap_va = [a - b for a, b in zip(dva, mva)]\n    l_gtr.set_data(steps, gap_tr); l_gva.set_data(steps, gap_va)\n\n    ax1.relim(); ax1.autoscale_view()\n    ax2.relim(); ax2.autoscale_view()\n    fig.canvas.draw()\n    if disp is not None:\n        disp.update(fig)\n    else:\n        fig.canvas.flush_events()\n        plt.pause(0.001)\n\ndef setup_scatter(y_true_cpu, mn, mx, title):\n    fig, (axd, axm) = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)\n    scd = axd.scatter([], [], s=12)\n    scm = axm.scatter([], [], s=12)\n    for ax in (axd, axm):\n        ax.plot([mn, mx], [mn, mx])\n        ax.set_xlim(mn, mx); ax.set_ylim(mn, mx)\n        ax.set_xlabel(\"y_true\")\n    axd.set_ylabel(\"y_pred\")\n    axd.set_title(f\"{title} (DFN)\")\n    axm.set_title(f\"{title} (MLP)\")\n\n    disp = None\n    if _is_notebook():\n        from IPython.display import display\n        disp = display(fig, display_id=True)\n    else:\n        plt.ion(); plt.show(block=False)\n    return fig, (axd, axm), (scd, scm), disp\n\n@torch.no_grad()\ndef update_scatter(fig, axes, scs, disp, y_true_cpu, yp_dfn_cpu, yp_mlp_cpu, dfn_mse=None, mlp_mse=None):\n    scd, scm = scs\n    scd.set_offsets(np.stack([y_true_cpu, yp_dfn_cpu], 1))\n    scm.set_offsets(np.stack([y_true_cpu, yp_mlp_cpu], 1))\n    axd, axm = axes\n    if dfn_mse is not None:\n        axd.set_title(f\"Val scatter (DFN) mse={dfn_mse:.4f}\")\n    if mlp_mse is not None:\n        axm.set_title(f\"Val scatter (MLP) mse={mlp_mse:.4f}\")\n    fig.canvas.draw()\n    if disp is not None:\n        disp.update(fig)\n    else:\n        fig.canvas.flush_events()\n        plt.pause(0.001)\n\n# -------------------- training test --------------------\ndef training_test(\n    graph_spec,\n    *,\n    device=\"cpu\",\n    d=10,\n    N=4096,\n    steps=2000,\n    B=32,\n    split=0.8,\n    solver=\"gurobi\",\n    precision_digits=3,\n    eps_pos=1e-6,\n    dfn_lrs=None,\n    mlp_lr=1e-3,\n    update_every=20,\n    scatter_n=512,\n    x_min=-20,\n    x_max=20,\n    seed=0,\n):\n    '''\n    Train DFN and an MLP baseline on the same synthetic regression target.\n    This is meant as a sanity test for learning dynamics (like in the original notebooks).\n    '''\n    device = torch.device(device)\n    X, y = make_int_quadratic_data(N=N, d=d, x_min=x_min, x_max=x_max, seed=seed, device=device)\n    perm = torch.randperm(N, device=device)\n    ntr = int(split * N)\n    tr, va = perm[:ntr], perm[ntr:]\n    Xtr, ytr = X[tr], y[tr]\n    Xva, yva = X[va], y[va]\n\n    dfn = DFN(graph_spec, input_dim=d, solver=solver, precision_digits=precision_digits, eps_pos=eps_pos, affine_obj=True).to(device)\n    mlp = MLP(d, h=128).to(device)\n\n    dfn_lrs = dfn_lrs or dict(cost=5e-2, cap=5e-2, A=5e-2, b=5e-2, alpha=1e-2, beta=1e-2)\n    dfn_params = [\n        {\"params\": [dfn.c_raw], \"lr\": dfn_lrs[\"cost\"]},\n        {\"params\": [dfn.u_raw], \"lr\": dfn_lrs[\"cap\"]},\n        {\"params\": [dfn.A],     \"lr\": dfn_lrs[\"A\"]},\n        {\"params\": [dfn.b],     \"lr\": dfn_lrs[\"b\"]},\n    ]\n    if getattr(dfn, \"affine_obj\", False):\n        dfn_params += [\n            {\"params\": [dfn.alpha], \"lr\": dfn_lrs.get(\"alpha\", 1e-2)},\n            {\"params\": [dfn.beta],  \"lr\": dfn_lrs.get(\"beta\",  1e-2)},\n        ]\n    opt_dfn = torch.optim.Adam(dfn_params, betas=(0.9, 0.999), eps=1e-8)\n    opt_mlp = torch.optim.Adam(mlp.parameters(), lr=mlp_lr)\n\n    fig_l, axes_l, lines_l, disp_l = setup_live_loss(f\"Live MSE (solver={solver}, p={precision_digits})\")\n\n    ns = min(int(scatter_n), Xva.size(0))\n    sc_idx = torch.randperm(Xva.size(0), device=device)[:ns]\n    Xsc, ysc = Xva[sc_idx], yva[sc_idx]\n    ysc_cpu = ysc.detach().cpu().numpy()\n    mn, mx = float(ysc_cpu.min()), float(ysc_cpu.max())\n    fig_s, axes_s, scs_s, disp_s = setup_scatter(ysc_cpu, mn, mx, title=\"Val scatter\")\n\n    mse = nn.MSELoss()\n    steps_x, dtr, dva, mtr, mva = [], [], [], [], []\n\n    for t in range(1, steps + 1):\n        idx = torch.randint(0, Xtr.size(0), (B,), device=device)\n        xb, yb = Xtr[idx], ytr[idx]\n\n        # DFN\n        opt_dfn.zero_grad(set_to_none=True)\n        yp_d = dfn(xb)\n        loss_d = mse(yp_d, yb)\n        loss_d.backward()\n        opt_dfn.step()\n\n        # MLP\n        opt_mlp.zero_grad(set_to_none=True)\n        yp_m = mlp(xb)\n        loss_m = mse(yp_m, yb)\n        loss_m.backward()\n        opt_mlp.step()\n\n        if (t % update_every) == 0 or t == 1:\n            with torch.no_grad():\n                n_eval = min(1024, Xtr.size(0))\n                yp_d_tr = dfn(Xtr[:n_eval])\n                yp_m_tr = mlp(Xtr[:n_eval])\n                dtr_mse = float(mse(yp_d_tr, ytr[:n_eval]).item())\n                mtr_mse = float(mse(yp_m_tr, ytr[:n_eval]).item())\n\n                yp_d_va = dfn(Xva)\n                yp_m_va = mlp(Xva)\n                dva_mse = float(mse(yp_d_va, yva).item())\n                mva_mse = float(mse(yp_m_va, yva).item())\n\n                steps_x.append(t)\n                dtr.append(dtr_mse); dva.append(dva_mse)\n                mtr.append(mtr_mse); mva.append(mva_mse)\n\n                update_live_loss(fig_l, axes_l, lines_l, disp_l, steps_x, dtr, dva, mtr, mva)\n\n                # scatter snapshot\n                yp_d_sc = dfn(Xsc).detach().cpu().numpy()\n                yp_m_sc = mlp(Xsc).detach().cpu().numpy()\n                update_scatter(fig_s, axes_s, scs_s, disp_s, ysc_cpu, yp_d_sc, yp_m_sc, dfn_mse=dva_mse, mlp_mse=mva_mse)\n\n    return dict(\n        steps=steps_x,\n        dfn_train=dtr, dfn_val=dva,\n        mlp_train=mtr, mlp_val=mva,\n        dfn=dfn, mlp=mlp,\n    )\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "### Example run\n\nTry a quick run first, then increase `steps`.\n"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\ng = make_multilayer([3, 4, 3], var_frac=0.5)\n\n# Gurobi path (continuous) – should train smoothly\nlogs = training_test(\n    g,\n    device=device,\n    d=10,\n    N=2000,\n    steps=400,      # increase to 2000+ for clearer curves\n    B=32,\n    solver=\"gurobi\",\n    precision_digits=3,\n    update_every=20,\n)\n"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}