{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fgm-title",
   "metadata": {},
   "source": [
    "# Fast Gradient Method\n",
    "\n",
    "This notebook reproduces the smooth-convex FGM examples from the article. The CVXPY SDP definitions live in `paper_examples/fgm.py`; the experiment logic and plots are shown here.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fgm-imports",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-07T05:18:28.455423Z",
     "iopub.status.busy": "2026-05-07T05:18:28.455270Z",
     "iopub.status.idle": "2026-05-07T05:18:28.961588Z",
     "shell.execute_reply": "2026-05-07T05:18:28.961034Z"
    }
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "\n",
    "from IPython.display import Image, display\n",
    "import numpy as np\n",
    "\n",
    "repo_root = Path.cwd().resolve()\n",
    "os.environ.setdefault(\"MPLCONFIGDIR\", str(repo_root / \".cache\" / \"matplotlib\"))\n",
    "\n",
    "from paper_examples import fgm\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fgm-first-md",
   "metadata": {},
   "source": [
    "The first experiment uses the base smooth-convex interpolation assumptions for $N=3$, $L=1$, and the relaxed target $\\bar\\rho=1/13+10^{-9}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fgm-solves",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-07T05:18:28.963105Z",
     "iopub.status.busy": "2026-05-07T05:18:28.963015Z",
     "iopub.status.idle": "2026-05-07T05:18:29.534419Z",
     "shell.execute_reply": "2026-05-07T05:18:29.533937Z"
    }
   },
   "outputs": [],
   "source": [
    "N = 3\n",
    "L = 1.0\n",
    "solver = \"MOSEK\"\n",
    "sdp = fgm.build_fgm_arrays(n=N, L=L)\n",
    "labels, matrices, vectors, Q_init, q_const = sdp\n",
    "cap = fgm.rho_cap(N)\n",
    "solve_cap = cap - 1e-8\n",
    "\n",
    "heuristic_labels = {\n",
    "    \"plain\": \"Raw\",\n",
    "    \"l1\": r\"Plain $\\ell_1$\",\n",
    "    \"reweighted\": \"Log-sum\",\n",
    "    \"normalized_reweighted\": \"Norm. log-sum\",\n",
    "    \"capped\": r\"Capped $\\ell_1$\",\n",
    "    \"fixed_chain\": \"Conjecture\",\n",
    "}\n",
    "heuristic_order = (\"plain\", \"l1\", \"reweighted\", \"normalized_reweighted\", \"capped\", \"fixed_chain\")\n",
    "\n",
    "solves = {\n",
    "    \"plain\": fgm.solve_multiplier_sdp(labels, matrices, vectors, Q_init, q_const, solver=solver),\n",
    "    \"l1\": fgm.solve_l1_heuristic(sdp, solver=solver, objective=\"weighted_l1\", cap=solve_cap),\n",
    "    \"reweighted\": fgm.solve_l1_heuristic(sdp, solver=solver, objective=\"reweighted_l1\", cap=solve_cap),\n",
    "}\n",
    "normalization_scales = fgm.normalization_bounds(sdp, solver=solver, cap=solve_cap)\n",
    "solves[\"normalized_reweighted\"] = fgm.solve_l1_heuristic(\n",
    "    sdp,\n",
    "    solver=solver,\n",
    "    objective=\"reweighted_l1\",\n",
    "    cap=solve_cap,\n",
    "    scale=normalization_scales,\n",
    ")\n",
    "solves[\"capped\"] = fgm.solve_l1_heuristic(\n",
    "    sdp,\n",
    "    solver=solver,\n",
    "    objective=\"capped_l1\",\n",
    "    cap=solve_cap,\n",
    "    scale=normalization_scales,\n",
    ")\n",
    "\n",
    "chain_active_indices = fgm.multiplier_indices(labels, fgm.chain_active_multiplier_labels(N))\n",
    "solves[\"fixed_chain\"] = fgm.solve_multiplier_sdp(\n",
    "    labels,\n",
    "    matrices,\n",
    "    vectors,\n",
    "    Q_init,\n",
    "    q_const,\n",
    "    solver=solver,\n",
    "    active_multipliers=chain_active_indices,\n",
    ")\n",
    "\n",
    "active_counts = {}\n",
    "rhos = {}\n",
    "multiplier_map = {}\n",
    "for name, (status, rho, lambdas) in solves.items():\n",
    "    if status not in fgm.GOOD_STATUSES or rho is None or lambdas is None:\n",
    "        raise RuntimeError(f\"{name} solve failed: {status}\")\n",
    "    active_counts[name] = len(fgm.active_indices(lambdas))\n",
    "    rhos[name] = float(rho)\n",
    "    multiplier_map[name] = np.asarray(lambdas, dtype=float)\n",
    "\n",
    "expected_active = {\n",
    "    \"plain\": 16,\n",
    "    \"l1\": 8,\n",
    "    \"reweighted\": 8,\n",
    "    \"normalized_reweighted\": 7,\n",
    "    \"capped\": 7,\n",
    "    \"fixed_chain\": 6,\n",
    "}\n",
    "assert active_counts == expected_active\n",
    "assert all(rhos[name] <= cap + 5e-8 for name in expected_active if name != \"plain\")\n",
    "\n",
    "active_tables_path = repo_root / \"generated\" / \"fgm_multiplier_active_tables.tex\"\n",
    "active_tables_path = fgm.write_active_multiplier_tables(\n",
    "    labels,\n",
    "    multiplier_map,\n",
    "    heuristic_labels,\n",
    "    heuristic_order,\n",
    "    active_tables_path=active_tables_path,\n",
    "    n=N,\n",
    ")\n",
    "active_tables_text = active_tables_path.read_text()\n",
    "assert \"tab:fgm-active-multiplier-patterns\" in active_tables_text\n",
    "assert r\"Capped $\\ell_1$ (total 7)\" in active_tables_text\n",
    "assert \"Conjecture (total 6)\" in active_tables_text\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fgm-spy-tables-md",
   "metadata": {},
   "source": [
    "Let us look at the sparsity of the multipliers. Rows are source points and columns are target points; filled dots mark active multipliers and empty cells are inactive.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fgm-spy-tables",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-07T05:18:29.535992Z",
     "iopub.status.busy": "2026-05-07T05:18:29.535912Z",
     "iopub.status.idle": "2026-05-07T05:18:32.808052Z",
     "shell.execute_reply": "2026-05-07T05:18:32.807488Z"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Rectangle\n",
    "\n",
    "from paper_examples.plots import PAPER_TEXT_WIDTH_IN, apply_plot_style\n",
    "\n",
    "point_names = fgm.point_names(N)\n",
    "point_labels = [rf\"${fgm.latex_point_label(point)}$\" for point in point_names]\n",
    "\n",
    "\n",
    "def draw_active_multiplier_spy_table(ax, active_labels, title):\n",
    "    point_count = len(point_names)\n",
    "    for idx in range(point_count):\n",
    "        ax.add_patch(Rectangle((idx - 0.5, idx - 0.5), 1, 1, facecolor=\"#F3F4F6\", edgecolor=\"none\"))\n",
    "\n",
    "    active_cells = [\n",
    "        (row, column)\n",
    "        for row, source in enumerate(point_names)\n",
    "        for column, target in enumerate(point_names)\n",
    "        if f\"{source}->{target}\" in active_labels\n",
    "    ]\n",
    "    if active_cells:\n",
    "        active_rows, active_columns = np.array(active_cells).T\n",
    "        ax.scatter(active_columns, active_rows, s=26, color=\"#111827\", zorder=3)\n",
    "\n",
    "    for edge in np.arange(-0.5, point_count + 0.5, 1.0):\n",
    "        ax.axhline(edge, color=\"#D1D5DB\", linewidth=0.45, zorder=1)\n",
    "        ax.axvline(edge, color=\"#D1D5DB\", linewidth=0.45, zorder=1)\n",
    "\n",
    "    ax.set_xlim(-0.5, point_count - 0.5)\n",
    "    ax.set_ylim(point_count - 0.5, -0.5)\n",
    "    ax.set_aspect(\"equal\")\n",
    "    ax.set_xticks(range(point_count), point_labels, fontsize=6)\n",
    "    ax.set_yticks(range(point_count), point_labels, fontsize=6)\n",
    "    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False, length=0, pad=1.0)\n",
    "    ax.set_title(title, fontsize=8, pad=14)\n",
    "    ax.grid(False)\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color(\"#9CA3AF\")\n",
    "        spine.set_linewidth(0.55)\n",
    "\n",
    "\n",
    "spy_order = (\"plain\", \"l1\", \"reweighted\", \"normalized_reweighted\", \"capped\", \"fixed_chain\")\n",
    "spy_labels = {**heuristic_labels, \"normalized_reweighted\": \"Norm. log-sum\", \"fixed_chain\": \"Conjecture\"}\n",
    "spy_active_labels = {\n",
    "    name: {labels[idx] for idx in fgm.active_indices(multiplier_map[name])}\n",
    "    for name in spy_order\n",
    "}\n",
    "\n",
    "apply_plot_style()\n",
    "fig, axes = plt.subplots(2, 3, figsize=(PAPER_TEXT_WIDTH_IN, 3.25))\n",
    "for ax, name in zip(axes.flat, spy_order):\n",
    "    title = f\"{spy_labels[name]} (total {active_counts[name]})\"\n",
    "    draw_active_multiplier_spy_table(ax, spy_active_labels[name], title)\n",
    "\n",
    "fig.text(0.50, 0.015, r\"$\\bullet$ active multiplier; shaded cells are diagonal\", ha=\"center\", fontsize=8)\n",
    "fig.subplots_adjust(left=0.08, right=0.99, bottom=0.10, top=0.88, wspace=0.34, hspace=0.68)\n",
    "display(fig)\n",
    "plt.close(fig)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fgm-cardinality-plots",
   "metadata": {},
   "source": [
    "The paper-level cardinality plot compares active multiplier counts across FGM horizons. This cell solves the SDP sweep, writes the generated data artifact, and then writes the combined plot.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "write-fgm-cardinality-plots",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-07T05:18:32.809684Z",
     "iopub.status.busy": "2026-05-07T05:18:32.809482Z",
     "iopub.status.idle": "2026-05-07T05:18:37.632395Z",
     "shell.execute_reply": "2026-05-07T05:18:37.632019Z"
    }
   },
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from paper_examples.plots import PAPER_TEXT_WIDTH_IN, apply_plot_style, save_pdf_and_png\n",
    "\n",
    "N_max = 20\n",
    "exact_N_max = 4\n",
    "cardinality_cache = repo_root / \"generated\" / \"fgm_paper_cardinality.json\"\n",
    "cardinality_payload = fgm.generate_cardinality_payload(\n",
    "    n_max=N_max,\n",
    "    solver=solver,\n",
    "    exact_n_max=exact_N_max,\n",
    "    progress_factory=tqdm,\n",
    ")\n",
    "cardinality_cache.parent.mkdir(parents=True, exist_ok=True)\n",
    "cardinality_cache.write_text(json.dumps(cardinality_payload, indent=2, sort_keys=True) + \"\\n\")\n",
    "rows = cardinality_payload[\"rows\"]\n",
    "\n",
    "exact_payload = cardinality_payload.get(\"exact_verification\", {})\n",
    "exact = {}\n",
    "for n_val in range(1, N_max + 1):\n",
    "    recorded = exact_payload.get(str(n_val), exact_payload.get(n_val, {}))\n",
    "    exact[n_val] = recorded.get(\"exact_cardinality\", 2 * n_val if n_val <= exact_N_max else None)\n",
    "\n",
    "method_order = (\"raw\", \"l1\", \"reweighted\", \"normalized\", \"capped\")\n",
    "competitive_order = (\"reweighted\", \"normalized\", \"capped\")\n",
    "method_labels = {\n",
    "    \"raw\": \"Raw\",\n",
    "    \"l1\": r\"Plain $\\ell_1$\",\n",
    "    \"reweighted\": r\"Log-sum\",\n",
    "    \"normalized\": \"Norm.\\nlog-sum\",\n",
    "    \"capped\": r\"Capped $\\ell_1$\",\n",
    "}\n",
    "method_colors = {\n",
    "    \"raw\": \"#1f77b4\",\n",
    "    \"l1\": \"#ff7f0e\",\n",
    "    \"reweighted\": \"#2ca02c\",\n",
    "    \"normalized\": \"#d62728\",\n",
    "    \"capped\": \"#9467bd\",\n",
    "}\n",
    "\n",
    "ns = np.array([int(row[\"n\"]) for row in rows], dtype=int)\n",
    "exact_ns = np.array([n_val for n_val in ns if exact[int(n_val)] is not None], dtype=int)\n",
    "exact_counts = np.array([exact[int(n_val)] for n_val in exact_ns], dtype=float)\n",
    "conjecture_ns = np.arange(int(exact_ns.max()), N_max + 1)\n",
    "\n",
    "\n",
    "def draw_cardinality_panel(ax, selected_methods, title):\n",
    "    for name in selected_methods:\n",
    "        counts = [row[\"methods\"][name][\"active_multipliers\"] for row in rows]\n",
    "        ax.plot(\n",
    "            ns,\n",
    "            counts,\n",
    "            marker=\"o\",\n",
    "            linewidth=1.8,\n",
    "            markersize=4.0,\n",
    "            color=method_colors[name],\n",
    "            label=method_labels[name],\n",
    "        )\n",
    "\n",
    "    ax.scatter(\n",
    "        exact_ns,\n",
    "        exact_counts,\n",
    "        marker=\"*\",\n",
    "        s=90,\n",
    "        color=\"black\",\n",
    "        label=\"Exhaustive search\",\n",
    "        zorder=10,\n",
    "    )\n",
    "    ax.plot(exact_ns, exact_counts, color=\"black\", linewidth=1.25, zorder=9)\n",
    "    ax.plot(\n",
    "        conjecture_ns,\n",
    "        2 * conjecture_ns,\n",
    "        color=\"black\",\n",
    "        linestyle=\"--\",\n",
    "        marker=\"x\",\n",
    "        linewidth=1.35,\n",
    "        markersize=4.0,\n",
    "        label=\"Conjecture\",\n",
    "        zorder=8,\n",
    "    )\n",
    "\n",
    "    ax.set_title(title, pad=3)\n",
    "    ax.set_xlabel(r\"Horizon length $N$\", labelpad=2)\n",
    "    ax.set_xlim(0.7, N_max + 0.3)\n",
    "    ax.set_xticks([1, 5, 10, 15, 20])\n",
    "    ax.grid(True, axis=\"y\", alpha=0.25)\n",
    "    ax.grid(True, axis=\"x\", alpha=0.10)\n",
    "    ax.legend(\n",
    "        loc=\"upper left\",\n",
    "        fontsize=7,\n",
    "        frameon=True,\n",
    "        framealpha=0.88,\n",
    "        ncol=2 if len(selected_methods) > 3 else 1,\n",
    "        borderpad=0.24,\n",
    "        labelspacing=0.24,\n",
    "        handlelength=1.25,\n",
    "        columnspacing=0.70,\n",
    "        handletextpad=0.40,\n",
    "    )\n",
    "\n",
    "\n",
    "apply_plot_style()\n",
    "plt.rcParams.update({\"font.size\": 9, \"axes.labelsize\": 9, \"axes.titlesize\": 10, \"legend.fontsize\": 7})\n",
    "fig, axes = plt.subplots(1, 2, figsize=(0.98 * PAPER_TEXT_WIDTH_IN, 2.25))\n",
    "draw_cardinality_panel(axes[0], method_order, \"All methods\")\n",
    "draw_cardinality_panel(axes[1], competitive_order, \"Competitive heuristics\")\n",
    "axes[0].set_ylabel(\"Active multipliers\", labelpad=2)\n",
    "fig.suptitle(r\"FGM multiplier cardinality\", fontsize=10, y=0.98)\n",
    "fig.subplots_adjust(left=0.075, right=0.995, bottom=0.17, top=0.82, wspace=0.24)\n",
    "\n",
    "figure_pdf = repo_root / \"figures\" / \"fgm_cardinality_combined.pdf\"\n",
    "figure_png = figure_pdf.with_suffix(\".png\")\n",
    "save_pdf_and_png(fig, figure_pdf)\n",
    "plt.close(fig)\n",
    "\n",
    "display(Image(filename=str(figure_png)))\n",
    "{\n",
    "    \"figure_pdf\": str(figure_pdf.relative_to(repo_root)),\n",
    "    \"figure_png\": str(figure_png.relative_to(repo_root)),\n",
    "}\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tmp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
