{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3e23641c",
   "metadata": {},
   "source": [
    "# Gradient Descent Method Example"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8baf79c6",
   "metadata": {},
   "source": [
    "## Import the required libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1988d6fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pepflow as pf\n",
    "import numpy as np\n",
    "import sympy as sp\n",
    "import matplotlib.pyplot as plt\n",
    "import itertools\n",
    "from functools import lru_cache\n",
    "from IPython.display import display, Math\n",
    "\n",
    "from pepflow.lyapunov_utils import (\n",
    "    decompose_rankr_symmetric,\n",
    "    vectors_in_column_space,\n",
    "    find_sparsest_decompositions,\n",
    "    ldl_deompose_with_reversed_basis,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4737f98-b1d1-4866-a39e-04539b329da4",
   "metadata": {},
   "source": [
    "## Define the functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a8c5be8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "L = pf.Parameter(\"L\")\n",
    "f = pf.SmoothConvexFunction(is_basis=True, tags=[\"f\"], L=L)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d7b57bb",
   "metadata": {},
   "source": [
    "## Write a function to return the PEPContext associated with GD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ad6701aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_ctx_gd(\n",
    "    ctx_name: str, N: int | sp.Integer, stepsize: pf.Parameter\n",
    ") -> pf.PEPContext:\n",
    "    ctx_gd = pf.PEPContext(ctx_name).set_as_current()\n",
    "    x = pf.Vector(is_basis=True, tags=[\"x_0\"])\n",
    "    f.set_stationary_point(\"x_star\")\n",
    "    for i in range(N):\n",
    "        x = x - 1 / L * f.grad(x)\n",
    "        x.add_tag(f\"x_{i + 1}\")\n",
    "    return ctx_gd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f1f0159",
   "metadata": {},
   "source": [
    "## Numerical evidence of convergence of GD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "425348ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 8\n",
    "R = pf.Parameter(\"R\")\n",
    "L_value = 1\n",
    "R_value = 1\n",
    "\n",
    "ctx_plt = make_ctx_gd(ctx_name=\"ctx_plt\", N=N, stepsize=1 / L)\n",
    "pb_plt = pf.PEPBuilder(ctx_plt)\n",
    "pb_plt.add_initial_constraint(\n",
    "    ((ctx_plt[\"x_0\"] - ctx_plt[\"x_star\"]) ** 2).le(R, name=\"initial_condition\")\n",
    ")\n",
    "\n",
    "opt_values = []\n",
    "for k in range(1, N):\n",
    "    x_k = ctx_plt[f\"x_{k}\"]\n",
    "    pb_plt.set_performance_metric(f(x_k) - f(ctx_plt[\"x_star\"]))\n",
    "    result = pb_plt.solve(resolve_parameters={\"L\": L_value, \"R\": R_value})\n",
    "    opt_values.append(result.opt_value)\n",
    "\n",
    "iters = np.arange(1, N)\n",
    "cont_iters = np.arange(1, N, 0.01)\n",
    "plt.plot(\n",
    "    cont_iters,\n",
    "    L_value / (4 * cont_iters + 2),\n",
    "    \"r-\",\n",
    "    label=\"Analytical bound $\\\\frac{L}{4k + 2}$\",\n",
    ")\n",
    "plt.scatter(iters, opt_values, color=\"blue\", marker=\"o\", label=\"Numerical values\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99362ee7",
   "metadata": {},
   "source": [
    "## Verification of convergence of GM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d0e252cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.04545413917119227\n"
     ]
    }
   ],
   "source": [
    "N = sp.S(5)\n",
    "R = pf.Parameter(\"R\")\n",
    "L_value = sp.S(1)\n",
    "R_value = sp.S(1)\n",
    "\n",
    "ctx_prf = make_ctx_gd(ctx_name=\"ctx_prf\", N=N, stepsize=1 / L)\n",
    "pb_prf = pf.PEPBuilder(ctx_prf)\n",
    "pb_prf.add_initial_constraint(\n",
    "    ((ctx_prf[\"x_0\"] - ctx_prf[\"x_star\"]) ** 2).le(R, name=\"initial_condition\")\n",
    ")\n",
    "pb_prf.set_performance_metric(f(ctx_prf[f\"x_{N}\"]) - f(ctx_prf[\"x_star\"]))\n",
    "\n",
    "result = pb_prf.solve(resolve_parameters={\"L\": L_value, \"R\": R_value})\n",
    "print(result.opt_value)\n",
    "\n",
    "# Dual variables associated with the interpolations conditions of f with no relaxation\n",
    "lamb_dense = result.get_scalar_constraint_dual_value_in_numpy(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "45f2b9ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dash app running on http://127.0.0.1:8050/\n"
     ]
    }
   ],
   "source": [
    "pf.launch_primal_interactive(\n",
    "    pb_prf, ctx_prf, resolve_parameters={\"L\": L_value, \"R\": R_value}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5530c49c",
   "metadata": {},
   "source": [
    "### Solve the problem again with the found relaxation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "24cf9cd5-0c2a-4d93-9d81-08071fe6fd4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tag_to_index(tag, N=N):\n",
    "    \"\"\"This is a function that takes in a tag of an iterate and returns its index.\n",
    "    We index \"x_star\" as \"N+1 where N is the last iterate.\n",
    "    \"\"\"\n",
    "    # Split the string on \"_\" and get the index\n",
    "    if (idx := tag.split(\"_\")[1]).isdigit():\n",
    "        return int(idx)\n",
    "    elif idx == \"star\":\n",
    "        return N + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1a22704f-f133-4903-b04b-939ec64ef680",
   "metadata": {},
   "outputs": [],
   "source": [
    "relaxed_constraints = []\n",
    "\n",
    "for tag_i in lamb_dense.row_names:\n",
    "    i = tag_to_index(tag_i)\n",
    "    if i == N + 1:\n",
    "        continue\n",
    "    for tag_j in lamb_dense.col_names:\n",
    "        j = tag_to_index(tag_j)\n",
    "        if i < N and i + 1 == j:\n",
    "            continue\n",
    "        relaxed_constraints.append(f\"f:{tag_i},{tag_j}\")\n",
    "\n",
    "pb_prf.set_relaxed_constraints(relaxed_constraints)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7b841a1",
   "metadata": {},
   "source": [
    "- Solve the PEP problem again with the relaxed constraints and store the results.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "412b66dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = pb_prf.solve(resolve_parameters={\"L\": L_value, \"R\": R_value})\n",
    "\n",
    "# Dual variable associated with the initial condition\n",
    "tau_sol = result.dual_var_manager.dual_value(\"initial_condition\")\n",
    "# Dual variable associated with the interpolations conditions of f\n",
    "lamb_sol = result.get_scalar_constraint_dual_value_in_numpy(f)\n",
    "# Dual variable associated with the Gram matrix G\n",
    "S_sol = result.get_gram_dual_matrix()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d1baf45",
   "metadata": {},
   "source": [
    "### Verify closed form expression of $\\lambda$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6907f42",
   "metadata": {},
   "source": [
    "- Print the values of $\\lambda$ obtained from the solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7326d189",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccccccc}\n",
       "         & x_0 & x_1 & x_2 & x_3 & x_4 & x_5 & x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0 & 0.0 & 0.1 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_1 & 0.0 & 0.0 & 0.222 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_2 & 0.0 & 0.0 & 0.0 & 0.375 & 0.0 & 0.0 & 0.0 \\\\x_3 & 0.0 & 0.0 & 0.0 & 0.0 & 0.571 & 0.0 & 0.0 \\\\x_4 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.833 & 0.0 \\\\x_5 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_\\star & 0.1 & 0.122 & 0.153 & 0.196 & 0.262 & 0.167 & 0.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lamb_sol.pprint()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3075a90c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccccccc}\n",
       "         & x_0 & x_1 & x_2 & x_3 & x_4 & x_5 & x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0 & 0.0 & 0.1 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_1 & 0.0 & 0.0 & 0.222 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_2 & 0.0 & 0.0 & 0.0 & 0.375 & 0.0 & 0.0 & 0.0 \\\\x_3 & 0.0 & 0.0 & 0.0 & 0.0 & 0.571 & 0.0 & 0.0 \\\\x_4 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.833 & 0.0 \\\\x_5 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\\\x_\\star & 0.1 & 0.122 & 0.153 & 0.196 & 0.262 & 0.167 & 0.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def lamb(tag_i, tag_j, N=N):\n",
    "    i = tag_to_index(tag_i)\n",
    "    j = tag_to_index(tag_j)\n",
    "    if i == N + 1:  # Additional constraint 1 (between x_★)\n",
    "        if j == 0:\n",
    "            return lamb(\"x_0\", \"x_1\")\n",
    "        elif j < N:\n",
    "            return lamb(f\"x_{j}\", f\"x_{j + 1}\") - lamb(f\"x_{j - 1}\", f\"x_{j}\")\n",
    "        elif j == N:\n",
    "            return 1 - lamb(f\"x_{N - 1}\", f\"x_{N}\")\n",
    "    if i < N and i + 1 == j:  # Additional constraint 2 (consecutive)\n",
    "        return j / (2 * N + 1 - j)\n",
    "    return 0\n",
    "\n",
    "\n",
    "lamb_cand = pf.pprint_labeled_matrix(\n",
    "    lamb, lamb_sol.row_names, lamb_sol.col_names, return_matrix=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74a2e3e0",
   "metadata": {},
   "source": [
    "- Check whether our candidate of $\\lambda$ matches with solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "74d4368a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Did we guess the right closed form of lambda? True\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    \"Did we guess the right closed form of lambda?\",\n",
    "    np.allclose(lamb_cand, lamb_sol.matrix, atol=1e-4),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db25c45d",
   "metadata": {},
   "source": [
    "### Verify closed form expression $S$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb4c18c1",
   "metadata": {},
   "source": [
    "- Create an ExpressionManager to translate $x_i$, $f(x_i)$, and $\\nabla f(x_i)$ into a basis representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "05ddf829",
   "metadata": {},
   "outputs": [],
   "source": [
    "pm = pf.ExpressionManager(ctx_prf, resolve_parameters={\"L\": L_value, \"R\": R_value})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98ba117b",
   "metadata": {},
   "source": [
    "- Print the values of $S$ obtained from the solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d17dee1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|cccccccc}\n",
       "         & x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) & \\nabla f(x_5) \\\\\n",
       "        \\hline\n",
       "        x_0 & 0.045 & -0.045 & -0.05 & -0.061 & -0.076 & -0.098 & -0.131 & -0.083 \\\\x_\\star & -0.045 & 0.045 & 0.05 & 0.061 & 0.076 & 0.098 & 0.131 & 0.083 \\\\\\nabla f(x_0) & -0.05 & 0.05 & 0.1 & 0.061 & 0.076 & 0.098 & 0.131 & 0.083 \\\\\\nabla f(x_1) & -0.061 & 0.061 & 0.061 & 0.222 & 0.076 & 0.098 & 0.131 & 0.083 \\\\\\nabla f(x_2) & -0.076 & 0.076 & 0.076 & 0.076 & 0.375 & 0.098 & 0.131 & 0.083 \\\\\\nabla f(x_3) & -0.098 & 0.098 & 0.098 & 0.098 & 0.098 & 0.571 & 0.131 & 0.083 \\\\\\nabla f(x_4) & -0.131 & 0.131 & 0.131 & 0.131 & 0.131 & 0.131 & 0.833 & 0.083 \\\\\\nabla f(x_5) & -0.083 & 0.083 & 0.083 & 0.083 & 0.083 & 0.083 & 0.083 & 0.5 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "S_sol.pprint()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f43a1265",
   "metadata": {},
   "source": [
    "- Subtract the decomposed closed-form expressions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c6fad11c",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = ctx_prf.tracked_point(f)\n",
    "x_0 = ctx_prf[\"x_0\"]\n",
    "x_star = ctx_prf[\"x_star\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "096d0ee1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|cccccccc}\n",
       "         & \\nabla f(x_5) & \\nabla f(x_4) & \\nabla f(x_3) & \\nabla f(x_2) & \\nabla f(x_1) & \\nabla f(x_0) & x_0 & x_\\star \\\\\n",
       "        \\hline\n",
       "        \\ell_{1} & 1.0 & 0.167 & 0.167 & 0.167 & 0.167 & 0.167 & -0.167 & 0.167 \\\\\\ell_{2} & 0.0 & 1.0 & 0.143 & 0.143 & 0.143 & 0.143 & -0.143 & 0.143 \\\\\\ell_{3} & 0.0 & 0.0 & 1.0 & 0.125 & 0.125 & 0.125 & -0.125 & 0.125 \\\\\\ell_{4} & 0.0 & 0.0 & 0.0 & 1.0 & 0.111 & 0.111 & -0.111 & 0.111 \\\\\\ell_{5} & 0.0 & 0.0 & 0.0 & 0.0 & 1.0 & 0.1 & -0.1 & 0.1 \\\\\\ell_{6} & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 1.0 & -0.091 & 0.091 \\\\\\ell_{7} & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 1.0 & -0.664 \\\\\\ell_{8} & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 1.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\left[\\begin{matrix}0.5 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\\\\0.0 & 0.819 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\\\\0.0 & 0.0 & 0.541 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\\\\0.0 & 0.0 & 0.0 & 0.336 & 0.0 & 0.0 & 0.0 & 0.0\\\\0.0 & 0.0 & 0.0 & 0.0 & 0.179 & 0.0 & 0.0 & 0.0\\\\0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.055 & 0.0 & 0.0\\\\0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\\\\0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\\end{matrix}\\right]$"
      ],
      "text/plain": [
       "Matrix([\n",
       "[0.5,   0.0,   0.0,   0.0,   0.0,   0.0, 0.0, 0.0],\n",
       "[0.0, 0.819,   0.0,   0.0,   0.0,   0.0, 0.0, 0.0],\n",
       "[0.0,   0.0, 0.541,   0.0,   0.0,   0.0, 0.0, 0.0],\n",
       "[0.0,   0.0,   0.0, 0.336,   0.0,   0.0, 0.0, 0.0],\n",
       "[0.0,   0.0,   0.0,   0.0, 0.179,   0.0, 0.0, 0.0],\n",
       "[0.0,   0.0,   0.0,   0.0,   0.0, 0.055, 0.0, 0.0],\n",
       "[0.0,   0.0,   0.0,   0.0,   0.0,   0.0, 0.0, 0.0],\n",
       "[0.0,   0.0,   0.0,   0.0,   0.0,   0.0, 0.0, 0.0]])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "LT, d, ell = ldl_deompose_with_reversed_basis(\n",
    "    S_sol, ctx_prf.basis_vectors(), print_output=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c24db970",
   "metadata": {},
   "source": [
    "- Guess a closed form for $\\ell$ based on numerical observations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1ce14cae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{cccccccc}\n",
       "        x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) & \\nabla f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{cccccccc}\n",
       "        x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) & \\nabla f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{cccccccc}\n",
       "        x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) & \\nabla f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{cccccccc}\n",
       "        x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) & \\nabla f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{cccccccc}\n",
       "        x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) & \\nabla f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{cccccccc}\n",
       "        x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) & \\nabla f(x_5) \\\\\n",
       "        \\hline\n",
       "        0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "z = []\n",
    "for i in range(N + 1):\n",
    "    z_i_guess = (x[i] - x_star) / sp.S(2 * N + 1 - i) - 1 / L * f.grad(x[i])\n",
    "    remainder_1 = pm.eval_vector(z_i_guess + ell[N - i]).coords\n",
    "    pf.pprint_labeled_vector(remainder_1, S_sol.row_names)\n",
    "    z.append(z_i_guess)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a30a12ca",
   "metadata": {},
   "source": [
    "- Double check the calculation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7c4c0e27",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|cccccccc}\n",
       "         & x_0 & x_\\star & \\nabla f(x_0) & \\nabla f(x_1) & \\nabla f(x_2) & \\nabla f(x_3) & \\nabla f(x_4) & \\nabla f(x_5) \\\\\n",
       "        \\hline\n",
       "        x_0 & -0.0 & 0.0 & 0.0 & -0.0 & -0.0 & 0.0 & 0.0 & 0.0 \\\\x_\\star & 0.0 & -0.0 & -0.0 & 0.0 & 0.0 & -0.0 & -0.0 & -0.0 \\\\\\nabla f(x_0) & 0.0 & -0.0 & -0.0 & -0.0 & 0.0 & -0.0 & -0.0 & -0.0 \\\\\\nabla f(x_1) & -0.0 & 0.0 & -0.0 & -0.0 & 0.0 & -0.0 & -0.0 & -0.0 \\\\\\nabla f(x_2) & -0.0 & 0.0 & 0.0 & 0.0 & -0.0 & -0.0 & -0.0 & -0.0 \\\\\\nabla f(x_3) & 0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 \\\\\\nabla f(x_4) & 0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 \\\\\\nabla f(x_5) & 0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 & -0.0 & 0.0 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "S_guess = pf.Scalar.zero()\n",
    "\n",
    "# Even though we don't find this closed form, we may find such decomposition numerically\n",
    "\n",
    "delta = []\n",
    "for i in range(N + 1):\n",
    "    coeff_i = d[N - i, N - i]\n",
    "    delta.append(coeff_i)\n",
    "    z[i].add_tag(f\"z_{i}\")\n",
    "    S_guess += L * coeff_i * z[i] ** 2\n",
    "\n",
    "remainder_1 = S_sol.matrix - pm.eval_scalar(S_guess).inner_prod_coords\n",
    "pf.pprint_labeled_matrix(remainder_1, S_sol.row_names, S_sol.col_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e35888e9",
   "metadata": {},
   "source": [
    "- Check whether our candidate of $S$ matches with solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "60bf29ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Did we guess the right closed form of ell? True\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    \"Did we guess the right closed form of ell?\",\n",
    "    np.allclose(pm.eval_scalar(S_guess).inner_prod_coords, S_sol.matrix, atol=1e-4),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54824c93",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "dd581a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[x_0,\n",
       " x_star,\n",
       " grad_f(x_0),\n",
       " grad_f(x_1),\n",
       " grad_f(x_2),\n",
       " grad_f(x_3),\n",
       " grad_f(x_4),\n",
       " grad_f(x_5)]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ctx_prf.set_as_current()\n",
    "ctx_prf.basis_vectors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ccbc5542",
   "metadata": {},
   "outputs": [],
   "source": [
    "lyap_basis_candidate = ctx_prf.basis_vectors()\n",
    "lyap_basis_candidate += x\n",
    "\n",
    "# Add any two difference between special vectors as new special vectors\n",
    "for i, j in itertools.combinations(range(len(lyap_basis_candidate)), 2):\n",
    "    diff = lyap_basis_candidate[i] - lyap_basis_candidate[j]\n",
    "    lyap_basis_candidate.append(diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "787fb5b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[x_0,\n",
       " x_star,\n",
       " grad_f(x_0),\n",
       " grad_f(x_1),\n",
       " grad_f(x_2),\n",
       " grad_f(x_3),\n",
       " grad_f(x_4),\n",
       " grad_f(x_5),\n",
       " x_0,\n",
       " x_1,\n",
       " x_2,\n",
       " x_3,\n",
       " x_4,\n",
       " x_5,\n",
       " x_star,\n",
       " x_0-x_star,\n",
       " x_0-grad_f(x_0),\n",
       " x_0-grad_f(x_1),\n",
       " x_0-grad_f(x_2),\n",
       " x_0-grad_f(x_3),\n",
       " x_0-grad_f(x_4),\n",
       " x_0-grad_f(x_5),\n",
       " x_0-x_0,\n",
       " x_0-(x_1),\n",
       " x_0-(x_2),\n",
       " x_0-(x_3),\n",
       " x_0-(x_4),\n",
       " x_0-(x_5),\n",
       " x_0-x_star,\n",
       " x_star-grad_f(x_0),\n",
       " x_star-grad_f(x_1),\n",
       " x_star-grad_f(x_2),\n",
       " x_star-grad_f(x_3),\n",
       " x_star-grad_f(x_4),\n",
       " x_star-grad_f(x_5),\n",
       " x_star-x_0,\n",
       " x_star-(x_1),\n",
       " x_star-(x_2),\n",
       " x_star-(x_3),\n",
       " x_star-(x_4),\n",
       " x_star-(x_5),\n",
       " x_star-x_star,\n",
       " grad_f(x_0)-grad_f(x_1),\n",
       " grad_f(x_0)-grad_f(x_2),\n",
       " grad_f(x_0)-grad_f(x_3),\n",
       " grad_f(x_0)-grad_f(x_4),\n",
       " grad_f(x_0)-grad_f(x_5),\n",
       " grad_f(x_0)-x_0,\n",
       " grad_f(x_0)-(x_1),\n",
       " grad_f(x_0)-(x_2),\n",
       " grad_f(x_0)-(x_3),\n",
       " grad_f(x_0)-(x_4),\n",
       " grad_f(x_0)-(x_5),\n",
       " grad_f(x_0)-x_star,\n",
       " grad_f(x_1)-grad_f(x_2),\n",
       " grad_f(x_1)-grad_f(x_3),\n",
       " grad_f(x_1)-grad_f(x_4),\n",
       " grad_f(x_1)-grad_f(x_5),\n",
       " grad_f(x_1)-x_0,\n",
       " grad_f(x_1)-(x_1),\n",
       " grad_f(x_1)-(x_2),\n",
       " grad_f(x_1)-(x_3),\n",
       " grad_f(x_1)-(x_4),\n",
       " grad_f(x_1)-(x_5),\n",
       " grad_f(x_1)-x_star,\n",
       " grad_f(x_2)-grad_f(x_3),\n",
       " grad_f(x_2)-grad_f(x_4),\n",
       " grad_f(x_2)-grad_f(x_5),\n",
       " grad_f(x_2)-x_0,\n",
       " grad_f(x_2)-(x_1),\n",
       " grad_f(x_2)-(x_2),\n",
       " grad_f(x_2)-(x_3),\n",
       " grad_f(x_2)-(x_4),\n",
       " grad_f(x_2)-(x_5),\n",
       " grad_f(x_2)-x_star,\n",
       " grad_f(x_3)-grad_f(x_4),\n",
       " grad_f(x_3)-grad_f(x_5),\n",
       " grad_f(x_3)-x_0,\n",
       " grad_f(x_3)-(x_1),\n",
       " grad_f(x_3)-(x_2),\n",
       " grad_f(x_3)-(x_3),\n",
       " grad_f(x_3)-(x_4),\n",
       " grad_f(x_3)-(x_5),\n",
       " grad_f(x_3)-x_star,\n",
       " grad_f(x_4)-grad_f(x_5),\n",
       " grad_f(x_4)-x_0,\n",
       " grad_f(x_4)-(x_1),\n",
       " grad_f(x_4)-(x_2),\n",
       " grad_f(x_4)-(x_3),\n",
       " grad_f(x_4)-(x_4),\n",
       " grad_f(x_4)-(x_5),\n",
       " grad_f(x_4)-x_star,\n",
       " grad_f(x_5)-x_0,\n",
       " grad_f(x_5)-(x_1),\n",
       " grad_f(x_5)-(x_2),\n",
       " grad_f(x_5)-(x_3),\n",
       " grad_f(x_5)-(x_4),\n",
       " grad_f(x_5)-(x_5),\n",
       " grad_f(x_5)-x_star,\n",
       " x_0-(x_1),\n",
       " x_0-(x_2),\n",
       " x_0-(x_3),\n",
       " x_0-(x_4),\n",
       " x_0-(x_5),\n",
       " x_0-x_star,\n",
       " x_1-(x_2),\n",
       " x_1-(x_3),\n",
       " x_1-(x_4),\n",
       " x_1-(x_5),\n",
       " x_1-x_star,\n",
       " x_2-(x_3),\n",
       " x_2-(x_4),\n",
       " x_2-(x_5),\n",
       " x_2-x_star,\n",
       " x_3-(x_4),\n",
       " x_3-(x_5),\n",
       " x_3-x_star,\n",
       " x_4-(x_5),\n",
       " x_4-x_star,\n",
       " x_5-x_star]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lyap_basis_candidate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e87af7ee",
   "metadata": {},
   "source": [
    "## Identify the vectors composing the Lyapunov function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eab04735",
   "metadata": {},
   "source": [
    "Compute the sum of active inequalities up to k-th iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "0cf2d0c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "lyap = [pf.Scalar.zero()]\n",
    "partial_sum = 0\n",
    "partial_sum -= delta[0] * ell[N] ** 2\n",
    "partial_sum += lamb(\"x_star\", f\"x_{0}\") * f.interp_ineq(\"x_star\", f\"x_{0}\")\n",
    "for j in np.arange(N):\n",
    "    partial_sum += lamb(f\"x_{j}\", f\"x_{j + 1}\") * f.interp_ineq(f\"x_{j}\", f\"x_{j + 1}\")\n",
    "    partial_sum += lamb(\"x_star\", f\"x_{j + 1}\") * f.interp_ineq(\"x_star\", f\"x_{j + 1}\")\n",
    "    partial_sum -= delta[j + 1] * ell[N - j - 1] ** 2\n",
    "    lyap.append(partial_sum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7deb0d63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rank of lyap[0]: 0\n",
      "Rank of lyap[1]: 3\n",
      "Rank of lyap[2]: 3\n",
      "Rank of lyap[3]: 3\n",
      "Rank of lyap[4]: 3\n",
      "Rank of lyap[5]: 1\n"
     ]
    }
   ],
   "source": [
    "for k in range(len(lyap)):\n",
    "    lyap_numeric_k = pm.eval_scalar(lyap[k]).inner_prod_coords.astype(float)\n",
    "    print(f\"Rank of lyap[{k}]: {np.linalg.matrix_rank(lyap_numeric_k, tol=1e-4)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "665a55b4",
   "metadata": {},
   "source": [
    "#### Extract the special vectors representing lyap[k] = V_{k+1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c6e5f600",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V_1: [grad_f(x_0), grad_f(x_1), x_0-x_star, x_0-x_0, x_0-(x_1), x_0-(x_2), x_0-x_star, x_star-x_0, x_star-(x_1), x_star-(x_2), x_star-x_star, grad_f(x_0)-grad_f(x_1), x_0-(x_1), x_0-(x_2), x_0-x_star, x_1-(x_2), x_1-x_star, x_2-x_star]\n",
      "V_2: [grad_f(x_2), x_0-x_star, x_0-x_0, x_0-(x_2), x_0-(x_3), x_0-x_star, x_star-x_0, x_star-(x_2), x_star-(x_3), x_star-x_star, x_0-(x_2), x_0-(x_3), x_0-x_star, x_2-(x_3), x_2-x_star, x_3-x_star]\n",
      "V_3: [grad_f(x_3), x_0-x_star, x_0-x_0, x_0-(x_3), x_0-(x_4), x_0-x_star, x_star-x_0, x_star-(x_3), x_star-(x_4), x_star-x_star, x_0-(x_3), x_0-(x_4), x_0-x_star, x_3-(x_4), x_3-x_star, x_4-x_star]\n",
      "V_4: [grad_f(x_4), x_0-x_star, x_0-x_0, x_0-(x_4), x_0-(x_5), x_0-x_star, x_star-x_0, x_star-(x_4), x_star-(x_5), x_star-x_star, x_0-(x_4), x_0-(x_5), x_0-x_star, x_4-(x_5), x_4-x_star, x_5-x_star]\n",
      "V_5: [x_0-x_star, x_0-x_0, x_0-x_star, x_star-x_0, x_star-x_star, x_0-x_star]\n"
     ]
    }
   ],
   "source": [
    "for k in range(1, len(lyap)):\n",
    "    print(\n",
    "        f\"V_{k}:\",\n",
    "        vectors_in_column_space(\n",
    "            lyap[k],\n",
    "            lyap_basis_candidate,\n",
    "            ctx_prf,\n",
    "            resolve_parameters=pm.resolve_parameters,\n",
    "            rtol=1e-4,\n",
    "            atol=1e-4,\n",
    "        ),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a508576",
   "metadata": {},
   "source": [
    "#### Collecting good linearly independent vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "4be4aac6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lyap[1]:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & \\nabla f(x_1) & x_0-x_\\star & x_\\star-(x_2) \\\\\n",
       "        \\hline\n",
       "        \\nabla f(x_1) & -0.111 & 0.0 & -0.0 \\\\x_0-x_\\star & 0.0 & -0.045 & -0.0 \\\\x_\\star-(x_2) & -0.0 & -0.0 & 0.043 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lyap[2]:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & \\nabla f(x_2) & x_0-x_\\star & x_\\star-(x_3) \\\\\n",
       "        \\hline\n",
       "        \\nabla f(x_2) & -0.188 & 0.0 & -0.0 \\\\x_0-x_\\star & 0.0 & -0.045 & -0.0 \\\\x_\\star-(x_3) & -0.0 & -0.0 & 0.039 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lyap[3]:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & \\nabla f(x_3) & x_0-x_\\star & x_\\star-(x_4) \\\\\n",
       "        \\hline\n",
       "        \\nabla f(x_3) & -0.286 & 0.0 & 0.0 \\\\x_0-x_\\star & 0.0 & -0.045 & -0.0 \\\\x_\\star-(x_4) & 0.0 & -0.0 & 0.031 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lyap[4]:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & \\nabla f(x_4) & x_0-x_\\star & x_\\star-(x_5) \\\\\n",
       "        \\hline\n",
       "        \\nabla f(x_4) & -0.417 & 0.0 & 0.0 \\\\x_0-x_\\star & 0.0 & -0.045 & -0.0 \\\\x_\\star-(x_5) & 0.0 & -0.0 & 0.014 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for k in range(1, len(lyap) - 1):\n",
    "    aligned_special_vectors_k = vectors_in_column_space(\n",
    "        lyap[k],\n",
    "        lyap_basis_candidate,\n",
    "        ctx_prf,\n",
    "        resolve_parameters=pm.resolve_parameters,\n",
    "        rtol=1e-4,\n",
    "        atol=1e-4,\n",
    "    )\n",
    "    best_vectors, best_coefficients = find_sparsest_decompositions(\n",
    "        lyap[k],\n",
    "        aligned_special_vectors_k,\n",
    "        pep_context=ctx_prf,\n",
    "        resolve_parameters=pm.resolve_parameters,\n",
    "    )\n",
    "    labels_k = [str(v) for v in best_vectors]\n",
    "    print(f\"lyap[{k}]:\")\n",
    "    pf.pprint_labeled_matrix(best_coefficients, labels_k, labels_k)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24cf9bf8",
   "metadata": {},
   "source": [
    "#### Guess the generic basis vectors based on above observation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "8e530cd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # TODO: this doesn't work\n",
    "\n",
    "# basis_by_k_auto, templates_auto = infer_k_dependent_basis_templates(\n",
    "#     lyap,\n",
    "#     lyap_basis_candidate,\n",
    "#     pep_context=ctx_prf,\n",
    "#     resolve_parameters=pm.resolve_parameters,\n",
    "#     k_start=2,\n",
    "# )\n",
    "# for k in range(2, len(lyap)-1):\n",
    "#     print(f\"V_{k}:\", basis_by_k_auto[k])\n",
    "# print(\"Inferred templates:\", templates_auto)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7e17a48e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: This part could be automated later.\n",
    "@lru_cache(None)\n",
    "def V_k_basis(k: int, N=N):\n",
    "    if k == N:\n",
    "        return [x[0] - x_star]\n",
    "    return [x[0] - x_star, f.grad(x[k]), x[k + 1] - x_star]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "4a56e919",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V_1: [x_0-x_star, grad_f(x_1), x_2-x_star]\n",
      "V_2: [x_0-x_star, grad_f(x_2), x_3-x_star]\n",
      "V_3: [x_0-x_star, grad_f(x_3), x_4-x_star]\n",
      "V_4: [x_0-x_star, grad_f(x_4), x_5-x_star]\n",
      "V_5: [x_0-x_star]\n"
     ]
    }
   ],
   "source": [
    "for k in range(1, len(lyap)):\n",
    "    print(f\"V_{k}:\", V_k_basis(k))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f7bbe3",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5a8ab9ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "@lru_cache(None)\n",
    "def V_k_basis_labels(k: int, N=N):\n",
    "    return [str(x) for x in V_k_basis(k)]\n",
    "\n",
    "\n",
    "def print_numeric_coefficients(k):\n",
    "    lyap_k_coeff = decompose_rankr_symmetric(\n",
    "        lyap[k],\n",
    "        V_k_basis(k),\n",
    "        pep_context=ctx_prf,\n",
    "        resolve_parameters=pm.resolve_parameters,\n",
    "    )\n",
    "    pf.pprint_labeled_matrix(lyap_k_coeff, V_k_basis_labels(k), V_k_basis_labels(k))\n",
    "    pf.pprint_labeled_vector(\n",
    "        pm.eval_scalar(lyap[k]).func_coords.astype(float),\n",
    "        [str(x) for x in ctx_prf.basis_scalars()],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "96bd1431",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V_1:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & x_0-x_\\star & \\nabla f(x_1) & x_2-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.045 & 0.0 & 0.0 \\\\\\nabla f(x_1) & 0.0 & -0.111 & 0.0 \\\\x_2-x_\\star & 0.0 & 0.0 & 0.043 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{ccccccc}\n",
       "        f(x_\\star) & f(x_0) & f(x_1) & f(x_2) & f(x_3) & f(x_4) & f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.222 & 0.0 & 0.222 & 0.0 & 0.0 & 0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V_2:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & x_0-x_\\star & \\nabla f(x_2) & x_3-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.045 & 0.0 & 0.0 \\\\\\nabla f(x_2) & 0.0 & -0.188 & 0.0 \\\\x_3-x_\\star & 0.0 & 0.0 & 0.039 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{ccccccc}\n",
       "        f(x_\\star) & f(x_0) & f(x_1) & f(x_2) & f(x_3) & f(x_4) & f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.375 & 0.0 & 0.0 & 0.375 & 0.0 & 0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V_3:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & x_0-x_\\star & \\nabla f(x_3) & x_4-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.045 & 0.0 & 0.0 \\\\\\nabla f(x_3) & 0.0 & -0.286 & -0.0 \\\\x_4-x_\\star & 0.0 & -0.0 & 0.031 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{ccccccc}\n",
       "        f(x_\\star) & f(x_0) & f(x_1) & f(x_2) & f(x_3) & f(x_4) & f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.571 & 0.0 & 0.0 & 0.0 & 0.571 & 0.0 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V_4:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccc}\n",
       "         & x_0-x_\\star & \\nabla f(x_4) & x_5-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.045 & 0.0 & 0.0 \\\\\\nabla f(x_4) & 0.0 & -0.417 & -0.0 \\\\x_5-x_\\star & 0.0 & -0.0 & 0.014 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{ccccccc}\n",
       "        f(x_\\star) & f(x_0) & f(x_1) & f(x_2) & f(x_3) & f(x_4) & f(x_5) \\\\\n",
       "        \\hline\n",
       "        -0.833 & 0.0 & 0.0 & 0.0 & 0.0 & 0.833 & 0.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V_5:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|c}\n",
       "         & x_0-x_\\star \\\\\n",
       "        \\hline\n",
       "        x_0-x_\\star & -0.045 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{ccccccc}\n",
       "        f(x_\\star) & f(x_0) & f(x_1) & f(x_2) & f(x_3) & f(x_4) & f(x_5) \\\\\n",
       "        \\hline\n",
       "        -1.0 & 0.0 & 0.0 & 0.0 & 0.0 & -0.0 & 1.0\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# TODO: better to have numeric -> abstract for function value\n",
    "for k in range(1, len(lyap)):\n",
    "    print(f\"V_{k}:\")\n",
    "    print_numeric_coefficients(k)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "172c9d03",
   "metadata": {},
   "source": [
    "### Finding the coefficient of Lyapunov function: symbolic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "ee0b73a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctx_gd_lyap = pf.PEPContext(\"gd_lyap_finder\").set_as_current()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd8ede53",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = pf.Parameter(\"k\")\n",
    "N = pf.Parameter(\"N\")\n",
    "L = pf.Parameter(\"L\")\n",
    "alpha = 1 / L\n",
    "\n",
    "f.set_stationary_point(\"x_star\")\n",
    "x_star = ctx_gd_lyap[\"x_star\"]\n",
    "x_0 = pf.Vector(is_basis=True, tags=[\"x_0\"])\n",
    "x_k = pf.Vector(is_basis=True, tags=[\"x_k\"])\n",
    "\n",
    "x_k1 = x_k - alpha * f.grad(x_k)\n",
    "x_k1.add_tag(\"x_{k+1}\")\n",
    "\n",
    "x_k2 = x_k1 - alpha * f.grad(x_k1)\n",
    "x_k2.add_tag(\"x_{k+2}\")\n",
    "\n",
    "a_k = pf.Parameter(\"a_k\")\n",
    "a_k1 = pf.Parameter(\"a_{k+1}\")\n",
    "b_k = pf.Parameter(\"b_k\")\n",
    "b_k1 = pf.Parameter(\"b_{k+1}\")\n",
    "c_k = pf.Parameter(\"c_k\")\n",
    "c_k1 = pf.Parameter(\"c_{k+1}\")\n",
    "d_k = pf.Parameter(\"d_k\")\n",
    "d_k1 = pf.Parameter(\"d_{k+1}\")\n",
    "\n",
    "s_k1 = pf.Parameter(\"s_{N-k-1}\")\n",
    "lamb_k_k1 = pf.Parameter(r\"\\lambda_{k,k+1}\")\n",
    "# lamb_k_k1 = (k+1)/(2*N-k)\n",
    "lamb_star_k1 = pf.Parameter(r\"\\lambda_{star,k+1}\")\n",
    "# lamb_star_k1 = (k+2)/(2*N-k-1) - lamb_k_k1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "c8f8000e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle a_{k+1}*(f(x_{k+1})-f(x_\\star))+b_{k+1}*\\|\\nabla f(x_{k+1})\\|^2+c_{k+1}*\\|x_{k+2}-x_\\star\\|^2+d_{k+1}*\\|x_0-x_\\star\\|^2-(a_k*(f(x_k)-f(x_\\star))+b_k*\\|\\nabla f(x_k)\\|^2+c_k*\\|x_{k+1}-x_\\star\\|^2+d_k*\\|x_0-x_\\star\\|^2)$"
      ],
      "text/plain": [
       "a_{k+1}*(f(x_{k+1})-f(x_star))+b_{k+1}*|grad_f(x_{k+1})|^2+c_{k+1}*|x_{k+2}-x_star|^2+d_{k+1}*|x_0-x_star|^2-(a_k*(f(x_k)-f(x_star))+b_k*|grad_f(x_k)|^2+c_k*|x_{k+1}-x_star|^2+d_k*|x_0-x_star|^2)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "V_k_prev = (\n",
    "    a_k * (f(x_k) - f(x_star))\n",
    "    + b_k * f.grad(x_k) ** 2\n",
    "    + c_k * (x_k1 - x_star) ** 2\n",
    "    + d_k * (x_0 - x_star) ** 2\n",
    ")\n",
    "V_k = (\n",
    "    a_k1 * (f(x_k1) - f(x_star))\n",
    "    + b_k1 * f.grad(x_k1) ** 2\n",
    "    + c_k1 * (x_k2 - x_star) ** 2\n",
    "    + d_k1 * (x_0 - x_star) ** 2\n",
    ")\n",
    "\n",
    "LHS = V_k - V_k_prev\n",
    "LHS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "4ca5b70f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\lambda_{k,k+1}*(f(x_{k+1})-f(x_k)+\\nabla f(x_{k+1})*(x_k-(x_{k+1}))+1/2*L*\\|\\nabla f(x_k)-\\nabla f(x_{k+1})\\|^2)+\\lambda_{\\\\star,k+1}*(f(x_{k+1})-f(x_\\star)+\\nabla f(x_{k+1})*(x_\\star-(x_{k+1}))+1/2*L*\\|\\nabla f(x_\\star)-\\nabla f(x_{k+1})\\|^2)-s_{N-k-1}*\\|1/2*N-k*(x_{k+1}-x_\\star)-\\nabla f(x_{k+1})\\|^2$"
      ],
      "text/plain": [
       "\\lambda_{k,k+1}*(f(x_{k+1})-f(x_k)+grad_f(x_{k+1})*(x_k-(x_{k+1}))+1/2*L*|grad_f(x_k)-grad_f(x_{k+1})|^2)+\\lambda_{\\star,k+1}*(f(x_{k+1})-f(x_star)+grad_f(x_{k+1})*(x_star-(x_{k+1}))+1/2*L*|grad_f(x_star)-grad_f(x_{k+1})|^2)-s_{N-k-1}*|1/2*N-k*(x_{k+1}-x_star)-grad_f(x_{k+1})|^2"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "RHS = lamb_k_k1 * f.interp_ineq(\"x_k\", \"x_{k+1}\")\n",
    "RHS += lamb_star_k1 * f.interp_ineq(\"x_star\", \"x_{k+1}\")\n",
    "RHS -= s_k1 * ((x_k1 - x_star) / (2 * N - k) - f.grad(x_k1)) ** 2\n",
    "\n",
    "RHS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "2347f8c8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle a_{k+1}*(f(x_{k+1})-f(x_\\star))+b_{k+1}*\\|\\nabla f(x_{k+1})\\|^2+c_{k+1}*\\|x_{k+2}-x_\\star\\|^2+d_{k+1}*\\|x_0-x_\\star\\|^2-(a_k*(f(x_k)-f(x_\\star))+b_k*\\|\\nabla f(x_k)\\|^2+c_k*\\|x_{k+1}-x_\\star\\|^2+d_k*\\|x_0-x_\\star\\|^2)-(\\lambda_{k,k+1}*(f(x_{k+1})-f(x_k)+\\nabla f(x_{k+1})*(x_k-(x_{k+1}))+1/2*L*\\|\\nabla f(x_k)-\\nabla f(x_{k+1})\\|^2)+\\lambda_{\\\\star,k+1}*(f(x_{k+1})-f(x_\\star)+\\nabla f(x_{k+1})*(x_\\star-(x_{k+1}))+1/2*L*\\|\\nabla f(x_\\star)-\\nabla f(x_{k+1})\\|^2)-s_{N-k-1}*\\|1/2*N-k*(x_{k+1}-x_\\star)-\\nabla f(x_{k+1})\\|^2)$"
      ],
      "text/plain": [
       "a_{k+1}*(f(x_{k+1})-f(x_star))+b_{k+1}*|grad_f(x_{k+1})|^2+c_{k+1}*|x_{k+2}-x_star|^2+d_{k+1}*|x_0-x_star|^2-(a_k*(f(x_k)-f(x_star))+b_k*|grad_f(x_k)|^2+c_k*|x_{k+1}-x_star|^2+d_k*|x_0-x_star|^2)-(\\lambda_{k,k+1}*(f(x_{k+1})-f(x_k)+grad_f(x_{k+1})*(x_k-(x_{k+1}))+1/2*L*|grad_f(x_k)-grad_f(x_{k+1})|^2)+\\lambda_{\\star,k+1}*(f(x_{k+1})-f(x_star)+grad_f(x_{k+1})*(x_star-(x_{k+1}))+1/2*L*|grad_f(x_star)-grad_f(x_{k+1})|^2)-s_{N-k-1}*|1/2*N-k*(x_{k+1}-x_star)-grad_f(x_{k+1})|^2)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diff = LHS - RHS\n",
    "\n",
    "diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "090f9c5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "basis_vectors = ctx_gd_lyap.basis_vectors()\n",
    "vec_index = [str(v) for v in basis_vectors]\n",
    "\n",
    "basis_scalars = ctx_gd_lyap.basis_scalars()\n",
    "scal_index = [str(v) for v in basis_scalars]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "2d6138b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "a_k_sp, a_k1_sp, b_k_sp, b_k1_sp, c_k_sp, c_k1_sp, d_k_sp, d_k1_sp, N_sp, k_sp, L_sp = (\n",
    "    sp.symbols(\"a_k a_{k+1} b_k b_{k+1} c_k c_{k+1} d_k d_{k+1} N k L\")\n",
    ")\n",
    "s_k1_sp = sp.Symbol(\"s_{N-k-1}\")\n",
    "lamb_k_k1_sp = sp.Symbol(r\"\\lambda_{k,k+1}\")\n",
    "lamb_star_k1_sp = sp.Symbol(r\"\\lambda_{\\star,k+1}\")\n",
    "\n",
    "pm_lyap = pf.ExpressionManager(\n",
    "    ctx_gd_lyap,\n",
    "    resolve_parameters={\n",
    "        \"a_k\": a_k_sp,\n",
    "        \"a_{k+1}\": a_k1_sp,\n",
    "        \"b_k\": b_k_sp,\n",
    "        \"b_{k+1}\": b_k1_sp,\n",
    "        \"c_k\": c_k_sp,\n",
    "        \"c_{k+1}\": c_k1_sp,\n",
    "        \"d_k\": d_k_sp,\n",
    "        \"d_{k+1}\": d_k1_sp,\n",
    "        \"s_{N-k-1}\": s_k1_sp,\n",
    "        r\"\\lambda_{k,k+1}\": lamb_k_k1_sp,\n",
    "        r\"\\lambda_{\\star,k+1}\": lamb_star_k1_sp,\n",
    "        \"N\": N_sp,\n",
    "        \"k\": k_sp,\n",
    "        \"L\": L_sp,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "6ce028eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{c|ccccc}\n",
       "         & x_\\star & x_0 & x_k & \\nabla f(x_k) & \\nabla f(x_{k+1}) \\\\\n",
       "        \\hline\n",
       "        x_\\star & -1.0*c_k + 1.0*c_{k+1} - 1.0*d_k + 1.0*d_{k+1} + 1.0*s_{N-k-1}/(2*N - k)^2 & 1.0*d_k - 1.0*d_{k+1} & 1.0*c_k - 1.0*c_{k+1} - 1.0*s_{N-k-1}/(2*N - k)^2 & -1.0*c_k/L + 1.0*c_{k+1}/L + 1.0*s_{N-k-1}/(L*(2*N - k)^2) & -0.5*\\lambda_{\\\\star,k+1} + 1.0*s_{N-k-1}/(2*N - k) + 1.0*c_{k+1}/L \\\\x_0 & 1.0*d_k - 1.0*d_{k+1} & -1.0*d_k + 1.0*d_{k+1} & 0 & 0 & 0 \\\\x_k & 1.0*c_k - 1.0*c_{k+1} - 1.0*s_{N-k-1}/(2*N - k)^2 & 0 & -1.0*c_k + 1.0*c_{k+1} + 1.0*s_{N-k-1}/(2*N - k)^2 & 1.0*c_k/L - 1.0*c_{k+1}/L - 1.0*s_{N-k-1}/(L*(2*N - k)^2) & 0.5*\\lambda_{\\\\star,k+1} - 1.0*s_{N-k-1}/(2*N - k) - 1.0*c_{k+1}/L \\\\\\nabla f(x_k) & -1.0*c_k/L + 1.0*c_{k+1}/L + 1.0*s_{N-k-1}/(L*(2*N - k)^2) & 0 & 1.0*c_k/L - 1.0*c_{k+1}/L - 1.0*s_{N-k-1}/(L*(2*N - k)^2) & -1.0*b_k - 0.5*\\lambda_{k,k+1}/L - 1.0*c_k/L^2 + 1.0*c_{k+1}/L^2 + 1.0*s_{N-k-1}/(L^2*(2*N - k)^2) & -0.5*\\lambda_{\\\\star,k+1}/L + 1.0*s_{N-k-1}/(L*(2*N - k)) + 1.0*c_{k+1}/L^2 \\\\\\nabla f(x_{k+1}) & -0.5*\\lambda_{\\\\star,k+1} + 1.0*s_{N-k-1}/(2*N - k) + 1.0*c_{k+1}/L & 0 & 0.5*\\lambda_{\\\\star,k+1} - 1.0*s_{N-k-1}/(2*N - k) - 1.0*c_{k+1}/L & -0.5*\\lambda_{\\\\star,k+1}/L + 1.0*s_{N-k-1}/(L*(2*N - k)) + 1.0*c_{k+1}/L^2 & 1.0*b_{k+1} + 1.0*s_{N-k-1} - 0.5*\\lambda_{\\\\star,k+1}/L - 0.5*\\lambda_{k,k+1}/L + 1.0*c_{k+1}/L^2 \\\\\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "diff_matrix = pm_lyap.eval_scalar(diff).inner_prod_coords\n",
    "pf.pprint_labeled_matrix(diff_matrix, vec_index, vec_index, precision=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "8506ec45",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\displaystyle \n",
       "        \\begin{array}{ccc}\n",
       "        f(x_\\star) & f(x_k) & f(x_{k+1}) \\\\\n",
       "        \\hline\n",
       "        1.0*\\lambda_{\\\\star,k+1} + 1.0*a_k - 1.0*a_{k+1} & 1.0*\\lambda_{k,k+1} - 1.0*a_k & -1.0*\\lambda_{\\\\star,k+1} - 1.0*\\lambda_{k,k+1} + 1.0*a_{k+1}\n",
       "        \\end{array}\n",
       "        $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "diff_vec = pm_lyap.eval_scalar(diff).func_coords\n",
    "pf.pprint_labeled_vector(diff_vec, scal_index, precision=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "baa4cffa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solutions:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\left\\{ a_{k} : \\lambda_{k,k+1}, \\  a_{k+1} : \\lambda_{\\star,k+1} + \\lambda_{k,k+1}, \\  b_{k} : - \\frac{\\lambda_{k,k+1}}{2 L}, \\  b_{k+1} : - \\frac{- 4 L N s_{N-k-1} + 2 L k s_{N-k-1} + 2 N \\lambda_{k,k+1} - \\lambda_{k,k+1} k + 2 s_{N-k-1}}{2 L \\left(- 2 N + k\\right)}, \\  c_{k} : \\frac{4 L N^{2} \\lambda_{\\star,k+1} - 4 L N \\lambda_{\\star,k+1} k - 4 L N s_{N-k-1} + L \\lambda_{\\star,k+1} k^{2} + 2 L k s_{N-k-1} + 2 s_{N-k-1}}{2 \\left(- 2 N + k\\right)^{2}}, \\  c_{k+1} : \\frac{L \\left(- 2 N \\lambda_{\\star,k+1} + \\lambda_{\\star,k+1} k + 2 s_{N-k-1}\\right)}{2 \\left(- 2 N + k\\right)}, \\  d_{k} : d_{k+1}\\right\\}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "diff_matrix_sympify = sp.Matrix(diff_matrix)\n",
    "diff_vec_sympify = sp.Matrix(diff_vec)\n",
    "\n",
    "# unknowns = (a_k_sp, a_k1_sp, b_k_sp, b_k1_sp, c_k_sp, c_k1_sp, d_k_sp, d_k1_sp, s_k1_sp, lamb_k_k1_sp, lamb_star_k1_sp)\n",
    "# unknowns = (a_k_sp, a_k1_sp, b_k_sp, b_k1_sp, c_k_sp, c_k1_sp, d_k_sp, d_k1_sp, s_k1_sp)\n",
    "unknowns = (a_k_sp, a_k1_sp, b_k_sp, b_k1_sp, c_k_sp, c_k1_sp, d_k_sp)\n",
    "\n",
    "eqs = list(diff_matrix_sympify)\n",
    "eqs += list(diff_vec_sympify)\n",
    "eqs = [e for e in eqs]\n",
    "\n",
    "sol = sp.linsolve(eqs, unknowns)\n",
    "sol_simplify = sp.factor(sp.factor(sp.nsimplify(sol)))\n",
    "# sol_simplify = sp.nsimplify(sol)\n",
    "\n",
    "print(\"Solutions:\")\n",
    "sol_dict = dict(zip(unknowns, next(iter(sol_simplify))))\n",
    "display(Math(sp.latex(sol_dict)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2122b554",
   "metadata": {},
   "source": [
    "Double check we found the right solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "f4136334",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Verification:\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\left[\\begin{matrix}0 & 0 & 0 & 0 & 0\\\\0 & 0 & 0 & 0 & 0\\\\0 & 0 & 0 & 0 & 0\\\\0 & 0 & 0 & 0 & 0\\\\0 & 0 & 0 & 0 & 0\\end{matrix}\\right]$"
      ],
      "text/plain": [
       "Matrix([\n",
       "[0, 0, 0, 0, 0],\n",
       "[0, 0, 0, 0, 0],\n",
       "[0, 0, 0, 0, 0],\n",
       "[0, 0, 0, 0, 0],\n",
       "[0, 0, 0, 0, 0]])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\left[\\begin{matrix}0\\\\0\\\\0\\end{matrix}\\right]$"
      ],
      "text/plain": [
       "Matrix([\n",
       "[0],\n",
       "[0],\n",
       "[0]])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "if sol:\n",
    "    print(\"\\nVerification:\")\n",
    "    display(sp.simplify(diff_matrix_sympify.subs(sol_dict)))\n",
    "    display(sp.simplify(diff_vec_sympify.subs(sol_dict)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pepflow (3.11.13)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
