{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f578b3e9-ee1b-4442-8a9f-69e064cc11ad",
   "metadata": {},
   "source": [
    "# Diagonal Model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8eb2ce6-229c-405c-955a-b3ae59c42594",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numpy.linalg as la\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import grad, hessian\n",
    "from jax import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "842b9593-b527-4a09-b28d-86902b34cd2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "# Alternative for interactive plots: \n",
    "#%matplotlib widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b04ca7-4ab6-4343-9c1b-849934a4ef87",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x,N):\n",
    "    return np.power(x,N)\n",
    "\n",
    "def grad_f(x,N):\n",
    "    return N*np.power(x,N-1)\n",
    "\n",
    "def Hess_f(x,N):\n",
    "    if N == 1:\n",
    "        return np.zeros([len(x),len(x)])\n",
    "    return N*(N-1)*np.diag(np.power(x,N-2))\n",
    "\n",
    "def H1(x,N,ATA,ATy):\n",
    "    return Hess_f(x,N)@np.diag((ATA @ f(x,N)  - ATy))\n",
    "\n",
    "def H2(x,N,ATA):\n",
    "    return np.diag(grad_f(x,N))@ATA@np.diag(grad_f(x,N))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88308f1c-4fd3-4de3-8d91-669a7f47216d",
   "metadata": {},
   "outputs": [],
   "source": [
    "### MODEL ###\n",
    "\n",
    "class DiagonalNetwork:\n",
    "    \n",
    "    def __init__(self, A, y, L, u0, custom_loss=None):\n",
    "        self.L = L\n",
    "        self.A = A\n",
    "        self.y = y\n",
    "        self.alpha = alpha\n",
    "        \n",
    "        self.u = np.array(u0)      \n",
    "\n",
    "        # JIT-compiled \n",
    "        self.jit_loss = jax.jit(self.loss)\n",
    "        self.jit_update = self.update\n",
    "        if(custom_loss is not None):\n",
    "            self.loss_ = custom_loss\n",
    "            \n",
    "    def set_u(self,u):\n",
    "        self.u=u\n",
    "\n",
    "    def get_u(self):\n",
    "        return self.u\n",
    "        \n",
    "    def loss(self, u):\n",
    "        # (change the loss in this line to modify the model)\n",
    "        return 1/(2*self.L)* jnp.linalg.norm((self.A@ jnp.power(u,self.L) - self.y))**2  \n",
    "\n",
    "    def custom_update(self, eta=1e-3):\n",
    "        grad_fn = grad(self.loss_)\n",
    "        \n",
    "        grads = grad_fn(self.u, self.A,self.L,self.y)\n",
    "    \n",
    "        # Perform gradient descent update\n",
    "        self.u = self.u - eta * grads\n",
    "\n",
    "        hess_fn =hessian(self.loss_) \n",
    "        \n",
    "        hess = hess_fn(self.u,self.A,self.L,self.y)\n",
    "        return la.norm(hess,ord=2)\n",
    "\n",
    "    def shaps(self, u):\n",
    "        hess_fn =hessian(self.loss) \n",
    "        hess = hess_fn(u)\n",
    "        return la.norm(hess,ord=2)\n",
    "    \n",
    "    \n",
    "    def update(self, eta=1e-3):\n",
    "        grad_fn = grad(self.loss)\n",
    "        \n",
    "        grads = grad_fn(self.u)\n",
    "    \n",
    "        # Perform gradient descent update\n",
    "        self.u = self.u - eta * grads\n",
    "\n",
    "        hess_fn =hessian(self.loss) \n",
    "        \n",
    "        hess = hess_fn(self.u)\n",
    "      \n",
    "        return la.norm(hess,ord=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43487c63-3fe0-47b2-a3d1-991434f7180c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, eta, steps=100, ls=[],s_H1=[],s_H2=[],s_all=[],us=[]):\n",
    "    for iter in range(steps):\n",
    "        ev = model.jit_update(eta=eta)\n",
    "        s_all.append(ev)\n",
    "        us.append(model.get_u())\n",
    "       \n",
    "        s_H1.append(la.norm(H1(model.u,L,ATA,ATy),2)) \n",
    "        s_H2.append(la.norm(H2(model.u,L,ATA),2))\n",
    "        loss = model.jit_loss(model.u)    \n",
    "        ls.append(loss)\n",
    "        if (iter + 1) % 100 == 0:\n",
    "            print(f'Epoch [{iter + 1}/{max_iter}], Loss: {loss:.6f}')\n",
    "        if loss < 0.0000001:\n",
    "            print(\"Loss goal satisfied with loss \" + str(loss))\n",
    "            return model, ls, s_H1, s_H2, s_all, us\n",
    "    \n",
    "        \n",
    "    return model, ls, s_H1, s_H2, s_all, us"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f41b58-9653-4a0a-a372-3af0be2a2b0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_custom(model, eta, steps=100, ls=[],s_all=[]):\n",
    "    for iter in range(steps):\n",
    "        ev = model.custom_update(eta=eta)\n",
    "        s_all.append(ev)\n",
    "      \n",
    "        loss = model.loss_(model.u, model.A,model.L,model.y)    \n",
    "        ls.append(loss)\n",
    "    \n",
    "        if (iter + 1) % 100 == 0:\n",
    "            print(f'Epoch [{iter + 1}/{max_iter}], Loss: {loss:.6f}')\n",
    "    return model, ls,  s_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07313e3a-b3ca-4780-89a6-2997fa720a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "### PARAMS ###\n",
    "seed = 4\n",
    "keys = random.split(random.PRNGKey(seed),2)\n",
    "L=2\n",
    "\n",
    "x_dim = 2\n",
    "\n",
    "y_dim = 1\n",
    "print('(x,y) Dimension = ({},{})'.format(x_dim,y_dim))\n",
    "\n",
    "A = np.array([[0.5,3]])\n",
    "\n",
    "y = np.array([3]) \n",
    "alpha = [0.01,0.001]\n",
    "\n",
    "max_iter = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f860360c-d688-4010-a4f0-177e1d241980",
   "metadata": {},
   "outputs": [],
   "source": [
    "ATA = A.T@A\n",
    "ATy = A.T@y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da083461-d1a4-46c4-8066-6bb0a64cdb5a",
   "metadata": {},
   "source": [
    "## Series of Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "928dc9f2-5308-4a80-a7bd-b866467dd883",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### PARAMS #### \n",
    "etas =np.array([0.001,0.08,0.091,0.102,0.113,0.124,0.135,0.146,0.157,0.167,0.178,0.189,0.2])#np.array([0.001, 0.09355, 0.0998, 0.10605, 0.1123, 0.11855, 0.1248, 0.1305, 0.1373, 0.14355, 0.1498, 0.15605, 0.1623])#np.linspace(.1,1,10) #np.linspace(0.6,1.1,16)# np.logspace(-1,1,3) #0.01,0.5,0.7\n",
    "alphas = [[0.01,0.01]]\n",
    "max_iter =10000 #500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b104142a-327e-4208-8d68-a15c70e95f6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_experiments = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e703aa89-34b6-4e8b-9f1e-efe61799e7cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "if(run_experiments):\n",
    "    ls_s = []\n",
    "    s_H1_s = []\n",
    "    s_H2_s = []\n",
    "    s_all_s = []\n",
    "    us_s = []\n",
    "    for eta_ in etas:\n",
    "        ls_s_ = []\n",
    "        s_H1_s_ = []\n",
    "        s_H2_s_ = []\n",
    "        s_all_s_ = []\n",
    "        us_s_ = []\n",
    "        for alpha_ in alphas:\n",
    "            print(f'Alpha {alpha_}, Eta: {eta_}')\n",
    "    \n",
    "            model = DiagonalNetwork(A, y, L, alpha_)\n",
    "            model, ls, s_H1, s_H2, s_all, us = train(model, eta_, steps=max_iter,ls=[],s_H1=[],s_H2=[],s_all=[],us=[model.get_u()])\n",
    "            ls_s_.append(ls)\n",
    "            s_H1_s_.append(s_H1)\n",
    "            s_H2_s_.append(s_H2)\n",
    "            s_all_s_.append(s_all)\n",
    "            us_s_.append(us)\n",
    "        ls_s.append(ls_s_)\n",
    "        s_H1_s.append(s_H1_s_)\n",
    "        s_H2_s.append(s_H2_s_)\n",
    "        s_all_s.append(s_all_s_)\n",
    "        us_s.append(us_s_)\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2028aa3-3c61-4514-8d41-ac822fd5fff6",
   "metadata": {},
   "outputs": [],
   "source": [
    "etas = etas[1:]\n",
    "max_idx = np.argmax([s for s in s_all_s[0][0]]).item()\n",
    "flow_shap = s_all_s[0][0][max_idx]\n",
    "flow_shap_last = s_all_s[0][0][-1]\n",
    "shaps = np.array([s_all_s[i+1][0][-1] for i in range(len(s_all_s)-1)])\n",
    "maxshaps = np.array([np.max(s_all_s[i+1][0]) for i in range(len(s_all_s)-1)])\n",
    "flow_norm = np.linalg.norm(us_s[0][0][-1],ord=1)\n",
    "norms1 = np.array([np.linalg.norm(us_s[i+1][0][-1],ord=1) for i in range(len(us_s)-1)])\n",
    "flow_loss = ls_s[0][0][-1]\n",
    "losses = np.array([ls_s[i+1][0][-1] for i in range(len(ls_s)-1)])\n",
    "dists = np.array([np.linalg.norm(us_s[i+1][0][-1] - us_s[0][0][-1],ord=1) for i in range(len(us_s)-1)])\n",
    "iterations = np.array([len(us_s[i+1][0]) for i in range(len(us_s)-1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d8ef672-fa9a-433b-8832-f829061622cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = {\n",
    "    \"converged\": \"#00CC03\", \n",
    "    \"converged_last\": \"darkgreen\", \n",
    "    \"maximum\": \"black\",\n",
    "    \"bound\": \"#0081D1\",  \n",
    "    \"flow\": \"#FF7F0F\",  \n",
    "    \"heuristic\": \"#CC00F5\", \n",
    "    \"goal\": \"#70C8FF\", \n",
    "    \"a\": \"#F00A02\",\n",
    "    \"b\": \"#0C7EC3\",\n",
    "    \"c\": \"#1ABA1E\",\n",
    "    \"d\": \"#085E09\",\n",
    "}\n",
    "sizes = {\n",
    "    \"converged\": 100,\n",
    "    \"maximum\": 133,\n",
    "    \"maxwidth\": 1.5,\n",
    "    \"bound\": 160\n",
    "}\n",
    "linew = 3\n",
    "tmp=np.abs(losses)\n",
    "mask = tmp<=0.0001\n",
    "plt.rcParams['font.size'] = 14\n",
    "plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "plt.rcParams['font.family'] = 'STIXGeneral'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c373bcb2-ba0a-4372-8457-65f378912445",
   "metadata": {},
   "outputs": [],
   "source": [
    "general_captions = True\n",
    "gf_mode = \"max\"\n",
    "if general_captions:\n",
    "    glabel = r\"final value\"\n",
    "    mlabel = \"value at max sharpness\"\n",
    "    if gf_mode == \"max\":\n",
    "        flabel = r\"GF value at max sharpness\"\n",
    "        flabel_last = r\"final GF sharpness\"\n",
    "\n",
    "    else:\n",
    "        flabel = r\"final GF value\"\n",
    "else:\n",
    "    glabel = r\"final sharpness\"\n",
    "    mlabel = \"maximum value\"\n",
    "    if gf_mode == \"max\":\n",
    "        flabel = r\"max GF sharpness ($s_{GF}$)\"\n",
    "        flabel_last = r\"final GF sharpness\"\n",
    "\n",
    "    else:\n",
    "        flabel = r\"final GF sharpness\"\n",
    "        flabel_last = r\"final GF sharpness\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce4df596-4657-4624-898b-b97209024d88",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(shaps)\n",
    "print(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1bafda3-5eb7-43d1-bcd0-501612b383df",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7,4))\n",
    "mask2 = np.array(shaps<=0)\n",
    "t1 = np.linspace(np.max((0,np.min(etas[~mask2]) - 0.03 * (np.max(etas[~mask2]) - np.min(etas[~mask2])))), np.max(etas[~mask2]) + 0.03 * (np.max(etas[~mask2]) - np.min(etas[~mask2])), 100)\n",
    "if(t1[0]==0):\n",
    "    t1=t1[1:]\n",
    "plt.plot(t1, 2/t1, color=colors[\"bound\"],zorder=0,linewidth=linew)#, alpha=0.2)\n",
    "plt.scatter(etas[~mask2], 2/etas[~mask2],color=colors[\"bound\"],marker=\"x\",zorder=3,s=sizes[\"bound\"])\n",
    "plt.plot([],ls=\"-\", marker=\"x\", color=colors[\"bound\"], label = r\"$2/\\eta$\",zorder=3,ms=8, linewidth=linew)\n",
    "plt.axhline(flow_shap_last, label = flabel_last,color=colors[\"converged_last\"],alpha=.5, zorder=2,linewidth=linew,linestyle=\":\")\n",
    "plt.axhline(flow_shap, label = flabel,color=colors[\"flow\"],zorder=1,linewidth=linew)\n",
    "plt.axvline(2/flow_shap, label = \"$2/s_{GF}$\", color=colors[\"flow\"],zorder=1,linestyle=\"--\",linewidth=linew)\n",
    "plt.scatter(etas[mask & ~mask2], shaps[mask & ~mask2], c=colors[\"converged\"], label=glabel, zorder=4,s=sizes[\"converged\"])\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(\"sharpness\")\n",
    "flow_shap_np = np.array([flow_shap])\n",
    "lims = np.concatenate((shaps[~mask2], flow_shap_np))\n",
    "plt.ylim(np.min(lims) - 0.1*(np.max(lims) - np.min(lims)), np.max(lims) + 0.2*(np.max(lims) - np.min(lims)))\n",
    "plt.grid(True, linestyle = '-')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"diag_sharpness_no_legend.png\", dpi=300)  \n",
    "plt.legend(loc=\"upper right\")\n",
    "\n",
    "plt.savefig(\"diag_sharpness.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b52a6c5e-ab54-411f-8e75-4b3ee89a0e63",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7,4))\n",
    "mask2 = np.array(norms1<=0)  \n",
    "plt.axhline(flow_norm, label = flabel,color=colors[\"flow\"],zorder=1, linewidth=linew)\n",
    "\n",
    "plt.axvline(2/flow_shap, label = \"$2/s_{GF}$\",color=colors[\"flow\"],zorder=2,linestyle=\"--\", linewidth=linew)\n",
    "plt.scatter(etas[mask & ~mask2], norms1[mask & ~mask2],c=colors[\"converged\"], label=glabel,zorder=3,s=sizes[\"converged\"])\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(r\"$\\ell 1$-norm\")\n",
    "plt.grid(True, linestyle = '-')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"diag_norm_l1_no_legend.png\", dpi=300)\n",
    "plt.legend(loc=\"upper left\")\n",
    "plt.savefig(\"diag_norm_l1.png\", dpi=300)  # Save the plot as a high-quality PNG file\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c82a44b-b44e-4726-98e8-c5fe5b1e4376",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7,4))\n",
    "mask2 = np.array(dists<0) \n",
    "plt.axvline(2/flow_shap, label = \"$2/s_{GF}$\",color=colors[\"flow\"],zorder=0,linestyle=\"--\", linewidth=linew)\n",
    "plt.scatter(etas[mask & ~mask2], dists[mask & ~mask2],c=colors[\"converged\"], label=glabel,zorder=3,s=sizes[\"converged\"])\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(\"distance from GF\")\n",
    "plt.grid(True, linestyle = '-')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"diag_distance_GF_no_legend.png\", dpi=300)\n",
    "plt.legend(loc=\"upper left\")\n",
    "plt.savefig(\"diag_distance_GF.png\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14efbf42-1cdf-413b-b1f7-8d7d3ca12239",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7,4))\n",
    "mask2 = np.array(iterations<0) \n",
    "plt.axvline(2/flow_shap, label = \"$2/s_{GF}$\",color=colors[\"flow\"],zorder=0,linestyle=\"--\", linewidth=linew)\n",
    "plt.scatter(etas[mask & ~mask2], iterations[mask & ~mask2],c=colors[\"converged\"], label=glabel,zorder=3,s=sizes[\"converged\"])\n",
    "plt.xlabel(r\"$\\eta$\")\n",
    "plt.ylabel(\"#iterations\")\n",
    "plt.grid(True, linestyle = '-')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"diag_iterations_no_legend.png\", dpi=300)\n",
    "plt.legend(loc=\"upper right\")\n",
    "plt.savefig(\"diag_iterations.png\", dpi=300) \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b890d863-b64b-43f0-864f-71a15213f74b",
   "metadata": {},
   "outputs": [],
   "source": [
    "goal = []\n",
    "for i in range(y_dim):\n",
    "    t1 = np.arange(-2.8, 2.8, 0.001)\n",
    "    lab ='Coordinate goal'\n",
    "    if A[i, 1] != 0:\n",
    "        if L % 2 == 0:\n",
    "            t1 = [t for t in t1 if (y[i] - A[i,0]*t**L)/A[i,1] >= 0]\n",
    "        goal = goal + list(zip(t1, [((y[i] - A[i,0]*t**L)/A[i,1])**(1/L) for t in t1]))\n",
    "        if L % 2 == 0:\n",
    "            goal = goal + list(reversed(list(zip(t1, [-((y[i] - A[i,0]*t**L)/A[i,1])**(1/L) for t in t1]))))\n",
    "    elif A[i, 0] != 0:\n",
    "        if L % 2 == 0:\n",
    "            t1 = [t for t in t1 if (y[i] - A[i,1]*t**L)/A[i,0] >= 0]\n",
    "        goal = goal + list(zip([((y[i] - A[i,1]*t**L)/A[i,0])**(1/L) for t in t1], t1))\n",
    "        if L % 2 == 0:\n",
    "            goal = goal + list(reversed(list(zip([-((y[i] - A[i,1]*t**L)/A[i,0])**(1/L) for t in t1], t1))))\n",
    "goal.append(goal[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c95efc5d-5d23-45b9-bd8e-a6f16cdca764",
   "metadata": {},
   "outputs": [],
   "source": [
    "lrx = -2.6\n",
    "rrx = 2.7\n",
    "lry = -1.5\n",
    "rry = 1.6\n",
    "ds = 0.01\n",
    "xx = np.arange(lrx, rrx+0.01, ds)\n",
    "yy = np.arange(lry, rry+0.01, ds)\n",
    "shapmap = np.zeros((len(yy),len(xx)))\n",
    "for i in range(len(xx)):\n",
    "    for j in range(len(yy)):\n",
    "        try:\n",
    "            shapmap[j][i] = model.shaps(np.array([xx[i],yy[j]]))\n",
    "        except la.LinAlgError:\n",
    "            shapmap[j][i] = 0\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58b27486-3a18-4f2a-ae47-38c49656af0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_shapmaps = []\n",
    "for eta in etas:\n",
    "    threshold_value = 2/eta\n",
    "    tolerance = 0.8\n",
    "    lower_bound = threshold_value - tolerance\n",
    "    upper_bound = threshold_value + tolerance\n",
    "    mask = (shapmap >= lower_bound) & (shapmap <= upper_bound)\n",
    "    shapm2 = shapmap.copy()\n",
    "    shapm2[mask] = 0\n",
    "    shapm2[~mask] = 1\n",
    "    masked_shapmaps.append(shapm2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e2f04a-f5e8-4587-83e2-c17d2c514cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "shapquant = []\n",
    "for k, eta_ in enumerate(etas):\n",
    "    shapq = np.zeros((len(yy),len(xx)))\n",
    "    for i in range(len(xx)):\n",
    "        for j in range(len(yy)):\n",
    "            if shapmap[j][i] <= 1/etas[k]:\n",
    "                shapq[j][i] = 2\n",
    "            elif shapmap[j][i] <= 2/etas[k]:\n",
    "                shapq[j][i] = 2\n",
    "            else:\n",
    "                shapq[j][i] = 3\n",
    "    shapquant.append(shapq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05009c12-33d4-43dc-b214-83ab60613f43",
   "metadata": {},
   "outputs": [],
   "source": [
    "bounds = []\n",
    "for eta in etas:\n",
    "    bound = []\n",
    "    for i in range(y_dim):\n",
    "        t1 = np.arange(-2.5, 2.5, 0.01)\n",
    "        if A[i, 1] != 0:\n",
    "            if L % 2 == 0:\n",
    "                t1 = [t for t in t1 if ((2/eta) - A[i,0]**2*t**L)/A[i,1]**2 >= 0]\n",
    "            bound = bound + list(zip(t1, [(((2/eta) - A[i,0]**2*t**L)/A[i,1]**2)**(1/L) for t in t1]))\n",
    "            if L % 2 == 0:\n",
    "                bound = bound + list(reversed(list(zip(t1, [-(((2/eta) - A[i,0]**2*t**L)/A[i,1]**2)**(1/L) for t in t1]))))\n",
    "        elif A[i, 0] != 0:\n",
    "            if L % 2 == 0:\n",
    "                t1 = [t for t in t1 if ((2/eta) - A[i,1]**2*t**L)/A[i,0]**2 >= 0]\n",
    "            bound = bound + list(zip([(((2/eta) - A[i,1]**2*t**L)/A[i,0]**2)**(1/L) for t in t1], t1))\n",
    "            if L % 2 == 0:\n",
    "                bound = bound + list(reversed(list(zip([-(((2/eta) - A[i,1]**2*t**L)/A[i,0]**2)**(1/L) for t in t1], t1))))\n",
    "    bounds.append(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d8434ff-7fab-4abf-999c-757b8e6da2c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 0\n",
    "for j in range(len(etas)):\n",
    "    plt.figure(figsize=(7,4))\n",
    "    plt.imshow(shapmap, cmap=\"Blues\", interpolation=\"spline16\", origin=\"lower\", extent=(lrx - 0.5/len(xx), rrx- 0.5/len(xx), lry- 0.5/len(yy), rry- 0.5/len(yy)))\n",
    "    plt.imshow(masked_shapmaps[j], cmap=\"tab20c\", interpolation=\"nearest\", origin=\"lower\", extent=(lrx - 0.5/len(xx), rrx- 0.5/len(xx), lry- 0.5/len(yy), rry- 0.5/len(yy)),alpha=0.4)\n",
    "    plt.axhline(0, color=\"gray\",alpha=0.5)\n",
    "    plt.axvline(0, color=\"gray\",alpha=0.5)\n",
    "    plt.plot([g[0] for g in goal], [g[1] for g in goal], color=\"yellow\", label=r\"solution manifold $\\mathcal{M}$\",linewidth=linew)\n",
    "    plt.plot([u[0] for u in us_s[0][k]], [u[1] for u in us_s[0][k]], color=colors[\"flow\"],alpha=0.8,linewidth=4,linestyle=(0, (1, 1)))\n",
    "    plt.plot([], color=colors[\"flow\"], label=\"GF\", alpha=0.8,linewidth=linew,linestyle=(0, (1, 1)))\n",
    "    plt.plot([u[0] for u in us_s[j+1][k]], [u[1] for u in us_s[j+1][k]], color=colors[\"converged\"], label=\"GD\", alpha=0.8,linewidth=linew)\n",
    "    plt.scatter([0,0],[1,-1], color=\"black\", label=r\"$\\ell 1$-min. $\\mathcal{M}_{\\ell 1}$\", marker=\"x\",linewidth=1.5,s=100,zorder=3)\n",
    "    plt.scatter([np.sqrt(6),-np.sqrt(6)],[0,0], label=r\"sharpness min. $\\mathcal{M}_{S_{\\mathcal{L}}}$\", marker=\"o\",facecolors='none', edgecolors=\"black\",linewidth=1.5,s=100,zorder=4)\n",
    "\n",
    "    plt.plot([], color=\"#85B5D9\",label = r\"sharpness bound $2/\\eta$\",linewidth=linew)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"diag_iterates_\"+str(j)+\"_both_no_legend.png\", dpi=300)\n",
    "    plt.legend(loc=\"upper left\")\n",
    "\n",
    "    plt.savefig(\"diag_iterates_\"+str(j)+\"_both.png\", dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
