{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ea8429b2",
   "metadata": {},
   "source": [
    "# Gradient Descent Recovery\n",
    "\n",
    "We consider fixed-step gradient descent on a convex $L$-smooth function $f$. The reference point $x_\\star$ is stationary, $\\nabla f(x_\\star)=0$, and the initialization satisfies $\\|x_0-x_\\star\\| \\le R$. The performance metric is $f(x_N)-f(x_\\star)$.\n",
    "\n",
    "$$\n",
    "x_{k+1}=x_k-\\frac{1}{L}\\nabla f(x_k), \\qquad k=0,1,\\ldots,N-1.\n",
    "$$\n",
    "\n",
    "The Block 1 numerical PEP sweep suggests the tight rate\n",
    "$$\n",
    "f(x_N)-f(x_\\star) \\le \\frac{L R^2}{4N+2}.\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c3449bc",
   "metadata": {},
   "source": [
    "## Proof Statement"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8939000",
   "metadata": {},
   "source": [
    "### Theorem\n",
    "\n",
    "Let $f$ be convex and $1$-smooth, and let $x_{\\star}$ satisfy $\\nabla f(x_{\\star})=0$. Consider gradient descent\n",
    "\n",
    "$$\n",
    "x_{k+1}=x_k-\\nabla f(x_k),\n",
    "$$\n",
    "\n",
    "with $\\|x_0-x_{\\star}\\|\\le R$. For $1\\le k<N$, define\n",
    "\n",
    "$$\n",
    "V_k=\\frac{k}{2N-k+1}\\bigl(f(x_k)-f(x_{\\star})\\bigr)-\\frac{1}{4N+2}\\|x_0-x_{\\star}\\|^2+\\frac{k}{2(2N-k+1)}\\|\\nabla f(x_k)\\|^2+\\frac{2N-2k+1}{2(2N-k+1)^2}\\|x_k-x_{\\star}\\|^2.\n",
    "$$\n",
    "\n",
    "At the terminal index,\n",
    "\n",
    "$$\n",
    "V_N=f(x_N)-f(x_{\\star})-\\frac{1}{4N+2}\\|x_0-x_{\\star}\\|^2.\n",
    "$$\n",
    "\n",
    "Then $V_N\\le0$, and therefore\n",
    "\n",
    "$$\n",
    "f(x_N)-f(x_{\\star})\\le \\frac{R^2}{4N+2}.\n",
    "$$\n",
    "\n",
    "For a general smoothness parameter $L$, rescaling gives\n",
    "\n",
    "$$\n",
    "f(x_N)-f(x_{\\star})\\le \\frac{L R^2}{4N+2}.\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b46c605f",
   "metadata": {},
   "source": [
    "### Proof outline\n",
    "\n",
    "Write the smooth-convex interpolation residual as\n",
    "\n",
    "$$\n",
    "I_f(u,v)=f(v)-f(u)+\\langle \\nabla f(v),u-v\\rangle+\\frac12\\|\\nabla f(u)-\\nabla f(v)\\|^2\\le0.\n",
    "$$\n",
    "\n",
    "Let\n",
    "\n",
    "$$\n",
    "\\alpha_k=\\frac{k}{2N-k+1},\\qquad\n",
    "\\delta_k=\\frac{4Nk+2N-2k^2+1}{2(2N-k)^2}.\n",
    "$$\n",
    "\n",
    "The base identity is\n",
    "\n",
    "$$\n",
    "V_1=\\frac{1}{2N}I_f(x_0,x_1)+\\frac{1}{2N}I_f(x_{\\star},x_0)-\\frac{2N+1}{8N^2}\\left\\|\\nabla f(x_0)-\\frac{x_0-x_{\\star}}{2N+1}\\right\\|^2.\n",
    "$$\n",
    "\n",
    "For $1\\le k<N-1$, the step identity is\n",
    "\n",
    "$$\n",
    "V_{k+1}-V_k=\\alpha_{k+1}I_f(x_k,x_{k+1})+(\\alpha_{k+1}-\\alpha_k)I_f(x_{\\star},x_k)-\\delta_k\\left\\|\\nabla f(x_k)-\\frac{x_k-x_{\\star}}{2N-k+1}\\right\\|^2.\n",
    "$$\n",
    "\n",
    "Every interpolation residual is nonpositive, every multiplier is nonnegative, and each square term is nonnegative. Hence $V_1\\le0$ and $V_{k+1}\\le V_k$, so $V_N\\le0$. The boundary identity is exactly\n",
    "\n",
    "$$\n",
    "V_N=f(x_N)-f(x_{\\star})-\\frac{1}{4N+2}\\|x_0-x_{\\star}\\|^2,\n",
    "$$\n",
    "\n",
    "which gives the stated bound."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d59fe397",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:36.350175Z",
     "iopub.status.busy": "2026-05-13T15:46:36.349930Z",
     "iopub.status.idle": "2026-05-13T15:46:41.279379Z",
     "shell.execute_reply": "2026-05-13T15:46:41.278490Z"
    }
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import sys\n",
    "\n",
    "ROOT = Path.cwd().resolve()\n",
    "while not (ROOT / \"pyproject.toml\").exists() and ROOT != ROOT.parent:\n",
    "    ROOT = ROOT.parent\n",
    "if str(ROOT) not in sys.path:\n",
    "    sys.path.insert(0, str(ROOT))\n",
    "ALGO_DIR = ROOT / \"examples\" / \"gd_recover\"\n",
    "STATE_DIR = ALGO_DIR / \"state\"\n",
    "\n",
    "import matplotlib  # noqa: E402\n",
    "\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt  # noqa: E402\n",
    "import pepflow as pf  # noqa: E402"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20943a2d",
   "metadata": {},
   "source": [
    "## PEP Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "880be6fc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.283076Z",
     "iopub.status.busy": "2026-05-13T15:46:41.282232Z",
     "iopub.status.idle": "2026-05-13T15:46:41.289982Z",
     "shell.execute_reply": "2026-05-13T15:46:41.289121Z"
    }
   },
   "outputs": [],
   "source": [
    "L = pf.Parameter(\"L\")\n",
    "R = pf.Parameter(\"R\")\n",
    "f = pf.SmoothConvexFunction(is_basis=True, tags=[\"f\"], L=L)\n",
    "\n",
    "\n",
    "def make_ctx_gd_recover(ctx_name: str, N) -> pf.PEPContext:\n",
    "    ctx = pf.PEPContext(ctx_name).set_as_current()\n",
    "    x = pf.Vector(is_basis=True, tags=[\"x_0\"])\n",
    "    f.set_stationary_point(\"x_star\")\n",
    "\n",
    "    for i in range(int(N)):\n",
    "        x = x - (1 / L) * f.grad(x)\n",
    "        x.add_tag(f\"x_{i + 1}\")\n",
    "\n",
    "    return ctx\n",
    "\n",
    "\n",
    "def get_pep_setup(N, params):\n",
    "    ctx = make_ctx_gd_recover(f\"ctx_{N}\", N)\n",
    "    pb = pf.PEPBuilder(ctx)\n",
    "    pb.add_initial_constraint(\n",
    "        ((ctx[\"x_0\"] - ctx[\"x_star\"]) ** 2).le(R**2, name=\"initial_condition\")\n",
    "    )\n",
    "    pb.set_performance_metric(f(ctx[f\"x_{N}\"]) - f(ctx[\"x_star\"]))\n",
    "    return ctx, pb, f"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "900794fa",
   "metadata": {},
   "source": [
    "## Numerical Evidence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eca994ac",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.292644Z",
     "iopub.status.busy": "2026-05-13T15:46:41.292424Z",
     "iopub.status.idle": "2026-05-13T15:46:41.316952Z",
     "shell.execute_reply": "2026-05-13T15:46:41.315464Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N=1: PEP=0.16666647, candidate=0.16666667\n",
      "N=2: PEP=0.09999779, candidate=0.10000000\n",
      "N=3: PEP=0.07142958, candidate=0.07142857\n",
      "N=4: PEP=0.05555563, candidate=0.05555556\n",
      "N=5: PEP=0.04545414, candidate=0.04545455\n",
      "N=6: PEP=0.03846168, candidate=0.03846154\n",
      "N=7: PEP=0.03333328, candidate=0.03333333\n"
     ]
    }
   ],
   "source": [
    "sweep_results = [\n",
    "    {\"N\": 1, \"opt_value\": 0.16666647215955557},\n",
    "    {\"N\": 2, \"opt_value\": 0.09999778825848792},\n",
    "    {\"N\": 3, \"opt_value\": 0.07142958122711231},\n",
    "    {\"N\": 4, \"opt_value\": 0.05555562600367351},\n",
    "    {\"N\": 5, \"opt_value\": 0.04545413917119227},\n",
    "    {\"N\": 6, \"opt_value\": 0.03846167786261405},\n",
    "    {\"N\": 7, \"opt_value\": 0.03333327804360977},\n",
    "]\n",
    "\n",
    "Ns = [row[\"N\"] for row in sweep_results]\n",
    "pep_values = [row[\"opt_value\"] for row in sweep_results]\n",
    "candidate = [1 / (4 * N + 2) for N in Ns]\n",
    "\n",
    "for N, value, rate in zip(Ns, pep_values, candidate):\n",
    "    print(f\"N={N}: PEP={value:.8f}, candidate={rate:.8f}\")\n",
    "\n",
    "plt.figure(figsize=(6, 4))\n",
    "plt.plot(Ns, pep_values, \"o\", label=\"PEP value\")\n",
    "plt.plot(Ns, candidate, \"-\", label=r\"$1/(4N+2)$ for $L=R=1$\")\n",
    "plt.xlabel(\"N\")\n",
    "plt.ylabel(r\"$f(x_N)-f(x_\\star)$\")\n",
    "plt.legend()\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "974212bc",
   "metadata": {},
   "source": [
    "## Dense And Relaxed Proof Solves\n",
    "\n",
    "At `N=4`, the dense solve and the sparse relaxed solve both attain the candidate value `1/18`. The relaxation keeps only forward consecutive interpolation inequalities and the `x_star` row."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6db0ef62",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.320020Z",
     "iopub.status.busy": "2026-05-13T15:46:41.319777Z",
     "iopub.status.idle": "2026-05-13T15:46:41.326345Z",
     "shell.execute_reply": "2026-05-13T15:46:41.325202Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dense opt:   0.0555556260\n",
      "relaxed opt: 0.0555555483\n",
      "target:      0.0555555556\n",
      "relaxed constraints dropped: 21\n",
      "basis vectors: ['x_0', 'x_star', 'grad_f(x_0)', 'grad_f(x_1)', 'grad_f(x_2)', 'grad_f(x_3)', 'grad_f(x_4)']\n"
     ]
    }
   ],
   "source": [
    "import itertools\n",
    "import json\n",
    "import numpy as np\n",
    "import sympy as sp\n",
    "\n",
    "b2 = json.load(open(STATE_DIR / \"gd_recover_b2.json\"))\n",
    "dense = json.load(open(STATE_DIR / \"gd_recover_dense.json\"))\n",
    "relaxed = json.load(open(STATE_DIR / \"gd_recover_relaxed.json\"))\n",
    "\n",
    "target = 1 / (4 * b2[\"N_verify\"] + 2)\n",
    "print(f\"dense opt:   {dense['opt_value']:.10f}\")\n",
    "print(f\"relaxed opt: {relaxed['opt_value']:.10f}\")\n",
    "print(f\"target:      {target:.10f}\")\n",
    "print(f\"relaxed constraints dropped: {len(b2['relaxed_constraints'])}\")\n",
    "print(\"basis vectors:\", b2[\"basis_vectors\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ade15cde",
   "metadata": {},
   "source": [
    "## Closed-Form Lambda And S Verification\n",
    "\n",
    "Let `a_i=(i+1)/(2N-i)`. The active multipliers are\n",
    "\n",
    "$$\n",
    "\\lambda_{x_i,x_{i+1}}=a_i,\\qquad\n",
    "\\lambda_{x_\\star,x_j}=a_j-a_{j-1},\\qquad\n",
    "\\lambda_{x_\\star,x_N}=1-a_{N-1},\n",
    "$$\n",
    "\n",
    "with `a_{-1}=0`; all other entries are relaxed away or zero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c65588a3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.328622Z",
     "iopub.status.busy": "2026-05-13T15:46:41.328418Z",
     "iopub.status.idle": "2026-05-13T15:46:41.339646Z",
     "shell.execute_reply": "2026-05-13T15:46:41.338573Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lambda max residual: 6.257644294405118e-08\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|cccccc}\n",
       "         & x_0 & x_1 & x_2 & x_3 & x_4 & x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0 & 0.0 & 0.125 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_1 & 0.0 & 0.0 & 0.286 & 0.0 & 0.0 & 0.0 \\\\x_2 & 0.0 & 0.0 & 0.0 & 0.5 & 0.0 & 0.0 \\\\x_3 & 0.0 & 0.0 & 0.0 & 0.0 & 0.8 & 0.0 \\\\x_4 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_\\star & 0.125 & 0.161 & 0.214 & 0.3 & 0.2 & 0.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "N_int = b2[\"N_verify\"]\n",
    "row_names = b2[\"lambda_row_names\"]\n",
    "col_names = b2[\"lambda_col_names\"]\n",
    "\n",
    "\n",
    "def idx(tag, N=N_int):\n",
    "    return N + 1 if tag == \"x_star\" else int(tag.split(\"_\")[1])\n",
    "\n",
    "\n",
    "def chain_weight(i, N=N_int):\n",
    "    return sp.Rational(i + 1, 2 * N - i)\n",
    "\n",
    "\n",
    "def lamb(ri, ci, N=N_int):\n",
    "    i, j = idx(ri, N), idx(ci, N)\n",
    "    if 0 <= i < N and j == i + 1:\n",
    "        return chain_weight(i, N)\n",
    "    if ri == \"x_star\" and 0 <= j < N:\n",
    "        prev = chain_weight(j - 1, N) if j > 0 else sp.S(0)\n",
    "        return chain_weight(j, N) - prev\n",
    "    if ri == \"x_star\" and j == N:\n",
    "        return 1 - chain_weight(N - 1, N)\n",
    "    return sp.S(0)\n",
    "\n",
    "\n",
    "lambda_candidate = np.array(\n",
    "    [[float(lamb(ri, ci)) for ci in col_names] for ri in row_names]\n",
    ")\n",
    "lambda_solver = np.array(b2[\"lambda_matrix\"])\n",
    "print(\"lambda max residual:\", np.max(np.abs(lambda_candidate - lambda_solver)))\n",
    "pf.pprint_labeled_matrix(lambda_candidate, row_names, col_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "eaed3441",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.341897Z",
     "iopub.status.busy": "2026-05-13T15:46:41.341561Z",
     "iopub.status.idle": "2026-05-13T15:46:41.443451Z",
     "shell.execute_reply": "2026-05-13T15:46:41.442096Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "S max residual: 6.275168507041684e-08\n",
      "S eigenvalues: [-7.10502237e-17  7.89848048e-17  6.60743053e-02  2.17642777e-01\n",
      "  3.95075951e-01  5.18851503e-01  1.12418086e+00]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccccccc}\n",
       "         & x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) \\\\\n",
       "        \\hline\n",
       "        x_0 & 0.056 & -0.056 & -0.062 & -0.08 & -0.107 & -0.15 & -0.1 \\\\x_\\star & -0.056 & 0.056 & 0.062 & 0.08 & 0.107 & 0.15 & 0.1 \\\\\\nabla f(x_0) & -0.062 & 0.062 & 0.125 & 0.08 & 0.107 & 0.15 & 0.1 \\\\\\nabla f(x_1) & -0.08 & 0.08 & 0.08 & 0.286 & 0.107 & 0.15 & 0.1 \\\\\\nabla f(x_2) & -0.107 & 0.107 & 0.107 & 0.107 & 0.5 & 0.15 & 0.1 \\\\\\nabla f(x_3) & -0.15 & 0.15 & 0.15 & 0.15 & 0.15 & 0.8 & 0.1 \\\\\\nabla f(x_4) & -0.1 & 0.1 & 0.1 & 0.1 & 0.1 & 0.1 & 0.5 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ctx, pb, obj = get_pep_setup(sp.S(N_int), {\"L\": sp.S(1), \"R\": sp.S(1)})\n",
    "pm = pf.ExpressionManager(ctx, resolve_parameters={\"L\": sp.S(1), \"R\": sp.S(1)})\n",
    "\n",
    "interp_sum = pf.Scalar.zero()\n",
    "for ri, ci in itertools.product(row_names, col_names):\n",
    "    c = lamb(ri, ci, N_int)\n",
    "    if c != 0:\n",
    "        interp_sum += c * obj.interp_ineq(ri, ci)\n",
    "\n",
    "tau = sp.Rational(1, 4 * N_int + 2)\n",
    "x_N, x_0, x_star = ctx[f\"x_{N_int}\"], ctx[\"x_0\"], ctx[\"x_star\"]\n",
    "LHS = obj(x_N) - obj(x_star) - tau * (x_0 - x_star) ** 2\n",
    "S_guess = interp_sum - LHS\n",
    "S_candidate = np.array(pm.eval_scalar(S_guess).inner_prod_coords, dtype=float)\n",
    "S_solver = np.array(b2[\"S_matrix\"])\n",
    "\n",
    "print(\"S max residual:\", np.max(np.abs(S_candidate - S_solver)))\n",
    "print(\"S eigenvalues:\", np.linalg.eigvalsh(S_candidate))\n",
    "pf.pprint_labeled_matrix(S_candidate, b2[\"S_row_names\"], b2[\"S_col_names\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9eef029",
   "metadata": {},
   "source": [
    "## Fixed-N Full Proof Identity\n",
    "\n",
    "The identity checked below is `LHS - interp_sum + S_guess = 0`, where `S_guess` is positive semidefinite."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4e07328d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.445696Z",
     "iopub.status.busy": "2026-05-13T15:46:41.445477Z",
     "iopub.status.idle": "2026-05-13T15:46:41.455673Z",
     "shell.execute_reply": "2026-05-13T15:46:41.454783Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Proof valid: True\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccccccc}\n",
       "         & x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) \\\\\n",
       "        \\hline\n",
       "        x_0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_\\star & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f(x_0) & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f(x_1) & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f(x_2) & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f(x_3) & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f(x_4) & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "diff = LHS - interp_sum + S_guess\n",
    "proof_residual = np.array(pm.eval_scalar(diff).inner_prod_coords, dtype=float)\n",
    "print(\"Proof valid:\", np.allclose(proof_residual, 0, atol=1e-8))\n",
    "pf.pprint_labeled_matrix(proof_residual, b2[\"S_row_names\"], b2[\"S_col_names\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9db4444d",
   "metadata": {},
   "source": [
    "## Partial-Sum Lyapunov Construction And Rank Profile\n",
    "\n",
    "Block 3 forms partial sums from the relaxed dual certificate. At step $k$, the increment uses the forward smooth-convex interpolation inequality, the star-row interpolation inequality, and the corresponding LDL square from the PSD residual:\n",
    "\n",
    "$$\n",
    "V_{k+1}-V_k = \\lambda_{x_k,x_{k+1}} I(x_k,x_{k+1}) + \\lambda_{x_\\star,x_k} I(x_\\star,x_k) - d_{N-k} \\ell_{N-k}^2,\\qquad 0\\le k<N.\n",
    "$$\n",
    "\n",
    "At the terminal step, the remaining $\\lambda_{x_\\star,x_N} I(x_\\star,x_N)$ term and latest LDL square are included as boundary terms. The interior partial sums have constant rank."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fede4242",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.458326Z",
     "iopub.status.busy": "2026-05-13T15:46:41.457714Z",
     "iopub.status.idle": "2026-05-13T15:46:41.674449Z",
     "shell.execute_reply": "2026-05-13T15:46:41.673750Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "extra duals: {}\n",
      "LDL diagonal: [0.500000001, 0.780000075, 0.458333377, 0.234693892, 0.070312497, -0.0, 0.0]\n"
     ]
    }
   ],
   "source": [
    "from pepflow.lyapunov_utils import ldl_deompose_with_reversed_basis\n",
    "\n",
    "b3 = json.load(open(STATE_DIR / \"gd_recover_b3.json\"))\n",
    "N_int = b3[\"N_verify\"]\n",
    "rank_tolerance = b3[\"rank_tolerance\"]\n",
    "row_names = b3[\"lambda_row_names\"]\n",
    "col_names = b3[\"lambda_col_names\"]\n",
    "\n",
    "ctx, pb, obj = get_pep_setup(sp.S(N_int), {\"L\": sp.S(1), \"R\": sp.S(1)})\n",
    "pb.set_relaxed_constraints(b3[\"relaxed_constraints\"])\n",
    "result = pb.solve(resolve_parameters={\"L\": sp.S(1), \"R\": sp.S(1)})\n",
    "pm = pf.ExpressionManager(ctx, resolve_parameters={\"L\": sp.S(1), \"R\": sp.S(1)})\n",
    "LT, d, ell = ldl_deompose_with_reversed_basis(\n",
    "    result.get_gram_dual_matrix(), ctx.basis_vectors(), print_output=False\n",
    ")\n",
    "\n",
    "extra_duals = {}\n",
    "try:\n",
    "    for name in result.dual_var_manager.names():\n",
    "        if name == \"initial_condition\" or name.startswith(f\"{b3['obj_tag']}:\"):\n",
    "            continue\n",
    "        value = result.dual_var_manager.dual_value(name)\n",
    "        if value is not None and abs(float(value)) > 1e-6:\n",
    "            extra_duals[name] = float(value)\n",
    "except AttributeError:\n",
    "    pass\n",
    "\n",
    "print(\"extra duals:\", extra_duals)\n",
    "print(\"LDL diagonal:\", [round(float(d[i, i]), 9) for i in range(d.shape[0])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dbc2a89d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.677115Z",
     "iopub.status.busy": "2026-05-13T15:46:41.676671Z",
     "iopub.status.idle": "2026-05-13T15:46:41.717175Z",
     "shell.execute_reply": "2026-05-13T15:46:41.715534Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rank V_0: 0\n",
      "\n",
      "rank V_1: 3\n",
      "rank V_2: 3\n",
      "rank V_3: 3\n",
      "rank V_4: 1\n",
      "Interior rank is constant: True\n",
      "Boundary rank: 1\n",
      "Stored rank profile matches: True\n"
     ]
    }
   ],
   "source": [
    "lyap = [pf.Scalar.zero()]\n",
    "partial_sum = pf.Scalar.zero()\n",
    "\n",
    "for step in range(N_int):\n",
    "    s_idx = N_int - step\n",
    "    if 0 <= s_idx < len(ell) and abs(float(d[s_idx, s_idx])) > 1e-8:\n",
    "        partial_sum = partial_sum - float(d[s_idx, s_idx]) * ell[s_idx] ** 2\n",
    "\n",
    "    for ri, ci in itertools.product(row_names, col_names):\n",
    "        coeff = lamb(ri, ci, N_int)\n",
    "        if coeff == 0:\n",
    "            continue\n",
    "        include = ri == f\"x_{step}\" and ci == f\"x_{step + 1}\"\n",
    "        include = include or (ri == \"x_star\" and ci == f\"x_{step}\")\n",
    "        include = include or (\n",
    "            step == N_int - 1 and ri == \"x_star\" and ci == f\"x_{N_int}\"\n",
    "        )\n",
    "        if include:\n",
    "            partial_sum = partial_sum + coeff * obj.interp_ineq(ri, ci)\n",
    "\n",
    "    if step == N_int - 1 and abs(float(d[0, 0])) > 1e-8:\n",
    "        partial_sum = partial_sum - float(d[0, 0]) * ell[0] ** 2\n",
    "\n",
    "    lyap.append(partial_sum)\n",
    "\n",
    "ranks = []\n",
    "for k, Vk in enumerate(lyap):\n",
    "    matrix = np.array(pm.eval_scalar(Vk).inner_prod_coords, dtype=float)\n",
    "    rank = int(np.linalg.matrix_rank(matrix, tol=rank_tolerance))\n",
    "    ranks.append(rank)\n",
    "    print(f\"rank V_{k}: {rank}\")\n",
    "    if k == 0:\n",
    "        print()\n",
    "\n",
    "print(\"Interior rank is constant:\", len(set(ranks[1:N_int])) == 1)\n",
    "print(\"Boundary rank:\", ranks[-1])\n",
    "print(\"Stored rank profile matches:\", ranks == b3[\"rank_profile\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "069e5029",
   "metadata": {},
   "source": [
    "## Identify the vectors composing the Lyapunov function\n",
    "\n",
    "Block 4 starts from the Block 3 partial sums and searches for simple rank-spanning vectors. For this GD certificate, the interior partial sums are rank 3 and are spanned by the anchor gap, the current gradient, and the current point-to-solution gap."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "248b93c8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.719920Z",
     "iopub.status.busy": "2026-05-13T15:46:41.719668Z",
     "iopub.status.idle": "2026-05-13T15:46:41.724773Z",
     "shell.execute_reply": "2026-05-13T15:46:41.724054Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stored rank profile: [0, 3, 3, 3, 1]\n",
      "interior rank constant: True\n"
     ]
    }
   ],
   "source": [
    "from pepflow.lyapunov_utils import (\n",
    "    vectors_in_column_space,\n",
    "    decompose_rankr_symmetric,\n",
    ")\n",
    "\n",
    "b4 = json.load(open(STATE_DIR / \"gd_recover_b4.json\"))\n",
    "print(\"stored rank profile:\", b4[\"rank_profile\"])\n",
    "print(\"interior rank constant:\", b4[\"interior_rank_constant\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb33fe75",
   "metadata": {},
   "source": [
    "### Candidate-vector scan\n",
    "\n",
    "The candidate families are basis vectors, gradients, point-to-solution gaps, anchor-to-iterate gaps, and the LDL vectors exposed by Block 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "36288c57",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.727225Z",
     "iopub.status.busy": "2026-05-13T15:46:41.726986Z",
     "iopub.status.idle": "2026-05-13T15:46:41.742179Z",
     "shell.execute_reply": "2026-05-13T15:46:41.740749Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "candidate count: 18\n"
     ]
    }
   ],
   "source": [
    "candidate_pairs = []\n",
    "\n",
    "\n",
    "def add_candidate(label, vector):\n",
    "    coords = np.array(pm.eval_vector(vector).coords, dtype=float).ravel()\n",
    "    if np.linalg.norm(coords) < 1e-9:\n",
    "        return\n",
    "    for _, _, old_coords in candidate_pairs:\n",
    "        if np.linalg.norm(coords - old_coords) < 1e-9:\n",
    "            return\n",
    "    candidate_pairs.append((label, vector, coords))\n",
    "\n",
    "\n",
    "for vector in ctx.basis_vectors():\n",
    "    add_candidate(str(vector), vector)\n",
    "for i in range(N_int + 1):\n",
    "    add_candidate(f\"x_{i}-x_star\", ctx[f\"x_{i}\"] - ctx[\"x_star\"])\n",
    "    add_candidate(f\"grad_f(x_{i})\", obj.grad(ctx[f\"x_{i}\"]))\n",
    "    add_candidate(f\"x_0-x_{i}\", ctx[\"x_0\"] - ctx[f\"x_{i}\"])\n",
    "for k in range(1, N_int):\n",
    "    add_candidate(f\"ell_for_V_{k}\", ell[N_int - k])\n",
    "\n",
    "candidates = [vector for _, vector, _ in candidate_pairs]\n",
    "label_by_id = {id(vector): label for label, vector, _ in candidate_pairs}\n",
    "print(f\"candidate count: {len(candidates)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f8a91668",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.744328Z",
     "iopub.status.busy": "2026-05-13T15:46:41.744106Z",
     "iopub.status.idle": "2026-05-13T15:46:41.786377Z",
     "shell.execute_reply": "2026-05-13T15:46:41.785624Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V_1 column-space candidates:\n",
      "   grad_f(x_0)\n",
      "   grad_f(x_1)\n",
      "   x_0-x_star\n",
      "   x_1-x_star\n",
      "   x_2-x_star\n",
      "   x_0-x_2\n",
      "   ell_for_V_1\n",
      "V_2 column-space candidates:\n",
      "   grad_f(x_2)\n",
      "   x_0-x_star\n",
      "   x_2-x_star\n",
      "   x_0-x_2\n",
      "   x_3-x_star\n",
      "   x_0-x_3\n",
      "   ell_for_V_2\n",
      "V_3 column-space candidates:\n",
      "   grad_f(x_3)\n",
      "   x_0-x_star\n",
      "   x_3-x_star\n",
      "   x_0-x_3\n",
      "   x_4-x_star\n",
      "   x_0-x_4\n",
      "   ell_for_V_3\n"
     ]
    }
   ],
   "source": [
    "for k in range(1, N_int):\n",
    "    in_col = vectors_in_column_space(\n",
    "        lyap[k],\n",
    "        candidates,\n",
    "        pep_context=ctx,\n",
    "        resolve_parameters={\"L\": sp.S(1), \"R\": sp.S(1)},\n",
    "        rtol=1e-4,\n",
    "        atol=1e-4,\n",
    "    )\n",
    "    print(f\"V_{k} column-space candidates:\")\n",
    "    for vector in in_col:\n",
    "        print(\"  \", label_by_id.get(id(vector), str(vector)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "911003bb",
   "metadata": {},
   "source": [
    "### Selected basis pattern\n",
    "\n",
    "For $1\\le k<N$, use\n",
    "\n",
    "$$\n",
    "\\mathcal B_k = \\{x_0-x_\\star,\\ \n",
    "abla f(x_k),\\ x_k-x_\\star\\}.\n",
    "$$\n",
    "\n",
    "The terminal boundary term uses the rank-1 basis $\\{x_0-x_\\star\\}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aa9187c1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.789126Z",
     "iopub.status.busy": "2026-05-13T15:46:41.788919Z",
     "iopub.status.idle": "2026-05-13T15:46:41.796582Z",
     "shell.execute_reply": "2026-05-13T15:46:41.795410Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k=1: rank 3 basis ['x_0-x_star', 'grad_f(x_1)', 'x_1-x_star']\n",
      "k=2: rank 3 basis ['x_0-x_star', 'grad_f(x_2)', 'x_2-x_star']\n",
      "k=3: rank 3 basis ['x_0-x_star', 'grad_f(x_3)', 'x_3-x_star']\n",
      "k=4: rank 1 basis ['x_0-x_star']\n"
     ]
    }
   ],
   "source": [
    "def V_k_basis(k):\n",
    "    if 1 <= k < N_int:\n",
    "        return [\n",
    "            ctx[\"x_0\"] - ctx[\"x_star\"],\n",
    "            obj.grad(ctx[f\"x_{k}\"]),\n",
    "            ctx[f\"x_{k}\"] - ctx[\"x_star\"],\n",
    "        ]\n",
    "    if k == N_int:\n",
    "        return [ctx[\"x_0\"] - ctx[\"x_star\"]]\n",
    "    return []\n",
    "\n",
    "\n",
    "def V_k_basis_labels(k):\n",
    "    if 1 <= k < N_int:\n",
    "        return [\"x_0-x_star\", f\"grad_f(x_{k})\", f\"x_{k}-x_star\"]\n",
    "    if k == N_int:\n",
    "        return [\"x_0-x_star\"]\n",
    "    return []\n",
    "\n",
    "\n",
    "for k in range(1, N_int + 1):\n",
    "    basis = V_k_basis(k)\n",
    "    matrix = np.column_stack(\n",
    "        [np.array(pm.eval_vector(v).coords, dtype=float).ravel() for v in basis]\n",
    "    )\n",
    "    print(\n",
    "        f\"k={k}: rank {np.linalg.matrix_rank(matrix, tol=1e-7)} basis {V_k_basis_labels(k)}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ef80d70",
   "metadata": {},
   "source": [
    "### Coefficient matrices\n",
    "\n",
    "The interior basis order is $(x_0-x_\\star,\n",
    "abla f(x_k),x_k-x_\\star)$. The coefficient matrix is diagonal with entries\n",
    "\n",
    "$$\n",
    "-\\frac{1}{4N+2},\\qquad \\frac{k}{2(2N-k+1)},\\qquad \\frac{2N-2k+1}{2(2N-k+1)^2}.\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d802cb1e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.799026Z",
     "iopub.status.busy": "2026-05-13T15:46:41.798602Z",
     "iopub.status.idle": "2026-05-13T15:46:41.898378Z",
     "shell.execute_reply": "2026-05-13T15:46:41.897578Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k=1 formula residual: 5.36e-09\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & x_0-x_\\star & \\nabla f(x_1) & x_1-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.055556 & 0.0 & -0.0 \\\\\\nabla f(x_1) & 0.0 & 0.0625 & -0.0 \\\\x_1-x_\\star & -0.0 & -0.0 & 0.054688 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k=2 formula residual: 7.82e-09\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & x_0-x_\\star & \\nabla f(x_2) & x_2-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.055556 & -0.0 & 0.0 \\\\\\nabla f(x_2) & -0.0 & 0.142857 & 0.0 \\\\x_2-x_\\star & 0.0 & 0.0 & 0.05102 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k=3 formula residual: 1.84e-08\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & x_0-x_\\star & \\nabla f(x_3) & x_3-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.055556 & 0.0 & 0.0 \\\\\\nabla f(x_3) & 0.0 & 0.25 & -0.0 \\\\x_3-x_\\star & 0.0 & -0.0 & 0.041667 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k=4 formula residual: 4.22e-10\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|c}\n",
       "         & x_0-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.055556 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def coeff_pattern(k, N):\n",
    "    if 1 <= k < N:\n",
    "        return np.array(\n",
    "            [\n",
    "                [float(sp.Rational(-1, 4 * N + 2)), 0.0, 0.0],\n",
    "                [0.0, float(sp.Rational(k, 2 * (2 * N - k + 1))), 0.0],\n",
    "                [\n",
    "                    0.0,\n",
    "                    0.0,\n",
    "                    float(sp.Rational(2 * N - 2 * k + 1, 2 * (2 * N - k + 1) ** 2)),\n",
    "                ],\n",
    "            ]\n",
    "        )\n",
    "    if k == N:\n",
    "        return np.array([[float(sp.Rational(-1, 4 * N + 2))]])\n",
    "    return np.zeros((0, 0))\n",
    "\n",
    "\n",
    "for k in range(1, N_int + 1):\n",
    "    basis = V_k_basis(k)\n",
    "    labels = V_k_basis_labels(k)\n",
    "    C = decompose_rankr_symmetric(\n",
    "        lyap[k], basis, pep_context=ctx, resolve_parameters={\"L\": sp.S(1), \"R\": sp.S(1)}\n",
    "    )\n",
    "    formula = coeff_pattern(k, N_int)\n",
    "    print(f\"k={k} formula residual: {np.max(np.abs(C - formula)):.2e}\")\n",
    "    pf.pprint_labeled_matrix(C, labels, labels, precision=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e313a8d9",
   "metadata": {},
   "source": [
    "### Block 4 conclusion\n",
    "\n",
    "Combining the function-coordinate pattern with the quadratic coefficient matrices gives, for $1\\le k<N$,\n",
    "\n",
    "$$\n",
    "V_k=\\frac{k}{2N-k+1}\\bigl(f(x_k)-f(x_{\\star})\\bigr)-\\frac{1}{4N+2}\\|x_0-x_{\\star}\\|^2+\\frac{k}{2(2N-k+1)}\\|\\nabla f(x_k)\\|^2+\\frac{2N-2k+1}{2(2N-k+1)^2}\\|x_k-x_{\\star}\\|^2.\n",
    "$$\n",
    "\n",
    "At $k=N$,\n",
    "\n",
    "$$\n",
    "V_N=f(x_N)-f(x_{\\star})-\\frac{1}{4N+2}\\|x_0-x_{\\star}\\|^2.\n",
    "$$\n",
    "\n",
    "The next cells symbolically verify the one-step recursion, the base identity, and the boundary identity for this closed form."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd713122",
   "metadata": {},
   "source": [
    "## Symbolic Step Recursion Verification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b30ad59",
   "metadata": {},
   "source": [
    "For $1\\le k<N-1$, verify\n",
    "\n",
    "$$\n",
    "V_{k+1}-V_k=\\alpha_{k+1}I_f(x_k,x_{k+1})+(\\alpha_{k+1}-\\alpha_k)I_f(x_{\\star},x_k)-\\delta_k\\left\\|\\nabla f(x_k)-\\frac{x_k-x_{\\star}}{2N-k+1}\\right\\|^2.\n",
    "$$\n",
    "\n",
    "The residual $\\mathrm{LHS}-\\mathrm{RHS}$ should simplify to zero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "81cb3247",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:41.901724Z",
     "iopub.status.busy": "2026-05-13T15:46:41.901375Z",
     "iopub.status.idle": "2026-05-13T15:46:43.182005Z",
     "shell.execute_reply": "2026-05-13T15:46:43.180494Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step inner residual zero: True\n",
      "Step function residual zero: True\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccccc}\n",
       "         & x_{0} & x_{k} & x_\\star & \\nabla f_{step}(x_{k}) & \\nabla f_{step}(x_{k+1}) \\\\\n",
       "        \\hline\n",
       "        x_{0} & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_{k} & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_\\star & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f_{step}(x_{k}) & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f_{step}(x_{k+1}) & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def simplify_residual(scalar, context, resolve):\n",
    "    manager = pf.ExpressionManager(context, resolve_parameters=resolve)\n",
    "    evaluated = manager.eval_scalar(scalar)\n",
    "    inner = sp.Matrix(evaluated.inner_prod_coords).applyfunc(sp.simplify)\n",
    "    func = sp.Matrix(evaluated.func_coords).applyfunc(sp.simplify)\n",
    "    return inner, func\n",
    "\n",
    "\n",
    "k_symbol, N_symbol = sp.symbols(\"k N\", positive=True, integer=True)\n",
    "k_param = pf.Parameter(\"k\")\n",
    "N_param = pf.Parameter(\"N\")\n",
    "resolve_symbols = {\"k\": k_symbol, \"N\": N_symbol}\n",
    "\n",
    "\n",
    "def alpha_param(t):\n",
    "    return t / (2 * N_param - t + 1)\n",
    "\n",
    "\n",
    "def beta_param(t):\n",
    "    return t / (2 * (2 * N_param - t + 1))\n",
    "\n",
    "\n",
    "def gamma_param(t):\n",
    "    return (2 * N_param - 2 * t + 1) / (2 * (2 * N_param - t + 1) ** 2)\n",
    "\n",
    "\n",
    "def tau_param():\n",
    "    return 1 / (4 * N_param + 2)\n",
    "\n",
    "\n",
    "def delta_param(t):\n",
    "    return (4 * N_param * t + 2 * N_param - 2 * t**2 + 1) / (2 * (2 * N_param - t) ** 2)\n",
    "\n",
    "\n",
    "ctx_step = pf.PEPContext(\"symbolic_step\").set_as_current()\n",
    "f_step = pf.SmoothConvexFunction(is_basis=True, tags=[\"f_{step}\"], L=1)\n",
    "x_0_step = pf.Vector(is_basis=True, tags=[\"x_{0}\"])\n",
    "x_k_step = pf.Vector(is_basis=True, tags=[\"x_{k}\"])\n",
    "x_star_step = f_step.set_stationary_point(\"x_star\")\n",
    "g_k_step = f_step.grad(x_k_step)\n",
    "x_k1_step = x_k_step - g_k_step\n",
    "x_k1_step.add_tag(\"x_{k+1}\")\n",
    "g_k1_step = f_step.grad(x_k1_step)\n",
    "\n",
    "V_k_step = (f_step(x_k_step) - f_step(x_star_step)) * alpha_param(k_param)\n",
    "V_k_step += -((x_0_step - x_star_step) ** 2) * tau_param()\n",
    "V_k_step += (g_k_step**2) * beta_param(k_param)\n",
    "V_k_step += ((x_k_step - x_star_step) ** 2) * gamma_param(k_param)\n",
    "\n",
    "V_k1_step = (f_step(x_k1_step) - f_step(x_star_step)) * alpha_param(k_param + 1)\n",
    "V_k1_step += -((x_0_step - x_star_step) ** 2) * tau_param()\n",
    "V_k1_step += (g_k1_step**2) * beta_param(k_param + 1)\n",
    "V_k1_step += ((x_k1_step - x_star_step) ** 2) * gamma_param(k_param + 1)\n",
    "\n",
    "ell_step = g_k_step - (x_k_step - x_star_step) / (2 * N_param - k_param + 1)\n",
    "RHS_step = f_step.interp_ineq(x_k_step, x_k1_step) * alpha_param(k_param + 1)\n",
    "RHS_step += f_step.interp_ineq(x_star_step, x_k_step) * (\n",
    "    alpha_param(k_param + 1) - alpha_param(k_param)\n",
    ")\n",
    "RHS_step += -(ell_step**2) * delta_param(k_param)\n",
    "\n",
    "step_inner, step_func = simplify_residual(\n",
    "    V_k1_step - V_k_step - RHS_step, ctx_step, resolve_symbols\n",
    ")\n",
    "print(\"Step inner residual zero:\", step_inner == sp.zeros(*step_inner.shape))\n",
    "print(\"Step function residual zero:\", step_func == sp.zeros(*step_func.shape))\n",
    "pf.pprint_labeled_matrix(\n",
    "    np.array(step_inner.tolist(), dtype=object),\n",
    "    ctx_step.basis_vectors_math_exprs(),\n",
    "    ctx_step.basis_vectors_math_exprs(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c86d6910",
   "metadata": {},
   "source": [
    "## Base Case and Boundary Symbolic Verification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "037756e4",
   "metadata": {},
   "source": [
    "Verify\n",
    "\n",
    "$$\n",
    "V_1=\\frac{1}{2N}I_f(x_0,x_1)+\\frac{1}{2N}I_f(x_{\\star},x_0)-\\frac{2N+1}{8N^2}\\left\\|\\nabla f(x_0)-\\frac{x_0-x_{\\star}}{2N+1}\\right\\|^2.\n",
    "$$\n",
    "\n",
    "The residual $\\mathrm{LHS}-\\mathrm{RHS}$ should simplify to zero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4d63cb36",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:43.186345Z",
     "iopub.status.busy": "2026-05-13T15:46:43.185783Z",
     "iopub.status.idle": "2026-05-13T15:46:43.412035Z",
     "shell.execute_reply": "2026-05-13T15:46:43.410628Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Base inner residual zero: True\n",
      "Base function residual zero: True\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|cccc}\n",
       "         & x_{0} & x_\\star & \\nabla f_{base}(x_{0}) & \\nabla f_{base}(x_{1}) \\\\\n",
       "        \\hline\n",
       "        x_{0} & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_\\star & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f_{base}(x_{0}) & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f_{base}(x_{1}) & 0.0 & 0.0 & 0.0 & 0.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ctx_base = pf.PEPContext(\"symbolic_base\").set_as_current()\n",
    "f_base = pf.SmoothConvexFunction(is_basis=True, tags=[\"f_{base}\"], L=1)\n",
    "x_0_base = pf.Vector(is_basis=True, tags=[\"x_{0}\"])\n",
    "x_star_base = f_base.set_stationary_point(\"x_star\")\n",
    "g_0_base = f_base.grad(x_0_base)\n",
    "x_1_base = x_0_base - g_0_base\n",
    "x_1_base.add_tag(\"x_{1}\")\n",
    "g_1_base = f_base.grad(x_1_base)\n",
    "\n",
    "V_1_base = (f_base(x_1_base) - f_base(x_star_base)) * alpha_param(1)\n",
    "V_1_base += -((x_0_base - x_star_base) ** 2) * tau_param()\n",
    "V_1_base += (g_1_base**2) * beta_param(1)\n",
    "V_1_base += ((x_1_base - x_star_base) ** 2) * gamma_param(1)\n",
    "ell_base = g_0_base - (x_0_base - x_star_base) / (2 * N_param + 1)\n",
    "RHS_base = f_base.interp_ineq(x_0_base, x_1_base) * (1 / (2 * N_param))\n",
    "RHS_base += f_base.interp_ineq(x_star_base, x_0_base) * (1 / (2 * N_param))\n",
    "RHS_base += -(ell_base**2) * ((2 * N_param + 1) / (8 * N_param**2))\n",
    "\n",
    "base_inner, base_func = simplify_residual(\n",
    "    V_1_base - RHS_base, ctx_base, {\"N\": N_symbol}\n",
    ")\n",
    "print(\"Base inner residual zero:\", base_inner == sp.zeros(*base_inner.shape))\n",
    "print(\"Base function residual zero:\", base_func == sp.zeros(*base_func.shape))\n",
    "pf.pprint_labeled_matrix(\n",
    "    np.array(base_inner.tolist(), dtype=object),\n",
    "    ctx_base.basis_vectors_math_exprs(),\n",
    "    ctx_base.basis_vectors_math_exprs(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3adcdecc",
   "metadata": {},
   "source": [
    "### Boundary Identity Symbolic Verification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "596793ad",
   "metadata": {},
   "source": [
    "Verify\n",
    "\n",
    "$$\n",
    "V_N=f(x_N)-f(x_{\\star})-\\frac{1}{4N+2}\\|x_0-x_{\\star}\\|^2.\n",
    "$$\n",
    "\n",
    "The residual $\\mathrm{LHS}-\\mathrm{RHS}$ should simplify to zero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4fa9a1cb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-13T15:46:43.414992Z",
     "iopub.status.busy": "2026-05-13T15:46:43.414623Z",
     "iopub.status.idle": "2026-05-13T15:46:43.426719Z",
     "shell.execute_reply": "2026-05-13T15:46:43.425604Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Boundary inner residual zero: True\n",
      "Boundary function residual zero: True\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|cccc}\n",
       "         & x_{0} & x_{N} & x_\\star & \\nabla f_{boundary}(x_{N}) \\\\\n",
       "        \\hline\n",
       "        x_{0} & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_{N} & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_\\star & 0.0 & 0.0 & 0.0 & 0.0 \\\\\\nabla f_{boundary}(x_{N}) & 0.0 & 0.0 & 0.0 & 0.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ctx_boundary = pf.PEPContext(\"symbolic_boundary\").set_as_current()\n",
    "f_boundary = pf.SmoothConvexFunction(is_basis=True, tags=[\"f_{boundary}\"], L=1)\n",
    "x_0_boundary = pf.Vector(is_basis=True, tags=[\"x_{0}\"])\n",
    "x_N_boundary = pf.Vector(is_basis=True, tags=[\"x_{N}\"])\n",
    "x_star_boundary = f_boundary.set_stationary_point(\"x_star\")\n",
    "V_N_boundary = f_boundary(x_N_boundary) - f_boundary(x_star_boundary)\n",
    "V_N_boundary += -((x_0_boundary - x_star_boundary) ** 2) * tau_param()\n",
    "RHS_boundary = f_boundary(x_N_boundary) - f_boundary(x_star_boundary)\n",
    "RHS_boundary += -((x_0_boundary - x_star_boundary) ** 2) * tau_param()\n",
    "boundary_inner, boundary_func = simplify_residual(\n",
    "    V_N_boundary - RHS_boundary, ctx_boundary, {\"N\": N_symbol}\n",
    ")\n",
    "print(\n",
    "    \"Boundary inner residual zero:\", boundary_inner == sp.zeros(*boundary_inner.shape)\n",
    ")\n",
    "print(\n",
    "    \"Boundary function residual zero:\", boundary_func == sp.zeros(*boundary_func.shape)\n",
    ")\n",
    "pf.pprint_labeled_matrix(\n",
    "    np.array(boundary_inner.tolist(), dtype=object),\n",
    "    ctx_boundary.basis_vectors_math_exprs(),\n",
    "    ctx_boundary.basis_vectors_math_exprs(),\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
