{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7GVWcjh65iGv"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "from scipy.linalg import sqrtm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b8aBVIt_8qHf"
   },
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.rcParams[\"figure.figsize\"] = (7,5)\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette(\"colorblind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wLEakSJI8qOy"
   },
   "outputs": [],
   "source": [
    "# Define custom dash patterns for each line without names\n",
    "custom_dashes = [\n",
    "    (1, 0),           # Solid\n",
    "    (1, 1),           # Dotted\n",
    "    (4, 2),           # Dashed\n",
    "    (3, 2, 1, 2),     # Dash-Dot\n",
    "    (5, 2, 1, 2, 1, 2),  # Dash-Dot-Dot\n",
    "    (8, 2),           # Long Dashes\n",
    "    (2, 4),           # Loosely Dashed\n",
    "    (5, 4, 1, 4),     # Sparse Dash-Dot\n",
    "    (6, 4, 1, 4, 1, 4),  # Sparse Dash-Dot-Dot\n",
    "    (1, 4),           # Loosely Dotted\n",
    "    (4, 2, 1, 2, 1, 2),  # Dashed with Dots\n",
    "    (2, 4, 1, 4, 2, 4)   # Custom Pattern 1\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9iedPQJn5jvq"
   },
   "outputs": [],
   "source": [
    "def cost(x): return 0.5*x.T@H@x\n",
    "def g(x): return H@x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oVZOW6-UeAee"
   },
   "outputs": [],
   "source": [
    "def Adam_Traj(eta, x0, nit, sigma,beta_1,beta_2,eps, l,seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "  noise = sigma*np.random.normal(size=(d,))\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  m = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0]+noise)*g(x[0]+noise)\n",
    "  m[0] = g(x[0]+noise)\n",
    "  for k in range(nit-1):\n",
    "\n",
    "    gamma_1 = 1/(1 - beta_1**(k+1))\n",
    "    gamma_2 = 1/(1 - beta_2**(k+1))\n",
    "\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    m[k+1] = beta_1*m[k]+(1-beta_1)*(g(x[k])+noise)\n",
    "    v[k+1] = beta_2*v[k]+(1-beta_2)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k] - eta*(gamma_1*m[k+1])/(np.sqrt(gamma_2*v[k+1])+eps) - eta*l*x[k]\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n",
    "def RMS_Traj(eta, x0, nit, sigma,beta,eps, l,seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0])*g(x[0])\n",
    "  for k in range(nit-1):\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    v[k+1] = beta*v[k]+(1-beta)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k]*(1-l*eta) - eta*(g(x[k])+noise)/(np.sqrt(v[k])+eps)\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ySznf5idOUf8"
   },
   "outputs": [],
   "source": [
    "def Adam_Trajs(eta, x0, nit, sigma,beta_1,beta_2,eps,l, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = Adam_Traj(eta, x0, nit+1, sigma,beta_1,beta_2,eps,l, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)\n",
    "\n",
    "\n",
    "def RMS_Trajs(eta, x0, nit, sigma,beta,eps,l, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = RMS_Traj(eta, x0, nit+1, sigma,beta,eps,l, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K121PaKMDbr6"
   },
   "outputs": [],
   "source": [
    "d = 2\n",
    "landa1 = 1\n",
    "landa3 = 3\n",
    "H = np.array([[landa1, 0.0],[0.0, landa3]])\n",
    "eta = 0.001\n",
    "sigma = 1\n",
    "eps = 1e-7\n",
    "\n",
    "\n",
    "\n",
    "T=20\n",
    "nit= int(T/eta)\n",
    "\n",
    "mu1=100\n",
    "mu2=1\n",
    "mu=1\n",
    "\n",
    "beta=1-mu*eta\n",
    "beta_1 = 1 - mu1*eta\n",
    "beta_2 = 1 - mu2*eta\n",
    "\n",
    "\n",
    "\n",
    "l1 = 1\n",
    "l3 = 4\n",
    "\n",
    "kappa = 2\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "x0 = 0.1*np.random.normal(size=(d,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4l7r9_Bk-0Di"
   },
   "outputs": [],
   "source": [
    "limit1 = 0.25*eta*np.ones((nit,))*landa1*sigma/(sigma*l1+landa1) + 0.25*eta*np.ones((nit,))*landa3*sigma/(sigma*l1+landa3)\n",
    "limit3 = 0.25*eta*np.ones((nit,))*landa1*sigma/(sigma*l3+landa1) + 0.25*eta*np.ones((nit,))*landa3*sigma/(sigma*l3+landa3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PB0Pzkm0FuKV"
   },
   "outputs": [],
   "source": [
    "seeds = [i for i in np.arange(1000)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xERLGH0VOp7C",
    "outputId": "db5a81f0-35a9-4488-deb9-2d89edad96ff",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Adam_Trajs_11       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu1*eta, 1-mu2*eta,  eps         ,l1, seeds)\n",
    "Adam_Trajs_13       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu1*eta, 1-mu2*eta,  eps         ,l3, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Adam_Trajs_11_Rescaled       = Adam_Trajs   (eta*kappa, x0, nit, sigma/kappa, 1-kappa**2*(1-beta_1), 1-kappa**2*(1-beta_2),  eps,l1*kappa, seeds)\n",
    "Adam_Trajs_13_Rescaled       = Adam_Trajs   (eta*kappa, x0, nit, sigma/kappa, 1-kappa**2*(1-beta_1), 1-kappa**2*(1-beta_2),  eps,l3*kappa, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Adam_Trajs_11_NRescaled       = Adam_Trajs   (eta*kappa, x0, nit, sigma/kappa, 1-kappa**2*(1-beta_1), 1-kappa**2*(1-beta_2),  eps,l1, seeds)\n",
    "Adam_Trajs_13_NRescaled       = Adam_Trajs   (eta*kappa, x0, nit, sigma/kappa, 1-kappa**2*(1-beta_1), 1-kappa**2*(1-beta_2),  eps,l3, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9RQj_DmQBS85",
    "outputId": "ab10bcd8-0b32-428f-cb4b-0c6ada6306e8",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "RMS_Trajs_11       = RMS_Trajs   (eta                , x0, nit, sigma               , 1-mu*eta,  eps         ,l1, seeds)\n",
    "RMS_Trajs_13       = RMS_Trajs   (eta                , x0, nit, sigma               , 1-mu*eta,  eps         ,l3, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "RMS_Trajs_11_Rescaled       = RMS_Trajs   (eta*kappa, x0, nit, sigma/kappa, 1-kappa**2*(1-beta),  eps,l1*kappa, seeds)\n",
    "RMS_Trajs_13_Rescaled       = RMS_Trajs   (eta*kappa, x0, nit, sigma/kappa, 1-kappa**2*(1-beta),  eps,l3*kappa, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "RMS_Trajs_11_NRescaled       = RMS_Trajs   (eta*kappa, x0, nit, sigma/kappa, 1-kappa**2*(1-beta),  eps,l1, seeds)\n",
    "RMS_Trajs_13_NRescaled       = RMS_Trajs   (eta*kappa, x0, nit, sigma/kappa, 1-kappa**2*(1-beta),  eps,l3, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tQgz62A-O6s1"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj_l1 = np.mean(Adam_Trajs_11[0],axis=0)\n",
    "avg_Adam_Loss_l1 = np.mean(Adam_Trajs_11[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_l3 = np.mean(Adam_Trajs_13[0],axis=0)\n",
    "avg_Adam_Loss_l3 = np.mean(Adam_Trajs_13[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_l3_R = np.mean(Adam_Trajs_13_Rescaled[0],axis=0)\n",
    "avg_Adam_Loss_l3_R = np.mean(Adam_Trajs_13_Rescaled[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_l1_R = np.mean(Adam_Trajs_11_Rescaled[0],axis=0)\n",
    "avg_Adam_Loss_l1_R = np.mean(Adam_Trajs_11_Rescaled[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_l1_NR = np.mean(Adam_Trajs_11_NRescaled[0],axis=0)\n",
    "avg_Adam_Loss_l1_NR = np.mean(Adam_Trajs_11_NRescaled[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_l3_NR = np.mean(Adam_Trajs_13_NRescaled[0],axis=0)\n",
    "avg_Adam_Loss_l3_NR = np.mean(Adam_Trajs_13_NRescaled[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_RMS_Traj_l1 = np.mean(RMS_Trajs_11[0],axis=0)\n",
    "avg_RMS_Loss_l1 = np.mean(RMS_Trajs_11[1],axis=0)\n",
    "\n",
    "avg_RMS_Traj_l3 = np.mean(RMS_Trajs_13[0],axis=0)\n",
    "avg_RMS_Loss_l3 = np.mean(RMS_Trajs_13[1],axis=0)\n",
    "\n",
    "avg_RMS_Traj_l3_R = np.mean(RMS_Trajs_13_Rescaled[0],axis=0)\n",
    "avg_RMS_Loss_l3_R = np.mean(RMS_Trajs_13_Rescaled[1],axis=0)\n",
    "\n",
    "avg_RMS_Traj_l1_R = np.mean(RMS_Trajs_11_Rescaled[0],axis=0)\n",
    "avg_RMS_Loss_l1_R = np.mean(RMS_Trajs_11_Rescaled[1],axis=0)\n",
    "\n",
    "avg_RMS_Traj_l1_NR = np.mean(RMS_Trajs_11_NRescaled[0],axis=0)\n",
    "avg_RMS_Loss_l1_NR = np.mean(RMS_Trajs_11_NRescaled[1],axis=0)\n",
    "\n",
    "avg_RMS_Traj_l3_NR = np.mean(RMS_Trajs_13_NRescaled[0],axis=0)\n",
    "avg_RMS_Loss_l3_NR = np.mean(RMS_Trajs_13_NRescaled[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 486
    },
    "id": "GQzbZMZM5vh8",
    "outputId": "42598e0a-bfa6-4ee7-c05d-d7c77ca037ff"
   },
   "outputs": [],
   "source": [
    "start = 0\n",
    "end = -1\n",
    "\n",
    "flag = 0\n",
    "\n",
    "plt.figure()\n",
    "# We plot\n",
    "\n",
    "\n",
    "line_3, = plt.plot(avg_Adam_Loss_l1_NR[start:end]+1*flag, color = colors[3],linewidth=2)\n",
    "line_2, = plt.plot(avg_Adam_Loss_l1_R[start:end]+1*flag, color = colors[1],linewidth=2)\n",
    "line_1, = plt.plot(avg_Adam_Loss_l1[start:end]+1*flag, color = colors[0],linewidth=2)\n",
    "\n",
    "\n",
    "line_7, = plt.plot(avg_Adam_Loss_l3_NR[start:end]+1*flag, color = colors[6],linewidth=2)\n",
    "line_6, = plt.plot(avg_Adam_Loss_l3_R[start:end]+1*flag, color = colors[5],linewidth=2)\n",
    "line_5, = plt.plot(avg_Adam_Loss_l3[start:end]+1*flag, color = colors[4],linewidth=2)\n",
    "\n",
    "line_4, = plt.plot(limit1[start:end], color = colors[2],linewidth=2)\n",
    "\n",
    "line_8, = plt.plot(limit3[start:end], color = colors[7],linewidth=2)\n",
    "\n",
    "\n",
    "plt.legend([line_1, line_2,line_3,line_4,line_5,line_6,line_7,line_8],['AdamW (' + '$\\gamma = $'+str(l1)+')','AdamW R (' + '$\\gamma = $'+str(l1)+')','AdamW NR (' + '$\\gamma = $'+str(l1)+')','Theor. Pred. (' + '$\\gamma = $'+str(l1)+')','AdamW (' + '$\\gamma = $'+str(l3)+')','AdamW R (' + '$\\gamma = $'+str(l3)+')','AdamW NR (' + '$\\gamma = $'+str(l3)+')','Theor. Pred. (' + '$\\gamma = $'+str(l3)+')'],fontsize=14)\n",
    "plt.title(\"Losses\",fontsize=25)\n",
    "plt.xlabel('Iterations',fontsize=25)\n",
    "plt.ylabel('Loss',fontsize=25)\n",
    "num_ticks = 6  # Adjust this number as needed\n",
    "xticks_positions = np.linspace(1, nit, num_ticks, dtype=int)\n",
    "plt.xticks(xticks_positions, fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.yscale('log')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = 0\n",
    "end = -1\n",
    "\n",
    "flag = 0\n",
    "\n",
    "plt.figure()\n",
    "# We plot\n",
    "\n",
    "\n",
    "line_3, = plt.plot(avg_RMS_Loss_l1_NR[start:end]+1*flag, color = colors[3],linewidth=2)\n",
    "line_2, = plt.plot(avg_RMS_Loss_l1_R[start:end]+1*flag, color = colors[1],linewidth=2)\n",
    "line_1, = plt.plot(avg_RMS_Loss_l1[start:end]+1*flag, color = colors[0],linewidth=2)\n",
    "\n",
    "\n",
    "line_7, = plt.plot(avg_RMS_Loss_l3_NR[start:end]+1*flag, color = colors[6],linewidth=2)\n",
    "line_6, = plt.plot(avg_RMS_Loss_l3_R[start:end]+1*flag, color = colors[5],linewidth=2)\n",
    "line_5, = plt.plot(avg_RMS_Loss_l3[start:end]+1*flag, color = colors[4],linewidth=2)\n",
    "\n",
    "line_4, = plt.plot(limit1[start:end], color = colors[2],linewidth=2)\n",
    "\n",
    "line_8, = plt.plot(limit3[start:end], color = colors[7],linewidth=2)\n",
    "\n",
    "\n",
    "plt.legend([line_1, line_2,line_3,line_4,line_5,line_6,line_7,line_8],['RMSpropW (' + '$\\gamma = $'+str(l1)+')','RMSpropW R (' + '$\\gamma = $'+str(l1)+')','RMSpropW NR (' + '$\\gamma = $'+str(l1)+')','Theor. Pred. (' + '$\\gamma = $'+str(l1)+')','RMSpropW (' + '$\\gamma = $'+str(l3)+')','RMSpropW R (' + '$\\gamma = $'+str(l3)+')','RMSpropW NR (' + '$\\gamma = $'+str(l3)+')','Theor. Pred. (' + '$\\gamma = $'+str(l3)+')'],fontsize=14)\n",
    "plt.title(\"Losses\",fontsize=25)\n",
    "plt.xlabel('Iterations',fontsize=25)\n",
    "plt.ylabel('Loss',fontsize=25)\n",
    "num_ticks = 6  # Adjust this number as needed\n",
    "xticks_positions = np.linspace(1, nit, num_ticks, dtype=int)\n",
    "plt.xticks(xticks_positions, fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.yscale('log')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
