{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7GVWcjh65iGv"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "from scipy.linalg import sqrtm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b8aBVIt_8qHf"
   },
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.rcParams[\"figure.figsize\"] = (7,5)\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette(\"colorblind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wLEakSJI8qOy"
   },
   "outputs": [],
   "source": [
    "# Define custom dash patterns for each line without names\n",
    "custom_dashes = [\n",
    "    (1, 0),           # Solid\n",
    "    (1, 1),           # Dotted\n",
    "    (4, 2),           # Dashed\n",
    "    (3, 2, 1, 2),     # Dash-Dot\n",
    "    (5, 2, 1, 2, 1, 2),  # Dash-Dot-Dot\n",
    "    (8, 2),           # Long Dashes\n",
    "    (2, 4),           # Loosely Dashed\n",
    "    (5, 4, 1, 4),     # Sparse Dash-Dot\n",
    "    (6, 4, 1, 4, 1, 4),  # Sparse Dash-Dot-Dot\n",
    "    (1, 4),           # Loosely Dotted\n",
    "    (4, 2, 1, 2, 1, 2),  # Dashed with Dots\n",
    "    (2, 4, 1, 4, 2, 4)   # Custom Pattern 1\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9iedPQJn5jvq"
   },
   "outputs": [],
   "source": [
    "def cost(x): return 0.5*x.T@H@x\n",
    "def g(x): return H@x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oVZOW6-UeAee"
   },
   "outputs": [],
   "source": [
    "def Adam_Traj(eta, x0, nit, sigma,beta_1,beta_2,eps, l,seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "  noise = sigma*np.random.normal(size=(d,))\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  m = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0]+noise)*g(x[0]+noise)\n",
    "  m[0] = g(x[0]+noise)\n",
    "  for k in range(nit-1):\n",
    "\n",
    "    gamma_1 = 1/(1 - beta_1**(k+1))\n",
    "    gamma_2 = 1/(1 - beta_2**(k+1))\n",
    "\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    m[k+1] = beta_1*m[k]+(1-beta_1)*(g(x[k])+noise)\n",
    "    v[k+1] = beta_2*v[k]+(1-beta_2)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k] - eta*(gamma_1*m[k+1])/(np.sqrt(gamma_2*v[k+1])+eps) - eta*l*x[k]\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n",
    "def RMS_Traj(eta, x0, nit, sigma,beta,eps, l,seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0])*g(x[0])\n",
    "  for k in range(nit-1):\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    v[k+1] = beta*v[k]+(1-beta)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k]*(1-l*eta) - eta*(g(x[k])+noise)/(np.sqrt(v[k])+eps)\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ySznf5idOUf8"
   },
   "outputs": [],
   "source": [
    "def Adam_Trajs(eta, x0, nit, sigma,beta_1,beta_2,eps,l, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = Adam_Traj(eta, x0, nit+1, sigma,beta_1,beta_2,eps,l, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)\n",
    "\n",
    "\n",
    "def RMS_Trajs(eta, x0, nit, sigma,beta,eps,l, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = RMS_Traj(eta, x0, nit+1, sigma,beta,eps,l, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K121PaKMDbr6"
   },
   "outputs": [],
   "source": [
    "d = 1\n",
    "landa = 1\n",
    "H = np.array([[landa]])\n",
    "eta = 0.001\n",
    "sigma = 1\n",
    "eps = 1e-7\n",
    "\n",
    "\n",
    "\n",
    "T=100\n",
    "nit= int(T/eta)\n",
    "\n",
    "mu1=100\n",
    "mu2=1\n",
    "mu=1\n",
    "\n",
    "\n",
    "l1 = 1\n",
    "\n",
    "sigma1=1e-2\n",
    "sigma2=1e-1\n",
    "sigma3=1e0\n",
    "sigma4=1e1\n",
    "sigma5=1e3\n",
    "\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "x0 = 1*np.random.normal(size=(d,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4l7r9_Bk-0Di"
   },
   "outputs": [],
   "source": [
    "limit1 = 0.25*eta*np.ones((nit,))*landa*sigma1/(sigma1*l1+landa)\n",
    "limit2 = 0.25*eta*np.ones((nit,))*landa*sigma2/(sigma2*l1+landa)\n",
    "limit3 = 0.25*eta*np.ones((nit,))*landa*sigma3/(sigma3*l1+landa)\n",
    "limit4 = 0.25*eta*np.ones((nit,))*landa*sigma4/(sigma4*l1+landa)\n",
    "limit5 = 0.25*eta*np.ones((nit,))*landa*sigma5/(sigma5*l1+landa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PB0Pzkm0FuKV"
   },
   "outputs": [],
   "source": [
    "seeds = [i for i in np.arange(100)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xERLGH0VOp7C",
    "outputId": "a0ea6bb7-d9f2-44ca-db0e-3d838219a562",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Adam_Trajs_1       = Adam_Trajs   (eta, x0, nit, sigma1, 1-mu1*eta, 1-mu2*eta,  eps         ,l1, seeds)\n",
    "Adam_Trajs_2       = Adam_Trajs   (eta, x0, nit, sigma2, 1-mu1*eta, 1-mu2*eta,  eps         ,l1, seeds)\n",
    "Adam_Trajs_3       = Adam_Trajs   (eta, x0, nit, sigma3, 1-mu1*eta, 1-mu2*eta,  eps         ,l1, seeds)\n",
    "Adam_Trajs_4       = Adam_Trajs   (eta, x0, nit, sigma4, 1-mu1*eta, 1-mu2*eta,  eps         ,l1, seeds)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Adam_Trajs_5       = Adam_Trajs   (eta, x0, nit, sigma5, 1-mu1*eta, 1-mu2*eta,  eps         ,l1, seeds)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tQgz62A-O6s1"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj_1 = np.mean(Adam_Trajs_1[0],axis=0)\n",
    "avg_Adam_Loss_1 = np.mean(Adam_Trajs_1[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_2 = np.mean(Adam_Trajs_2[0],axis=0)\n",
    "avg_Adam_Loss_2 = np.mean(Adam_Trajs_2[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_3 = np.mean(Adam_Trajs_3[0],axis=0)\n",
    "avg_Adam_Loss_3 = np.mean(Adam_Trajs_3[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_4 = np.mean(Adam_Trajs_4[0],axis=0)\n",
    "avg_Adam_Loss_4 = np.mean(Adam_Trajs_4[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj_5 = np.mean(Adam_Trajs_5[0],axis=0)\n",
    "avg_Adam_Loss_5 = np.mean(Adam_Trajs_5[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = 0\n",
    "end = -1\n",
    "\n",
    "flag = 0\n",
    "\n",
    "# Set the overall figure size\n",
    "plt.rcParams[\"figure.figsize\"] = (3.5, 5)\n",
    "\n",
    "plt.figure()\n",
    "# We plot\n",
    "line_1, = plt.plot(avg_Adam_Loss_1[start:end] + 1 * flag, color=colors[0], linewidth=2)\n",
    "line_2, = plt.plot(avg_Adam_Loss_2[start:end] + 1 * flag, color=colors[1], linewidth=2)\n",
    "line_3, = plt.plot(avg_Adam_Loss_3[start:end] + 1 * flag, color=colors[2], linewidth=2)\n",
    "line_4, = plt.plot(avg_Adam_Loss_4[start:end] + 1 * flag, color=colors[3], linewidth=2)\n",
    "line_5, = plt.plot(avg_Adam_Loss_5[start:end] + 1 * flag, color=colors[4], linewidth=2)\n",
    "\n",
    "line_6, = plt.plot(limit1[start:end] + 1 * flag, color=colors[5], linewidth=2)\n",
    "line_7, = plt.plot(limit2[start:end] + 1 * flag, color=colors[6], linewidth=2)\n",
    "line_8, = plt.plot(limit3[start:end] + 1 * flag, color=colors[7], linewidth=2)\n",
    "line_9, = plt.plot(limit4[start:end] + 1 * flag, color=colors[8], linewidth=2)\n",
    "line_10, = plt.plot(limit5[start:end] + 1 * flag, color=colors[9], linewidth=2)\n",
    "\n",
    "# Create the legend\n",
    "plt.legend([line_1, line_2, line_3, line_4, line_5, line_6, line_7, line_8, line_9, line_10], [\n",
    "    'AdamW (' + '$\\sigma = $' + str(sigma1) + ')',\n",
    "    'AdamW (' + '$\\sigma = $' + str(sigma2) + ')',\n",
    "    'AdamW (' + '$\\sigma = $' + str(sigma3) + ')',\n",
    "    'AdamW (' + '$\\sigma = $' + str(sigma4) + ')',\n",
    "    'AdamW (' + '$\\sigma = $' + str(sigma5) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma1) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma2) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma3) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma4) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma5) + ')'\n",
    "], bbox_to_anchor=(1.05, 1), loc='upper left',fontsize=14)\n",
    "\n",
    "plt.title(\"Losses\", fontsize=25)\n",
    "plt.xlabel('Iterations', fontsize=25)\n",
    "plt.ylabel('Loss', fontsize=25)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.yscale('log')\n",
    "\n",
    "# Save the figure\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
