{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f578b3e9-ee1b-4442-8a9f-69e064cc11ad",
   "metadata": {},
   "source": [
    "# Experiments for the 2D diagonal model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8eb2ce6-229c-405c-955a-b3ae59c42594",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numpy.linalg as la\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.legend_handler import HandlerTuple\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import grad, jit, hessian\n",
    "from jax import random\n",
    "from jax.scipy.linalg import eigh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "842b9593-b527-4a09-b28d-86902b34cd2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# matplotlib widget\n",
    "# Alternative for non-interactive plots:\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b04ca7-4ab6-4343-9c1b-849934a4ef87",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x, N):\n",
    "    return np.power(x, N)\n",
    "\n",
    "\n",
    "def grad_f(x, N):\n",
    "    return N * np.power(x, N - 1)\n",
    "\n",
    "\n",
    "def Hess_f(x, N):\n",
    "    if N == 1:\n",
    "        return np.zeros([len(x), len(x)])\n",
    "    return N * (N - 1) * np.diag(np.power(x, N - 2))\n",
    "\n",
    "\n",
    "def H1(x, N, ATA, ATy):\n",
    "    return Hess_f(x, N) @ np.diag((ATA @ f(x, N) - ATy))\n",
    "\n",
    "\n",
    "def H2(x, N, ATA):\n",
    "    return np.diag(grad_f(x, N)) @ ATA @ np.diag(grad_f(x, N))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88308f1c-4fd3-4de3-8d91-669a7f47216d",
   "metadata": {},
   "outputs": [],
   "source": [
    "### MODEL ###\n",
    "\n",
    "\n",
    "class DiagonalNetwork:\n",
    "\n",
    "    def __init__(self, A, y, L, u0, custom_loss=None):\n",
    "        self.L = L\n",
    "        self.A = A\n",
    "        self.y = y\n",
    "        self.alpha = alpha\n",
    "\n",
    "        self.u = np.array(u0)\n",
    "\n",
    "        # JIT-compiled\n",
    "        self.jit_loss = jax.jit(self.loss)\n",
    "        self.jit_update = self.update\n",
    "        if custom_loss is not None:\n",
    "            self.loss_ = custom_loss\n",
    "\n",
    "    def set_u(self, u):\n",
    "        self.u = u\n",
    "\n",
    "    def get_u(self):\n",
    "        return self.u\n",
    "\n",
    "    def loss(self, u):\n",
    "        return 1 / 2 * jnp.linalg.norm((self.A @ jnp.power(u, self.L) - self.y)) ** 2\n",
    "\n",
    "    def custom_update(self, eta=1e-3):\n",
    "        grad_fn = grad(self.loss_)\n",
    "\n",
    "        grads = grad_fn(self.u, self.A, self.L, self.y)\n",
    "\n",
    "        # Perform gradient descent update\n",
    "        self.u = self.u - eta * grads\n",
    "\n",
    "        hess_fn = hessian(self.loss_)\n",
    "\n",
    "        hess = hess_fn(self.u, self.A, self.L, self.y)\n",
    "        return la.norm(hess, ord=2)\n",
    "\n",
    "    def shaps(self, u):\n",
    "        hess_fn = hessian(self.loss)\n",
    "        hess = hess_fn(u)\n",
    "        return la.norm(hess, ord=2)\n",
    "\n",
    "    def update(self, eta=1e-3):\n",
    "        grad_fn = grad(self.loss)\n",
    "\n",
    "        grads = grad_fn(self.u)\n",
    "\n",
    "        # Perform gradient descent update\n",
    "        self.u = self.u - eta * grads\n",
    "\n",
    "        hess_fn = hessian(self.loss)\n",
    "\n",
    "        hess = hess_fn(self.u)\n",
    "\n",
    "        return la.norm(hess, ord=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43487c63-3fe0-47b2-a3d1-991434f7180c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, eta, steps=100, ls=[], s_H1=[], s_H2=[], s_all=[], us=[]):\n",
    "    for iter in range(steps):\n",
    "        ev = model.jit_update(eta=eta)\n",
    "        s_all.append(ev)\n",
    "        us.append(model.get_u())\n",
    "\n",
    "        s_H1.append(la.norm(H1(model.u, L, ATA, ATy), 2))\n",
    "        s_H2.append(la.norm(H2(model.u, L, ATA), 2))\n",
    "        loss = model.jit_loss(model.u)\n",
    "        ls.append(loss)\n",
    "        if (iter + 1) % 100 == 0:\n",
    "            print(f\"Epoch [{iter + 1}/{max_iter}], Loss: {loss:.6f}\")\n",
    "        if loss < 0.0000001:\n",
    "            print(\"Loss goal satisfied with loss \" + str(loss))\n",
    "            return model, ls, s_H1, s_H2, s_all, us\n",
    "\n",
    "    return model, ls, s_H1, s_H2, s_all, us"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f41b58-9653-4a0a-a372-3af0be2a2b0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_custom(model, eta, steps=100, ls=[], s_all=[]):\n",
    "    for iter in range(steps):\n",
    "        ev = model.custom_update(eta=eta)\n",
    "        s_all.append(ev)\n",
    "\n",
    "        loss = model.loss_(model.u, model.A, model.L, model.y)\n",
    "        ls.append(loss)\n",
    "\n",
    "        if (iter + 1) % 100 == 0:\n",
    "            print(f\"Epoch [{iter + 1}/{max_iter}], Loss: {loss:.6f}\")\n",
    "    return model, ls, s_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07313e3a-b3ca-4780-89a6-2997fa720a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "### PARAMS ###\n",
    "seed = 4\n",
    "keys = random.split(random.PRNGKey(seed), 2)\n",
    "L = 2\n",
    "\n",
    "x_dim = 2\n",
    "\n",
    "y_dim = 1\n",
    "print(\"(x,y) Dimension = ({},{})\".format(x_dim, y_dim))\n",
    "\n",
    "A = np.array([[0.5, 3]])\n",
    "y = np.array([1])\n",
    "alpha = [0.01, 0.001]\n",
    "\n",
    "max_iter = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f860360c-d688-4010-a4f0-177e1d241980",
   "metadata": {},
   "outputs": [],
   "source": [
    "ATA = A.T @ A\n",
    "ATy = A.T @ y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da083461-d1a4-46c4-8066-6bb0a64cdb5a",
   "metadata": {},
   "source": [
    "## ComputeExperiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "928dc9f2-5308-4a80-a7bd-b866467dd883",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### PARAMS ####\n",
    "\n",
    "lower = 2 / 11.990035057067871\n",
    "upper = 0.3330094722598106\n",
    "gap = (upper - lower) / 8\n",
    "etas = np.concatenate(([0.001], np.linspace(lower - 3 * gap, upper, 12)))\n",
    "max_iter = 20000  # 500\n",
    "\n",
    "print(\"Learning rates to be tested: \", etas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b104142a-327e-4208-8d68-a15c70e95f6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_experiments = True\n",
    "load_experiments = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e703aa89-34b6-4e8b-9f1e-efe61799e7cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_experiments:\n",
    "    ls_s = []\n",
    "    s_H1_s = []\n",
    "    s_H2_s = []\n",
    "    s_all_s = []\n",
    "    us_s = []\n",
    "    for eta_ in etas:\n",
    "        ls_s_ = []\n",
    "        s_H1_s_ = []\n",
    "        s_H2_s_ = []\n",
    "        s_all_s_ = []\n",
    "        us_s_ = []\n",
    "        for alpha_ in [alpha]:\n",
    "            # run the optimization\n",
    "            print(f\"Alpha {alpha_}, Eta: {eta_}\")\n",
    "\n",
    "            model = DiagonalNetwork(A, y, L, alpha_)\n",
    "            # maybe max_iter/eta (so more steps for smaller eta)?\n",
    "            model, ls, s_H1, s_H2, s_all, us = train(\n",
    "                model,\n",
    "                eta_,\n",
    "                steps=max_iter,\n",
    "                ls=[],\n",
    "                s_H1=[],\n",
    "                s_H2=[],\n",
    "                s_all=[],\n",
    "                us=[model.get_u()],\n",
    "            )\n",
    "            print(f\"Sharpness: {s_all[-1]}\")\n",
    "            ls_s_.append(ls)\n",
    "            s_H1_s_.append(s_H1)\n",
    "            s_H2_s_.append(s_H2)\n",
    "            s_all_s_.append(s_all)\n",
    "            us_s_.append(us)\n",
    "        ls_s.append(ls_s_)\n",
    "        s_H1_s.append(s_H1_s_)\n",
    "        s_H2_s.append(s_H2_s_)\n",
    "        s_all_s.append(s_all_s_)\n",
    "        us_s.append(us_s_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0092fe06-26af-4d1e-993d-2c478abea72d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_experiments:\n",
    "    np.savez(\n",
    "        \"experiment_data.npz\",\n",
    "        ls_s=np.array(ls_s, dtype=object),\n",
    "        s_all_s=np.array(s_all_s, dtype=object),\n",
    "        us_s=np.array(us_s, dtype=object),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b68a5d4f-00b2-46a4-8bb0-647bc38e7214",
   "metadata": {},
   "outputs": [],
   "source": [
    "if load_experiments:\n",
    "    data = np.load(\"experiment_data.npz\", allow_pickle=True)\n",
    "    print(data)\n",
    "    ls_s = data[\"ls_s\"]\n",
    "    s_all_s = data[\"s_all_s\"]\n",
    "    us_s = data[\"us_s\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5318fa2a-dc37-4950-8ccc-45f6b011617c",
   "metadata": {},
   "source": [
    "## Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2028aa3-3c61-4514-8d41-ac822fd5fff6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Warning: the smallest eta (interpreted as the flow) gets removed from the list of etas\n",
    "etas = etas[1:]\n",
    "max_idx = np.argmax([s for s in s_all_s[0][0]]).item()\n",
    "flow_shap = s_all_s[0][0][max_idx]\n",
    "flow_shap_last = s_all_s[0][0][-1]\n",
    "shaps = np.array([s_all_s[i + 1][0][-1] for i in range(len(s_all_s) - 1)])\n",
    "maxshaps = np.array([np.max(s_all_s[i + 1][0]) for i in range(len(s_all_s) - 1)])\n",
    "flow_norm = np.linalg.norm(us_s[0][0][-1], ord=1)\n",
    "norms1 = np.array(\n",
    "    [np.linalg.norm(us_s[i + 1][0][-1], ord=1) for i in range(len(us_s) - 1)]\n",
    ")\n",
    "flow_loss = ls_s[0][0][-1]\n",
    "losses = np.array([ls_s[i + 1][0][-1] for i in range(len(ls_s) - 1)])\n",
    "dists = np.array(\n",
    "    [\n",
    "        np.linalg.norm(us_s[i + 1][0][-1] - us_s[0][0][-1], ord=1)\n",
    "        for i in range(len(us_s) - 1)\n",
    "    ]\n",
    ")\n",
    "iterations = np.array([len(us_s[i + 1][0]) for i in range(len(us_s) - 1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d8ef672-fa9a-433b-8832-f829061622cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = {\n",
    "    \"converged\": \"#00CC03\",\n",
    "    \"converged_last\": \"darkgreen\",\n",
    "    \"maximum\": \"black\",\n",
    "    \"bound\": \"#0081D1\",\n",
    "    \"flow\": \"#FF7F0F\",\n",
    "    \"heuristic\": \"#CC00F5\",\n",
    "    \"goal\": \"#70C8FF\",\n",
    "    \"a\": \"#F00A02\",\n",
    "    \"b\": \"#0C7EC3\",\n",
    "    \"c\": \"#1ABA1E\",\n",
    "    \"d\": \"#085E09\",\n",
    "}\n",
    "sizes = {\"converged\": 100, \"maximum\": 133, \"maxwidth\": 1.5, \"bound\": 160}\n",
    "linew = 3\n",
    "tmp = np.abs(losses)\n",
    "mask = tmp <= 0.0001\n",
    "plt.rcParams[\"font.size\"] = 16\n",
    "plt.rcParams[\"mathtext.fontset\"] = \"stix\"\n",
    "plt.rcParams[\"font.family\"] = \"STIXGeneral\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c373bcb2-ba0a-4372-8457-65f378912445",
   "metadata": {},
   "outputs": [],
   "source": [
    "general_captions = True\n",
    "gf_mode = \"max\"\n",
    "if general_captions:\n",
    "    glabel = r\"final value\"\n",
    "    mlabel = \"value at max sharpness\"\n",
    "    if gf_mode == \"max\":\n",
    "        flabel = r\"GF value at max sharpness\"\n",
    "        flabel_last = r\"final GF sharpness\"\n",
    "\n",
    "    else:\n",
    "        flabel = r\"final GF value\"\n",
    "else:\n",
    "    glabel = r\"final sharpness\"\n",
    "    mlabel = \"maximum value\" + smooth_label\n",
    "    if gf_mode == \"max\":\n",
    "        flabel = r\"max GF sharpness ($s_{GF}$)\"\n",
    "        flabel_last = r\"final GF sharpness\"\n",
    "\n",
    "    else:\n",
    "        flabel = r\"final GF sharpness\"\n",
    "        flabel_last = r\"final GF sharpness\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d68ec8d-dea4-40a5-a9ed-7da2fec9f04e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generalization(w_):\n",
    "    return (\n",
    "        1 / 2 * np.linalg.norm(w_, ord=4) ** 4\n",
    "        - np.linalg.norm(w_, ord=2) ** 2\n",
    "        + 1 / 2 * 3\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "864f0e80-ccdd-42f3-bafd-bc1812c9bf86",
   "metadata": {},
   "outputs": [],
   "source": [
    "w = np.zeros((len(etas), 2, 2))\n",
    "Aa = A.flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0514ba97-be04-43d9-ab18-edd730aed43d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fac = 1\n",
    "gens_points = []\n",
    "gens_true = []\n",
    "l1_true = []\n",
    "for i in range(len(etas)):\n",
    "    w_1 = us_s[i + 1][0][-1]\n",
    "    generalization_w_1 = generalization(w_1)\n",
    "    gens_points.append(generalization_w_1)\n",
    "\n",
    "    det = la.norm(Aa, ord=2) ** 2 * la.norm(Aa, ord=4) ** 4 - la.norm(Aa, ord=3) ** 6\n",
    "    lam2 = (\n",
    "        (\n",
    "            y * la.norm(Aa, ord=3) ** 3\n",
    "            + la.norm(Aa, ord=2) ** 4\n",
    "            - la.norm(Aa, ord=1) * la.norm(Aa, ord=3) ** 3\n",
    "            - 1 / (2 * etas[i]) * la.norm(Aa, ord=2) ** 2\n",
    "        )\n",
    "        / (2 * etas[i] * det)\n",
    "    )[0]\n",
    "    nu2 = (\n",
    "        -(\n",
    "            y * la.norm(Aa, ord=4) ** 4\n",
    "            + la.norm(Aa, ord=3) ** 3 * la.norm(Aa, ord=2) ** 2\n",
    "            - la.norm(Aa, ord=1) * la.norm(Aa, ord=4) ** 4\n",
    "            - la.norm(Aa, ord=3) ** 3 / (2 * etas[i])\n",
    "        )\n",
    "        / det\n",
    "    )\n",
    "    w0 = fac * np.sqrt(1 - 2 * lam2 * etas[i] * Aa[0] ** 2 - nu2 * Aa[0])\n",
    "    w1 = fac * np.sqrt(1 - 2 * lam2 * etas[i] * Aa[1] ** 2 - nu2 * Aa[1])\n",
    "    ww = np.array([w0, w1]).flatten()\n",
    "    gens_true.append(generalization(ww))\n",
    "    l1_true.append(np.linalg.norm(ww, ord=1))\n",
    "gens_true = np.array(gens_true)\n",
    "l1_true = np.array(l1_true)\n",
    "gens_points = np.array(gens_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c075cf3-28eb-4595-9a24-5342e7fcf79a",
   "metadata": {},
   "outputs": [],
   "source": [
    "etas_dense = np.linspace(0, 1 / np.sqrt(2), 500)[0:]\n",
    "\n",
    "ws = []\n",
    "for i in range(len(etas_dense)):\n",
    "    det = la.norm(Aa, ord=2) ** 2 * la.norm(Aa, ord=4) ** 4 - la.norm(Aa, ord=3) ** 6\n",
    "    lam2 = (\n",
    "        (\n",
    "            y * la.norm(Aa, ord=3) ** 3\n",
    "            + la.norm(Aa, ord=2) ** 4\n",
    "            - la.norm(Aa, ord=1) * la.norm(Aa, ord=3) ** 3\n",
    "            - 1 / (2 * etas_dense[i]) * la.norm(Aa, ord=2) ** 2\n",
    "        )\n",
    "        / (2 * etas_dense[i] * det)\n",
    "    )[0]\n",
    "    nu2 = (\n",
    "        -(\n",
    "            y * la.norm(Aa, ord=4) ** 4\n",
    "            + la.norm(Aa, ord=3) ** 3 * la.norm(Aa, ord=2) ** 2\n",
    "            - la.norm(Aa, ord=1) * la.norm(Aa, ord=4) ** 4\n",
    "            - la.norm(Aa, ord=3) ** 3 / (2 * etas_dense[i])\n",
    "        )\n",
    "        / det\n",
    "    )\n",
    "    w0 = fac * np.sqrt(1 - 2 * lam2 * etas_dense[i] * Aa[0] ** 2 - nu2 * Aa[0])\n",
    "    w1 = fac * np.sqrt(1 - 2 * lam2 * etas_dense[i] * Aa[1] ** 2 - nu2 * Aa[1])\n",
    "    ws.append(np.array([w0, w1]).flatten())\n",
    "\n",
    "gens_dense = []\n",
    "l1_dense = []\n",
    "for ww in ws:\n",
    "    generalization_ww = generalization(ww)\n",
    "    gens_dense.append(generalization_ww)\n",
    "    l1_dense.append(np.linalg.norm(ww, ord=1))\n",
    "\n",
    "l1_dense = np.array(l1_dense)\n",
    "gens_dense = np.array(gens_dense)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1bafda3-5eb7-43d1-bcd0-501612b383df",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7, 4))\n",
    "mask2 = np.array(shaps <= 0)\n",
    "t1 = np.linspace(\n",
    "    np.max((0, np.min(etas[~mask2]) - 0.01 * (1 / np.sqrt(2) - np.min(etas[~mask2])))),\n",
    "    1 / np.sqrt(2),\n",
    "    500,\n",
    ")\n",
    "if t1[0] == 0:\n",
    "    t1 = t1[1:]\n",
    "plt.plot(t1, 2 / t1, color=colors[\"bound\"], zorder=0, linewidth=linew)  # , alpha=0.2)\n",
    "plt.scatter(\n",
    "    etas[~mask2],\n",
    "    2 / etas[~mask2],\n",
    "    color=colors[\"bound\"],\n",
    "    marker=\"x\",\n",
    "    zorder=3,\n",
    "    s=sizes[\"bound\"],\n",
    ")\n",
    "plt.plot(\n",
    "    [],\n",
    "    ls=\"-\",\n",
    "    marker=\"x\",\n",
    "    color=colors[\"bound\"],\n",
    "    label=\"theoretical prediction\",\n",
    "    zorder=3,\n",
    "    ms=8,\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axhline(\n",
    "    flow_shap_last,\n",
    "    label=\"final GF value\",\n",
    "    color=colors[\"converged_last\"],\n",
    "    alpha=0.5,\n",
    "    zorder=2,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "plt.axhline(\n",
    "    flow_shap,\n",
    "    label=\"GF at max sharpness\",\n",
    "    color=colors[\"flow\"],\n",
    "    zorder=1,\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axvline(\n",
    "    2 / flow_shap,\n",
    "    label=\"$2/s_{GF}$\",\n",
    "    color=colors[\"flow\"],\n",
    "    zorder=1,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "mask3 = gens_dense > 0\n",
    "min_eta_ix = np.argmin(gens_dense[mask3])\n",
    "plt.axvline(\n",
    "    etas_dense[mask3][min_eta_ix],\n",
    "    color=colors[\"goal\"],\n",
    "    linestyle=\"--\",\n",
    "    zorder=1,\n",
    "    label=\"best generalization\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axvline(\n",
    "    0.333510148849797,\n",
    "    label=\"empirical divergence\",\n",
    "    color=\"violet\",\n",
    "    zorder=1,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axvline(\n",
    "    1 / np.sqrt(2),\n",
    "    label=\"theoretical divergence\",\n",
    "    color=\"pink\",\n",
    "    zorder=1,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.scatter(\n",
    "    etas[mask & ~mask2],\n",
    "    shaps[mask & ~mask2],\n",
    "    c=colors[\"converged\"],\n",
    "    label=glabel,\n",
    "    zorder=4,\n",
    "    s=sizes[\"converged\"],\n",
    ")\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(\"sharpness\")\n",
    "flow_shap_np = np.array([flow_shap])\n",
    "lims = np.concatenate((shaps[~mask2], flow_shap_np))\n",
    "plt.ylim(\n",
    "    np.min(lims) - 0.1 * (np.max(lims) - np.min(lims)),\n",
    "    np.max(lims) + 0.2 * (np.max(lims) - np.min(lims)),\n",
    ")\n",
    "plt.grid(True, linestyle=\"-\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\n",
    "    \"diag_sharpness_no_legend.png\", dpi=300\n",
    ")  # Save the plot as a high-quality PNG file\n",
    "plt.legend(loc=\"upper right\")\n",
    "\n",
    "plt.savefig(\"diag_sharpness.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99c7a4e7-29bc-455f-a4ab-85c64fe49d12",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7, 4))\n",
    "mask = tmp <= 0.0001\n",
    "\n",
    "mask2 = np.array(norms1 <= 0)\n",
    "plt.plot(\n",
    "    [],\n",
    "    ls=\"-\",\n",
    "    marker=\"x\",\n",
    "    color=colors[\"bound\"],\n",
    "    label=\"theoretical prediction\",\n",
    "    zorder=3,\n",
    "    ms=8,\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axhline(\n",
    "    flow_norm,\n",
    "    label=\"final GF value\",\n",
    "    color=colors[\"converged_last\"],\n",
    "    alpha=0.5,\n",
    "    zorder=2,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "\n",
    "plt.axvline(\n",
    "    0.333510148849797,\n",
    "    label=\"empirical divergence\",\n",
    "    color=\"violet\",\n",
    "    zorder=1,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axvline(\n",
    "    1 / np.sqrt(2),\n",
    "    label=\"theoretical divergence\",\n",
    "    color=\"pink\",\n",
    "    zorder=1,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axvline(\n",
    "    2 / flow_shap,\n",
    "    label=\"$2/s_{GF}$\",\n",
    "    color=colors[\"flow\"],\n",
    "    zorder=0,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.scatter(\n",
    "    etas[mask & ~mask2],\n",
    "    norms1[mask & ~mask2],\n",
    "    c=colors[\"converged\"],\n",
    "    label=glabel,\n",
    "    zorder=3,\n",
    "    s=sizes[\"converged\"],\n",
    ")\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(r\"$\\ell 1$-norm\")\n",
    "plt.grid(True, linestyle=\"-\")\n",
    "\n",
    "plt.plot(etas_dense, l1_dense, c=colors[\"bound\"], zorder=1, linewidth=linew)\n",
    "plt.scatter(\n",
    "    etas[~mask2],\n",
    "    l1_true[~mask2],\n",
    "    color=colors[\"bound\"],\n",
    "    marker=\"x\",\n",
    "    zorder=2,\n",
    "    s=sizes[\"bound\"],\n",
    ")\n",
    "\n",
    "\n",
    "mask3 = gens_dense > 0\n",
    "min_eta_ix = np.argmin(gens_dense[mask3])\n",
    "plt.axvline(\n",
    "    etas_dense[mask3][min_eta_ix],\n",
    "    color=colors[\"goal\"],\n",
    "    linestyle=\"--\",\n",
    "    zorder=1,\n",
    "    label=\"best generalization\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"diag_norm_l1_no_legend.png\", dpi=300)\n",
    "plt.legend(loc=\"upper right\")\n",
    "plt.savefig(\"diag_norm_l1.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba66a0fa-75cc-4bc8-bbf3-bcd072c068ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7, 4))\n",
    "w_I1 = np.array([0, 1 / np.sqrt(3)])\n",
    "generalization_w_I1 = generalization(w_I1)\n",
    "generalization_flow = generalization(us_s[0][0][-1])\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(\"risk\")\n",
    "plt.plot(\n",
    "    [],\n",
    "    ls=\"-\",\n",
    "    marker=\"x\",\n",
    "    color=colors[\"bound\"],\n",
    "    label=\"theoretical prediction\",\n",
    "    zorder=3,\n",
    "    ms=8,\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.scatter(\n",
    "    etas,\n",
    "    gens_points,\n",
    "    c=colors[\"converged\"],\n",
    "    zorder=4,\n",
    "    s=sizes[\"converged\"],\n",
    "    label=\"final value\",\n",
    ")\n",
    "plt.axvline(\n",
    "    0.333510148849797,\n",
    "    label=\"empirical divergence\",\n",
    "    color=\"violet\",\n",
    "    zorder=1,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axvline(\n",
    "    1 / np.sqrt(2),\n",
    "    label=\"theoretical divergence\",\n",
    "    color=\"pink\",\n",
    "    zorder=1,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axvline(\n",
    "    2 / flow_shap,\n",
    "    label=\"$2/s_{GF}$\",\n",
    "    color=colors[\"flow\"],\n",
    "    zorder=2,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "\n",
    "plt.plot(etas_dense, gens_dense, c=colors[\"bound\"], zorder=0, linewidth=linew)\n",
    "plt.scatter(\n",
    "    etas[~mask2],\n",
    "    gens_true[~mask2],\n",
    "    color=colors[\"bound\"],\n",
    "    marker=\"x\",\n",
    "    zorder=3,\n",
    "    s=sizes[\"bound\"],\n",
    ")\n",
    "\n",
    "\n",
    "plt.axhline(\n",
    "    generalization_w_I1,\n",
    "    label=\"final GF value\",\n",
    "    color=colors[\"converged_last\"],\n",
    "    alpha=0.5,\n",
    "    zorder=2,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "\n",
    "mask3 = gens_dense > 0\n",
    "min_eta_ix = np.argmin(gens_dense[mask3])\n",
    "plt.axvline(\n",
    "    etas_dense[mask3][min_eta_ix],\n",
    "    color=colors[\"goal\"],\n",
    "    linestyle=\"--\",\n",
    "    zorder=1,\n",
    "    label=\"best generalization\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "\n",
    "plt.grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"diag_generalization_no_legend.png\", dpi=300)\n",
    "plt.legend(loc=\"upper right\")\n",
    "plt.savefig(f\"diag_generalization.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b890d863-b64b-43f0-864f-71a15213f74b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solution set computation\n",
    "goal = []\n",
    "for i in range(y_dim):\n",
    "    t1 = np.arange(-2.8, 2.8, 0.001)\n",
    "    lab = \"Coordinate goal\"\n",
    "    if A[i, 1] != 0:\n",
    "        if L % 2 == 0:\n",
    "            t1 = [t for t in t1 if (y[i] - A[i, 0] * t**L) / A[i, 1] >= 0]\n",
    "        goal = goal + list(\n",
    "            zip(t1, [((y[i] - A[i, 0] * t**L) / A[i, 1]) ** (1 / L) for t in t1])\n",
    "        )\n",
    "        if L % 2 == 0:\n",
    "            goal = goal + list(\n",
    "                reversed(\n",
    "                    list(\n",
    "                        zip(\n",
    "                            t1,\n",
    "                            [\n",
    "                                -(((y[i] - A[i, 0] * t**L) / A[i, 1]) ** (1 / L))\n",
    "                                for t in t1\n",
    "                            ],\n",
    "                        )\n",
    "                    )\n",
    "                )\n",
    "            )\n",
    "    elif A[i, 0] != 0:\n",
    "        if L % 2 == 0:\n",
    "            t1 = [t for t in t1 if (y[i] - A[i, 1] * t**L) / A[i, 0] >= 0]\n",
    "        goal = goal + list(\n",
    "            zip([((y[i] - A[i, 1] * t**L) / A[i, 0]) ** (1 / L) for t in t1], t1)\n",
    "        )\n",
    "        if L % 2 == 0:\n",
    "            goal = goal + list(\n",
    "                reversed(\n",
    "                    list(\n",
    "                        zip(\n",
    "                            [\n",
    "                                -(((y[i] - A[i, 1] * t**L) / A[i, 0]) ** (1 / L))\n",
    "                                for t in t1\n",
    "                            ],\n",
    "                            t1,\n",
    "                        )\n",
    "                    )\n",
    "                )\n",
    "            )\n",
    "goal.append(goal[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4a91e73-b254-4320-9424-87f94830d54e",
   "metadata": {},
   "source": [
    "Warning: The following cell can take a LONG time (depending mainly on the point density ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e6e812e-a03b-4139-b3a8-2ef4c2ab6770",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sharpness map computation\n",
    "lrx = -1.7\n",
    "rrx = 1.8\n",
    "lry = -0.7\n",
    "rry = 0.8\n",
    "ds = 0.01  # 005\n",
    "xx = np.arange(lrx, rrx + 0.01, ds)\n",
    "yy = np.arange(lry, rry + 0.01, ds)\n",
    "shapmap = np.zeros((len(yy), len(xx)))\n",
    "for i in range(len(xx)):\n",
    "    for j in range(len(yy)):\n",
    "        ww = np.array([xx[i], yy[j]])\n",
    "        diag_vec = 2 * (np.dot(Aa, ww**2) - y) * Aa\n",
    "        H = np.diag(diag_vec) + 4 * np.outer(Aa * ww, Aa * ww)\n",
    "        shapmap[j, i] = np.linalg.norm(H, ord=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58b27486-3a18-4f2a-ae47-38c49656af0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_shapmaps = []\n",
    "for eta in etas:\n",
    "    threshold_value = 2 / eta\n",
    "    tolerance = 0.8\n",
    "    lower_bound = threshold_value - tolerance\n",
    "    upper_bound = threshold_value + tolerance\n",
    "    mask = (shapmap >= lower_bound) & (shapmap <= upper_bound)\n",
    "    shapm2 = shapmap.copy()\n",
    "    shapm2[mask] = 0\n",
    "    shapm2[~mask] = 1\n",
    "\n",
    "    masked_shapmaps.append(shapm2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e2f04a-f5e8-4587-83e2-c17d2c514cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "shapquant = []\n",
    "for k, eta_ in enumerate(etas):\n",
    "    shapq = np.zeros((len(yy), len(xx)))\n",
    "    for i in range(len(xx)):\n",
    "        for j in range(len(yy)):\n",
    "            if shapmap[j][i] <= 1 / etas[k]:\n",
    "                shapq[j][i] = 2\n",
    "            elif shapmap[j][i] <= 2 / etas[k]:\n",
    "                shapq[j][i] = 2\n",
    "            else:\n",
    "                shapq[j][i] = 3\n",
    "    shapquant.append(shapq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05009c12-33d4-43dc-b214-83ab60613f43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sharpness bound ellipse computation\n",
    "bounds = []\n",
    "for eta in etas:\n",
    "    bound = []\n",
    "    for i in range(y_dim):\n",
    "        t1 = np.arange(-2.5, 2.5, 0.01)\n",
    "        if A[i, 1] != 0:\n",
    "            if L % 2 == 0:\n",
    "                t1 = [\n",
    "                    t\n",
    "                    for t in t1\n",
    "                    if ((2 / eta) - A[i, 0] ** 2 * t**L) / A[i, 1] ** 2 >= 0\n",
    "                ]\n",
    "            bound = bound + list(\n",
    "                zip(\n",
    "                    t1,\n",
    "                    [\n",
    "                        (((2 / eta) - A[i, 0] ** 2 * t**L) / A[i, 1] ** 2) ** (1 / L)\n",
    "                        for t in t1\n",
    "                    ],\n",
    "                )\n",
    "            )\n",
    "            if L % 2 == 0:\n",
    "                bound = bound + list(\n",
    "                    reversed(\n",
    "                        list(\n",
    "                            zip(\n",
    "                                t1,\n",
    "                                [\n",
    "                                    -(\n",
    "                                        (\n",
    "                                            ((2 / eta) - A[i, 0] ** 2 * t**L)\n",
    "                                            / A[i, 1] ** 2\n",
    "                                        )\n",
    "                                        ** (1 / L)\n",
    "                                    )\n",
    "                                    for t in t1\n",
    "                                ],\n",
    "                            )\n",
    "                        )\n",
    "                    )\n",
    "                )\n",
    "        elif A[i, 0] != 0:\n",
    "            if L % 2 == 0:\n",
    "                t1 = [\n",
    "                    t\n",
    "                    for t in t1\n",
    "                    if ((2 / eta) - A[i, 1] ** 2 * t**L) / A[i, 0] ** 2 >= 0\n",
    "                ]\n",
    "            bound = bound + list(\n",
    "                zip(\n",
    "                    [\n",
    "                        (((2 / eta) - A[i, 1] ** 2 * t**L) / A[i, 0] ** 2) ** (1 / L)\n",
    "                        for t in t1\n",
    "                    ],\n",
    "                    t1,\n",
    "                )\n",
    "            )\n",
    "            if L % 2 == 0:\n",
    "                bound = bound + list(\n",
    "                    reversed(\n",
    "                        list(\n",
    "                            zip(\n",
    "                                [\n",
    "                                    -(\n",
    "                                        (\n",
    "                                            ((2 / eta) - A[i, 1] ** 2 * t**L)\n",
    "                                            / A[i, 0] ** 2\n",
    "                                        )\n",
    "                                        ** (1 / L)\n",
    "                                    )\n",
    "                                    for t in t1\n",
    "                                ],\n",
    "                                t1,\n",
    "                            )\n",
    "                        )\n",
    "                    )\n",
    "                )\n",
    "    bounds.append(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0976033-c275-46e3-8f8a-c3e46d17978d",
   "metadata": {},
   "outputs": [],
   "source": [
    "w = np.zeros((len(etas), 2, 2))\n",
    "Aa = A.flatten()\n",
    "\n",
    "nu1 = (la.norm(Aa, ord=1) - y) / la.norm(Aa, ord=2) ** 2\n",
    "for i in range(len(etas)):\n",
    "    w[i][0][0] = np.sqrt((1 - nu1 * A[0, 0])[0])\n",
    "    w[i][0][1] = np.sqrt((1 - nu1 * A[0, 1])[0])\n",
    "    if 2 * etas[i] * ((w[i][0][0] * Aa[0]) ** 2 + (w[i][0][1] * Aa[1]) ** 2) - 1 > 0:\n",
    "        w[i][0][0] = np.nan\n",
    "        w[i][0][1] = np.nan\n",
    "        print(\"not satisfied for \", etas[i])\n",
    "\n",
    "fac = 1\n",
    "for i in range(len(etas)):\n",
    "    det = la.norm(Aa, ord=2) ** 2 * la.norm(Aa, ord=4) ** 4 - la.norm(Aa, ord=3) ** 6\n",
    "    lam2 = (\n",
    "        (\n",
    "            y * la.norm(Aa, ord=3) ** 3\n",
    "            + la.norm(Aa, ord=2) ** 4\n",
    "            - la.norm(Aa, ord=1) * la.norm(Aa, ord=3) ** 3\n",
    "            - 1 / (2 * etas[i]) * la.norm(Aa, ord=2) ** 2\n",
    "        )\n",
    "        / (2 * etas[i] * det)\n",
    "    )[0]\n",
    "    nu2 = (\n",
    "        -(\n",
    "            y * la.norm(Aa, ord=4) ** 4\n",
    "            + la.norm(Aa, ord=3) ** 3 * la.norm(Aa, ord=2) ** 2\n",
    "            - la.norm(Aa, ord=1) * la.norm(Aa, ord=4) ** 4\n",
    "            - la.norm(Aa, ord=3) ** 3 / (2 * etas[i])\n",
    "        )\n",
    "        / det\n",
    "    )\n",
    "    w[i][1][0] = fac * np.sqrt(1 - 2 * lam2 * etas[i] * Aa[0] ** 2 - nu2 * Aa[0])\n",
    "    w[i][1][1] = fac * np.sqrt(1 - 2 * lam2 * etas[i] * Aa[1] ** 2 - nu2 * Aa[1])\n",
    "    print(\n",
    "        \"First eq 0 = \",\n",
    "        la.norm(Aa, ord=1)\n",
    "        - 2 * etas[i] * lam2 * la.norm(Aa, ord=3) ** 3\n",
    "        - nu2 * la.norm(Aa, ord=2) ** 2\n",
    "        - y,\n",
    "    )\n",
    "    print(\n",
    "        \"Second eq  0 = \",\n",
    "        la.norm(Aa, ord=2) ** 2\n",
    "        - 2 * etas[i] * lam2 * la.norm(Aa, ord=4) ** 4\n",
    "        - nu2 * la.norm(Aa, ord=3) ** 3\n",
    "        - 1 / (2 * etas[i]),\n",
    "    )\n",
    "\n",
    "    if not (np.isnan(w[i][1][0]) or np.isnan(w[i][1][1])):\n",
    "        print(\"\\n Checking for eta=\", etas[i])\n",
    "        print(w[i])\n",
    "        print(\"lam=\", lam2)\n",
    "        print(\n",
    "            \"Equality constraint 0 = \",\n",
    "            w[i][1][0] ** 2 * Aa[0] + w[i][1][1] ** 2 * Aa[1] - y,\n",
    "        )\n",
    "        print(\n",
    "            \"Inquality constraint 0 >= \",\n",
    "            2 * etas[i] * ((w[i][1][0] * Aa[0]) ** 2 + (w[i][1][1] * Aa[1]) ** 2) - 1,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77c0cced-3b03-428f-bee0-bea15c1e752c",
   "metadata": {},
   "outputs": [],
   "source": [
    "theta = np.linspace(0, 2 * np.pi, 400)\n",
    "k = 0\n",
    "for j in range(len(etas))[1:]:  # [3,7,-1]:\n",
    "    plt.figure(figsize=(7, 4))\n",
    "    plt.imshow(\n",
    "        shapmap,\n",
    "        cmap=\"Blues\",\n",
    "        interpolation=\"spline16\",\n",
    "        origin=\"lower\",\n",
    "        extent=(\n",
    "            lrx - 0.5 / len(xx),\n",
    "            rrx - 0.5 / len(xx),\n",
    "            lry - 0.5 / len(yy),\n",
    "            rry - 0.5 / len(yy),\n",
    "        ),\n",
    "    )\n",
    "    plt.imshow(\n",
    "        masked_shapmaps[j],\n",
    "        cmap=\"tab20c\",\n",
    "        interpolation=\"nearest\",\n",
    "        origin=\"lower\",\n",
    "        extent=(\n",
    "            lrx - 0.5 / len(xx),\n",
    "            rrx - 0.5 / len(xx),\n",
    "            lry - 0.5 / len(yy),\n",
    "            rry - 0.5 / len(yy),\n",
    "        ),\n",
    "        alpha=0.4,\n",
    "    )\n",
    "    plt.axhline(0, color=\"gray\", alpha=0.5)\n",
    "    plt.axvline(0, color=\"gray\", alpha=0.5)\n",
    "    w1_E1 = np.sqrt(y / Aa[0]) * np.cos(theta)\n",
    "    w2_E1 = np.sqrt(y / Aa[1]) * np.sin(theta)\n",
    "\n",
    "    w1_E2 = (1 / (np.sqrt(2 * etas[j]) * Aa[0])) * np.cos(theta)\n",
    "    w2_E2 = (1 / (np.sqrt(2 * etas[j]) * Aa[1])) * np.sin(theta)\n",
    "\n",
    "    (p1,) = plt.plot(\n",
    "        [g[0] for g in goal],\n",
    "        [g[1] for g in goal],\n",
    "        color=\"yellow\",\n",
    "        label=r\"solution manifold $\\mathcal{M}$\",\n",
    "        linewidth=linew,\n",
    "    )\n",
    "    plt.plot(\n",
    "        [u[0] for u in us_s[0][k]],\n",
    "        [u[1] for u in us_s[0][k]],\n",
    "        color=colors[\"flow\"],\n",
    "        alpha=0.8,\n",
    "        linewidth=4,\n",
    "        linestyle=(0, (1, 1)),\n",
    "    )\n",
    "    (p2,) = plt.plot(\n",
    "        [],\n",
    "        color=colors[\"flow\"],\n",
    "        label=\"GF\",\n",
    "        alpha=0.8,\n",
    "        linewidth=linew,\n",
    "        linestyle=(0, (1, 1)),\n",
    "    )\n",
    "    (p3,) = plt.plot(\n",
    "        [u[0] for u in us_s[j + 1][k]],\n",
    "        [u[1] for u in us_s[j + 1][k]],\n",
    "        color=colors[\"converged\"],\n",
    "        label=\"GD\",\n",
    "        alpha=0.8,\n",
    "        linewidth=linew,\n",
    "    )\n",
    "    p4 = plt.scatter(\n",
    "        [0, 0],\n",
    "        [1 / np.sqrt(3), -1 / np.sqrt(3)],\n",
    "        color=\"black\",\n",
    "        label=r\"$\\ell 1$-min. $\\mathcal{M}_{\\ell 1}$\",\n",
    "        marker=\"x\",\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "        zorder=3,\n",
    "    )\n",
    "    p5 = plt.scatter(\n",
    "        [np.sqrt(2), -np.sqrt(2)],\n",
    "        [0, 0],\n",
    "        label=r\"sharpness min. $\\mathcal{M}_{S_{\\mathcal{L}}}$\",\n",
    "        marker=\"o\",\n",
    "        facecolors=\"none\",\n",
    "        edgecolors=\"black\",\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "        zorder=4,\n",
    "    )\n",
    "    p6 = plt.scatter(\n",
    "        [-w[j][0][0], -w[j][0][0], w[j][0][0], w[j][0][0]],\n",
    "        [-w[j][0][1], w[j][0][1], -w[j][0][1], w[j][0][1]],\n",
    "        color=\"blueviolet\",\n",
    "        marker=\"x\",\n",
    "        zorder=5,\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "    )\n",
    "    p7 = plt.scatter(\n",
    "        [-w[j][1][0], -w[j][1][0], w[j][1][0], w[j][1][0]],\n",
    "        [-w[j][1][1], w[j][1][1], -w[j][1][1], w[j][1][1]],\n",
    "        facecolors=\"none\",\n",
    "        edgecolors=\"blueviolet\",\n",
    "        marker=\"o\",\n",
    "        zorder=5,\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "        label=\"other KKT points\",\n",
    "    )\n",
    "    (p8,) = plt.plot(\n",
    "        [], color=\"#85B5D9\", label=r\"sharpness bound $2/\\eta$\", linewidth=linew\n",
    "    )\n",
    "    plt.tight_layout()\n",
    "    plt.xlim(lrx, rrx)\n",
    "    plt.ylim(lry, rry)\n",
    "\n",
    "    plt.savefig(\"diag_iterates_\" + str(j) + \"_both_no_legend.png\", dpi=300)\n",
    "    plt.legend(\n",
    "        [p1, p2, p3, p4, p5, (p6, p7), p8],\n",
    "        [\n",
    "            r\"solution manifold $\\mathcal{M}$\",\n",
    "            \"GF\",\n",
    "            \"GD\",\n",
    "            r\"$\\ell 1$-min. $\\mathcal{M}_{\\ell 1}$\",\n",
    "            r\"sharpness min. $\\mathcal{M}_{S_{\\mathcal{L}}}$\",\n",
    "            \"other KKT points\",\n",
    "            r\"sharpness bound $2/\\eta$\",\n",
    "        ],\n",
    "        handler_map={tuple: HandlerTuple(ndivide=None)},\n",
    "        loc=\"upper right\",\n",
    "        bbox_to_anchor=(0.47, 0.967),\n",
    "    )  \n",
    "\n",
    "    plt.savefig(\"diag_iterates_\" + str(j) + \"_both.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d9124fb-b55d-4fe7-84d5-44f38c1a5f6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 0\n",
    "for j in [0]:  # [3,7,-1]:\n",
    "    plt.figure(figsize=(7, 4))\n",
    "    plt.imshow(\n",
    "        shapmap,\n",
    "        cmap=\"Blues\",\n",
    "        interpolation=\"spline16\",\n",
    "        origin=\"lower\",\n",
    "        extent=(\n",
    "            lrx - 0.5 / len(xx),\n",
    "            rrx - 0.5 / len(xx),\n",
    "            lry - 0.5 / len(yy),\n",
    "            rry - 0.5 / len(yy),\n",
    "        ),\n",
    "    )\n",
    "    plt.axhline(0, color=\"gray\", alpha=0.5)\n",
    "    plt.axvline(0, color=\"gray\", alpha=0.5)\n",
    "    w1_E1 = np.sqrt(y / Aa[0]) * np.cos(theta)\n",
    "    w2_E1 = np.sqrt(y / Aa[1]) * np.sin(theta)\n",
    "\n",
    "    w1_E2 = (1 / (np.sqrt(2 * etas[j]) * Aa[0])) * np.cos(theta)\n",
    "    w2_E2 = (1 / (np.sqrt(2 * etas[j]) * Aa[1])) * np.sin(theta)\n",
    "\n",
    "    (p1,) = plt.plot(\n",
    "        [g[0] for g in goal],\n",
    "        [g[1] for g in goal],\n",
    "        color=\"yellow\",\n",
    "        label=r\"solution manifold $\\mathcal{M}$\",\n",
    "        linewidth=linew,\n",
    "    )\n",
    "    plt.plot(\n",
    "        [u[0] for u in us_s[0][k]],\n",
    "        [u[1] for u in us_s[0][k]],\n",
    "        color=colors[\"flow\"],\n",
    "        alpha=0.8,\n",
    "        linewidth=4,\n",
    "        linestyle=(0, (1, 1)),\n",
    "    )\n",
    "    (p2,) = plt.plot(\n",
    "        [],\n",
    "        color=colors[\"flow\"],\n",
    "        label=\"GF\",\n",
    "        alpha=0.8,\n",
    "        linewidth=linew,\n",
    "        linestyle=(0, (1, 1)),\n",
    "    )\n",
    "    (p3,) = plt.plot(\n",
    "        [u[0] for u in us_s[j + 1][k]],\n",
    "        [u[1] for u in us_s[j + 1][k]],\n",
    "        color=colors[\"converged\"],\n",
    "        label=\"GD\",\n",
    "        alpha=0.8,\n",
    "        linewidth=linew,\n",
    "    )\n",
    "    p4 = plt.scatter(\n",
    "        [0, 0],\n",
    "        [1 / np.sqrt(3), -1 / np.sqrt(3)],\n",
    "        color=\"black\",\n",
    "        label=r\"$\\ell 1$-min. $\\mathcal{M}_{\\ell 1}$\",\n",
    "        marker=\"x\",\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "        zorder=3,\n",
    "    )\n",
    "    p5 = plt.scatter(\n",
    "        [np.sqrt(2), -np.sqrt(2)],\n",
    "        [0, 0],\n",
    "        label=r\"sharpness min. $\\mathcal{M}_{S_{\\mathcal{L}}}$\",\n",
    "        marker=\"o\",\n",
    "        facecolors=\"none\",\n",
    "        edgecolors=\"black\",\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "        zorder=4,\n",
    "    )\n",
    "    p6 = plt.scatter(\n",
    "        [-w[j][0][0], -w[j][0][0], w[j][0][0], w[j][0][0]],\n",
    "        [-w[j][0][1], w[j][0][1], -w[j][0][1], w[j][0][1]],\n",
    "        color=\"blueviolet\",\n",
    "        marker=\"x\",\n",
    "        zorder=5,\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "    )\n",
    "    p7 = plt.scatter(\n",
    "        [-w[j][1][0], -w[j][1][0], w[j][1][0], w[j][1][0]],\n",
    "        [-w[j][1][1], w[j][1][1], -w[j][1][1], w[j][1][1]],\n",
    "        facecolors=\"none\",\n",
    "        edgecolors=\"blueviolet\",\n",
    "        marker=\"o\",\n",
    "        zorder=5,\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "        label=\"other KKT points\",\n",
    "    )\n",
    "\n",
    "    (p8,) = plt.plot(\n",
    "        [], color=\"#85B5D9\", label=r\"sharpness bound $2/\\eta$\", linewidth=linew\n",
    "    )\n",
    "    # ax.plot([u[0] for u in us_s[2][k]], [u[1] for u in us_s[2][k]], color=\"red\", label=\"{}\".format(etas[2]), alpha = 0.4)\n",
    "    plt.tight_layout()\n",
    "    plt.xlim(lrx, rrx)\n",
    "    plt.ylim(lry, rry)\n",
    "\n",
    "    plt.savefig(\"diag_iterates_\" + str(j) + \"_both_no_legend.png\", dpi=300)\n",
    "    plt.legend(\n",
    "        [p1, p2, p3, p4, p5, (p6, p7), p8],\n",
    "        [\n",
    "            r\"solution manifold $\\mathcal{M}$\",\n",
    "            \"GF\",\n",
    "            \"GD\",\n",
    "            r\"$\\ell 1$-min. $\\mathcal{M}_{\\ell 1}$\",\n",
    "            r\"sharpness min. $\\mathcal{M}_{S_{\\mathcal{L}}}$\",\n",
    "            \"other KKT points\",\n",
    "            r\"sharpness bound $2/\\eta$\",\n",
    "        ],\n",
    "        handler_map={tuple: HandlerTuple(ndivide=None)},\n",
    "        loc=\"upper right\",\n",
    "        bbox_to_anchor=(0.47, 0.967),\n",
    "    )  # plt.legend(loc=\"lower left\")\n",
    "\n",
    "    plt.savefig(\"diag_iterates_v2_\" + str(j) + \"_both.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db8e4d79-32f6-49fb-8970-8046d61aea53",
   "metadata": {},
   "source": [
    "## Figure 4 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e38858a4-4638-45cc-bf9b-97eb6f281b55",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = [colors[\"converged\"] for eta in etas]\n",
    "cols[1] = \"darkorchid\"\n",
    "cols[4] = \"deeppink\"\n",
    "cols[7] = \"indianred\"\n",
    "cols[10] = \"sandybrown\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5f76bce-ff1c-4918-a868-636ea9562fd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 0\n",
    "for j in [1, 4, 7, 10]:  # [3,7,-1]:\n",
    "    plt.figure(figsize=(7, 4))\n",
    "    plt.imshow(\n",
    "        shapmap,\n",
    "        cmap=\"Blues\",\n",
    "        interpolation=\"spline16\",\n",
    "        origin=\"lower\",\n",
    "        extent=(\n",
    "            lrx - 0.5 / len(xx),\n",
    "            rrx - 0.5 / len(xx),\n",
    "            lry - 0.5 / len(yy),\n",
    "            rry - 0.5 / len(yy),\n",
    "        ),\n",
    "    )\n",
    "    plt.imshow(\n",
    "        masked_shapmaps[j],\n",
    "        cmap=\"tab20c\",\n",
    "        interpolation=\"nearest\",\n",
    "        origin=\"lower\",\n",
    "        extent=(\n",
    "            lrx - 0.5 / len(xx),\n",
    "            rrx - 0.5 / len(xx),\n",
    "            lry - 0.5 / len(yy),\n",
    "            rry - 0.5 / len(yy),\n",
    "        ),\n",
    "        alpha=0.4,\n",
    "    )\n",
    "    plt.axhline(0, color=\"gray\", alpha=0.5)\n",
    "    plt.axvline(0, color=\"gray\", alpha=0.5)\n",
    "    w1_E1 = np.sqrt(y / Aa[0]) * np.cos(theta)\n",
    "    w2_E1 = np.sqrt(y / Aa[1]) * np.sin(theta)\n",
    "\n",
    "    w1_E2 = (1 / (np.sqrt(2 * etas[j]) * Aa[0])) * np.cos(theta)\n",
    "    w2_E2 = (1 / (np.sqrt(2 * etas[j]) * Aa[1])) * np.sin(theta)\n",
    "\n",
    "    (p1,) = plt.plot(\n",
    "        [g[0] for g in goal],\n",
    "        [g[1] for g in goal],\n",
    "        color=\"yellow\",\n",
    "        label=r\"solution manifold $\\mathcal{M}$\",\n",
    "        linewidth=linew,\n",
    "    )\n",
    "    plt.plot(\n",
    "        [u[0] for u in us_s[0][k]],\n",
    "        [u[1] for u in us_s[0][k]],\n",
    "        color=colors[\"flow\"],\n",
    "        alpha=0.8,\n",
    "        linewidth=4,\n",
    "        linestyle=(0, (1, 1)),\n",
    "    )\n",
    "    (p2,) = plt.plot(\n",
    "        [],\n",
    "        color=colors[\"flow\"],\n",
    "        label=\"GF\",\n",
    "        alpha=0.8,\n",
    "        linewidth=linew,\n",
    "        linestyle=(0, (1, 1)),\n",
    "    )\n",
    "    (p3,) = plt.plot(\n",
    "        [u[0] for u in us_s[j + 1][k]],\n",
    "        [u[1] for u in us_s[j + 1][k]],\n",
    "        color=cols[j],\n",
    "        label=\"GD\",\n",
    "        alpha=0.8,\n",
    "        linewidth=linew,\n",
    "    )\n",
    "    p4 = plt.scatter(\n",
    "        [0, 0],\n",
    "        [1 / np.sqrt(3), -1 / np.sqrt(3)],\n",
    "        color=\"black\",\n",
    "        label=r\"$\\ell 1$-min. $\\mathcal{M}_{\\ell 1}$\",\n",
    "        marker=\"x\",\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "        zorder=3,\n",
    "    )\n",
    "    p5 = plt.scatter(\n",
    "        [np.sqrt(2), -np.sqrt(2)],\n",
    "        [0, 0],\n",
    "        label=r\"sharpness min. $\\mathcal{M}_{S_{\\mathcal{L}}}$\",\n",
    "        marker=\"o\",\n",
    "        facecolors=\"none\",\n",
    "        edgecolors=\"black\",\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "        zorder=4,\n",
    "    )\n",
    "    p6 = plt.scatter(\n",
    "        [-w[0][0][0], -w[0][0][0], w[0][0][0], w[0][0][0]],\n",
    "        [-w[0][0][1], w[0][0][1], -w[0][0][1], w[0][0][1]],\n",
    "        color=\"black\",\n",
    "        marker=\"*\",\n",
    "        zorder=5,\n",
    "        linewidth=1.5,\n",
    "        s=100,\n",
    "    )\n",
    "    (p8,) = plt.plot(\n",
    "        [], color=\"#85B5D9\", label=r\"sharpness bound $2/\\eta$\", linewidth=linew\n",
    "    )\n",
    "    (p9,) = plt.plot(\n",
    "        [], ls=\"-\", color=cols[1], label=\"GD\", zorder=3, ms=8, linewidth=linew\n",
    "    )\n",
    "    (p10,) = plt.plot(\n",
    "        [], ls=\"-\", color=cols[4], label=\"GD\", zorder=3, ms=8, linewidth=linew\n",
    "    )\n",
    "    (p11,) = plt.plot(\n",
    "        [], ls=\"-\", color=cols[7], label=\"GD\", zorder=3, ms=8, linewidth=linew\n",
    "    )\n",
    "    (p12,) = plt.plot(\n",
    "        [], ls=\"-\", color=cols[10], label=\"GD\", zorder=3, ms=8, linewidth=linew\n",
    "    )\n",
    "    plt.tight_layout()\n",
    "    plt.xlim(lrx, rrx)\n",
    "    plt.ylim(lry, rry)\n",
    "\n",
    "    plt.savefig(\"mainfig_diag_iterates_\" + str(j) + \"_both_no_legend.png\", dpi=300)\n",
    "    plt.legend(\n",
    "        [p1, p2, (p9, p10, p11, p12), p4, p5, p6, p8],\n",
    "        [\n",
    "            r\"solution manifold $\\mathcal{M}$\",\n",
    "            \"GF\",\n",
    "            \"GD\",\n",
    "            r\"$\\ell 1$-min. $\\mathcal{M}_{\\ell 1}$\",\n",
    "            r\"sharpness min. $\\mathcal{M}_{S_{\\mathcal{L}}}$\",\n",
    "            r\"best generalization $\\mathcal{M}_G$\",\n",
    "            r\"sharpness bound $2/\\eta$\",\n",
    "        ],\n",
    "        handler_map={tuple: HandlerTuple(ndivide=None)},\n",
    "        loc=\"upper right\",\n",
    "        bbox_to_anchor=(0.5, 0.99),\n",
    "    )  # plt.legend(loc=\"lower left\")\n",
    "\n",
    "    plt.savefig(\"mainfig_diag_iterates_\" + str(j) + \"_both.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e7b3474-0fae-45c3-b5fb-6ad1e63022b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7, 4))\n",
    "w_I1 = np.array([0, 1 / np.sqrt(3)])\n",
    "generalization_w_I1 = generalization(w_I1)\n",
    "generalization_flow = generalization(us_s[0][0][-1])\n",
    "generalization_opt = generalization([w[0][0][0], w[0][0][1]])\n",
    "generalization_shap = generalization([np.sqrt(2), 0])\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(\"risk\")\n",
    "plt.plot(\n",
    "    [],\n",
    "    ls=\"-\",\n",
    "    marker=\"x\",\n",
    "    color=colors[\"bound\"],\n",
    "    label=\"theoretical prediction\",\n",
    "    zorder=3,\n",
    "    ms=8,\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axhline(\n",
    "    generalization_flow,\n",
    "    label=r\"risk at $\\mathcal{M}_{\\ell 1}$\",\n",
    "    color=colors[\"flow\"],\n",
    "    alpha=0.7,\n",
    "    zorder=0,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "plt.axhline(\n",
    "    generalization_opt,\n",
    "    label=r\"risk at $\\mathcal{M}_{G}$\",\n",
    "    color=colors[\"goal\"],\n",
    "    alpha=0.7,\n",
    "    zorder=0,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "plt.axhline(\n",
    "    generalization_shap,\n",
    "    label=r\"risk at $\\mathcal{M}_{S_{\\mathcal{L}}}$\",\n",
    "    color=\"violet\",\n",
    "    alpha=0.7,\n",
    "    zorder=0,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "plt.scatter(\n",
    "    etas,\n",
    "    gens_points,\n",
    "    c=colors[\"converged\"],\n",
    "    zorder=4,\n",
    "    s=sizes[\"converged\"],\n",
    "    label=\"final value\",\n",
    ")\n",
    "plt.axvline(\n",
    "    2 / flow_shap,\n",
    "    label=\"$2/s_{GF}$\",\n",
    "    color=colors[\"flow\"],\n",
    "    zorder=2,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "for j in [1, 4, 7, 10]:\n",
    "    plt.scatter([etas[j]], [gens_points[j]], c=cols[j], zorder=5, s=sizes[\"converged\"])\n",
    "\n",
    "plt.plot(\n",
    "    etas_dense[:400], gens_dense[:400], c=colors[\"bound\"], zorder=0, linewidth=linew\n",
    ")\n",
    "plt.scatter(\n",
    "    etas[~mask2],\n",
    "    gens_true[~mask2],\n",
    "    color=colors[\"bound\"],\n",
    "    marker=\"x\",\n",
    "    zorder=3,\n",
    "    s=sizes[\"bound\"],\n",
    ")\n",
    "\n",
    "mask3 = gens_dense > 0\n",
    "min_eta_ix = np.argmin(gens_dense[mask3])\n",
    "\n",
    "plt.grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"mainfig_diag_generalization_no_legend.png\", dpi=300)\n",
    "plt.legend(loc=\"upper right\")\n",
    "plt.savefig(f\"mainfig_diag_generalization.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28959763-d59d-40e1-bf58-ac728de60a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7, 4))\n",
    "w_I1 = np.array([0, 1 / np.sqrt(3)])\n",
    "generalization_w_I1 = generalization(w_I1)\n",
    "generalization_flow = generalization(us_s[0][0][-1])\n",
    "generalization_opt = generalization([w[0][0][0], w[0][0][1]])\n",
    "generalization_shap = generalization([np.sqrt(2), 0])\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(\"risk\")\n",
    "plt.plot(\n",
    "    [],\n",
    "    ls=\"-\",\n",
    "    marker=\"x\",\n",
    "    color=colors[\"bound\"],\n",
    "    label=\"theoretical prediction\",\n",
    "    zorder=3,\n",
    "    ms=8,\n",
    "    linewidth=linew,\n",
    ")\n",
    "plt.axhline(\n",
    "    generalization_flow,\n",
    "    label=r\"risk at $\\mathcal{M}_{\\ell 1}$\",\n",
    "    color=colors[\"flow\"],\n",
    "    alpha=0.7,\n",
    "    zorder=0,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "plt.axhline(\n",
    "    generalization_opt,\n",
    "    label=r\"risk at $\\mathcal{M}_{G}$\",\n",
    "    color=colors[\"goal\"],\n",
    "    alpha=0.7,\n",
    "    zorder=0,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "plt.axhline(\n",
    "    generalization_shap,\n",
    "    label=r\"risk at $\\mathcal{M}_{S_{\\mathcal{L}}}$\",\n",
    "    color=\"violet\",\n",
    "    alpha=0.7,\n",
    "    zorder=0,\n",
    "    linewidth=linew,\n",
    "    linestyle=\":\",\n",
    ")\n",
    "plt.scatter(\n",
    "    etas,\n",
    "    gens_points,\n",
    "    c=colors[\"converged\"],\n",
    "    zorder=4,\n",
    "    s=sizes[\"converged\"],\n",
    "    label=\"final value\",\n",
    ")\n",
    "\n",
    "plt.axvline(\n",
    "    2 / flow_shap,\n",
    "    label=\"$2/s_{GF}$\",\n",
    "    color=colors[\"flow\"],\n",
    "    zorder=2,\n",
    "    linestyle=\"--\",\n",
    "    linewidth=linew,\n",
    ")\n",
    "for j in [1, 4, 7, 10]:\n",
    "    plt.scatter([etas[j]], [gens_points[j]], c=cols[j], zorder=5, s=sizes[\"converged\"])\n",
    "\n",
    "\n",
    "plt.plot(\n",
    "    etas_dense[:400], gens_dense[:400], c=colors[\"bound\"], zorder=0, linewidth=linew\n",
    ")\n",
    "plt.scatter(\n",
    "    etas[~mask2],\n",
    "    gens_true[~mask2],\n",
    "    color=colors[\"bound\"],\n",
    "    marker=\"x\",\n",
    "    zorder=3,\n",
    "    s=sizes[\"bound\"],\n",
    ")\n",
    "\n",
    "\n",
    "plt.axvspan(2 / flow_shap, etas_dense[420], color=\"#FF0A68\", alpha=0.09, zorder=-1)\n",
    "plt.text(0.26, 1.3, \"EoS regime\", fontsize=24, ha=\"center\", color=\"#FF8593\")\n",
    "\n",
    "mask3 = gens_dense > 0\n",
    "min_eta_ix = np.argmin(gens_dense[mask3])\n",
    "\n",
    "plt.xlim(0.09, 0.59)\n",
    "\n",
    "plt.grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"mainfig_diag_generalization_with_eos_no_legend.png\", dpi=300)\n",
    "plt.legend(loc=\"upper right\")\n",
    "plt.savefig(f\"mainfig_diag_generalization_with_eos.png\", dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
