{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7GVWcjh65iGv"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "from scipy.linalg import sqrtm\n",
    "from numpy import linalg as LA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.rcParams[\"figure.figsize\"] = (7,5)\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette(\"colorblind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define custom dash patterns for each line without names\n",
    "custom_dashes = [\n",
    "    (1, 0),           # Solid\n",
    "    (1, 1),           # Dotted\n",
    "    (4, 2),           # Dashed\n",
    "    (3, 2, 1, 2),     # Dash-Dot\n",
    "    (5, 2, 1, 2, 1, 2),  # Dash-Dot-Dot\n",
    "    (8, 2),           # Long Dashes\n",
    "    (2, 4),           # Loosely Dashed\n",
    "    (5, 4, 1, 4),     # Sparse Dash-Dot\n",
    "    (6, 4, 1, 4, 1, 4),  # Sparse Dash-Dot-Dot\n",
    "    (1, 4),           # Loosely Dotted\n",
    "    (4, 2, 1, 2, 1, 2),  # Dashed with Dots\n",
    "    (2, 4, 1, 4, 2, 4)   # Custom Pattern 1\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9iedPQJn5jvq"
   },
   "outputs": [],
   "source": [
    "def cost(x): return 0.5*x.T@H@x\n",
    "def g(x): return H@x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oVZOW6-UeAee"
   },
   "outputs": [],
   "source": [
    "def Adam_Traj(eta, x0, nit, sigma, seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "  noise = sigma*np.random.normal(size=(d,))\n",
    "  np.random.seed(seed)\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  m = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0]+noise)*g(x[0]+noise)\n",
    "  m[0] = g(x[0]+noise)\n",
    "    \n",
    "  for k in range(nit-1):\n",
    "    gamma_1 = 1/(1 - beta_1**(k+2))\n",
    "    gamma_2 = 1/(1 - beta_2**(k+1))\n",
    "\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    m[k+1] = beta_1*m[k]+(1-beta_1)*(g(x[k])+noise)\n",
    "    v[k+1] = beta_2*v[k]+(1-beta_2)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k]*(1-eta*reg) - eta*(gamma_1*m[k+1])/(np.sqrt(gamma_2*v[k+1])+eps)\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n",
    "def RMS_Traj(eta, x0, nit, sigma, seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "  noise = sigma*np.random.normal(size=(d,))\n",
    "  np.random.seed(seed)\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0]+noise)*g(x[0]+noise)\n",
    "  \n",
    "  for k in range(nit-1):\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    v[k+1] = beta*v[k]+(1-beta)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k]*(1-eta*reg) - eta*(g(x[k])+noise)/(np.sqrt(v[k+1])+eps)\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jyuC11NvEIBq"
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def AdamSDE_Traj(eta, x0, nit, sigma, seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "  noise = sigma*np.random.normal(size=(d,))\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  m = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0]+noise)*g(x[0]+noise)\n",
    "  m[0] = g(x[0]+noise)\n",
    "  np.random.seed(seed)\n",
    "  for k in range(nit-1):\n",
    "\n",
    "    noise_wm = np.random.normal(size=(d,))\n",
    "\n",
    "    gamma_1 = 1 - np.exp(-mu1*eta*(k+1))\n",
    "    gamma_2 = 1 - np.exp(-mu2*eta*(k+1))\n",
    "\n",
    "\n",
    "    x[k+1] = x[k]*(1-eta*reg) - eta*(np.sqrt(gamma_2)/gamma_1)*(m[k] + eta*mu1*(g(x[k])-m[k]))/(np.sqrt(v[k])  + eps)\n",
    "    v[k+1] = v[k] + mu2*eta*( (g(x[k]))*(g(x[k])) + sigma**2 - v[k]  )\n",
    "    m[k+1] = m[k] + mu1*eta*( g(x[k]) - m[k]  ) + eta*mu1*sigma*noise_wm\n",
    "\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x,  f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n",
    "def RMSSDE_Traj(eta, x0, nit, sigma, seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "  noise = sigma*np.random.normal(size=(d,))\n",
    "  np.random.seed(seed)\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0]+noise)*g(x[0]+noise)\n",
    "  for k in range(nit-1):\n",
    "    noise_x = np.random.normal(size=(d,))\n",
    "\n",
    "    invP = np.linalg.inv(np.diag(np.sqrt(v[k])  + eps ))\n",
    "    x[k+1] = x[k]*(1-eta*reg) - eta*invP@(g(x[k]))-eta*sigma*invP@(noise_x)\n",
    "    v[k+1] = v[k] + mu*eta*( (g(x[k]))*(g(x[k])) + sigma**2 - v[k]  )\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x,  f.reshape(f.shape[-1],1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ySznf5idOUf8"
   },
   "outputs": [],
   "source": [
    "def AdamSDE_Trajs(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x_sde = np.zeros((len(seeds), nit+1,d))\n",
    "  f_sde = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x_sde[i,:,:], f_sde[i,:,:]  = AdamSDE_Traj(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x_sde, f_sde)\n",
    "\n",
    "\n",
    "def RMSSDE_Trajs(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x_sde = np.zeros((len(seeds), nit+1,d))\n",
    "  f_sde = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x_sde[i,:,:], f_sde[i,:,:]  = RMSSDE_Traj(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x_sde, f_sde)\n",
    "\n",
    "\n",
    "\n",
    "def Adam_Trajs(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = Adam_Traj(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)\n",
    "\n",
    "\n",
    "def RMS_Trajs(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = RMS_Traj(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K121PaKMDbr6"
   },
   "outputs": [],
   "source": [
    "d = 2\n",
    "H = np.array([[10, 0.6],[0.6, 2]])\n",
    "eta = 0.001\n",
    "sigma = 1\n",
    "eps = 1e-8\n",
    "\n",
    "T= 50\n",
    "nit= int(T/eta)\n",
    "nit_alg = nit\n",
    "nit_sde = nit\n",
    "mu1 = 1\n",
    "mu2 = 1\n",
    "mu = 1\n",
    "reg=0.1\n",
    "\n",
    "beta_1 = 1-eta*mu1\n",
    "beta_2 = 1-eta*mu2\n",
    "beta = 1-eta*mu\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "x0 = 1.0*np.random.normal(size=(d,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PB0Pzkm0FuKV"
   },
   "outputs": [],
   "source": [
    "seeds = [i for i in np.arange(100)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DpL6hyoyOp02",
    "outputId": "f59a148c-1f11-4b88-d582-d1041ad1c3fd",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "AdamSDE_Trajs_ = AdamSDE_Trajs(eta, x0, nit, sigma, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "RMSSDE_Trajs_ = RMSSDE_Trajs(eta, x0, nit, sigma, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Adam_Trajs_ = Adam_Trajs(eta, x0, nit, sigma, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xERLGH0VOp7C",
    "outputId": "3a1ede4f-f3b9-4b37-fc91-b7c2212e92b2",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "RMS_Trajs_ = RMS_Trajs(eta, x0, nit, sigma, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tQgz62A-O6s1"
   },
   "outputs": [],
   "source": [
    "avg_AdamSDE_Traj = np.mean(AdamSDE_Trajs_[0],axis=0)\n",
    "avg_AdamSDE_Loss = np.mean(AdamSDE_Trajs_[1],axis=0)\n",
    "\n",
    "std_AdamSDE_Traj = np.std(AdamSDE_Trajs_[0],axis=0)\n",
    "std_AdamSDE_Loss = np.std(AdamSDE_Trajs_[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_RMSSDE_Traj = np.mean(RMSSDE_Trajs_[0],axis=0)\n",
    "avg_RMSSDE_Loss = np.mean(RMSSDE_Trajs_[1],axis=0)\n",
    "\n",
    "std_RMSSDE_Traj = np.std(RMSSDE_Trajs_[0],axis=0)\n",
    "std_RMSSDE_Loss = np.std(RMSSDE_Trajs_[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_Adam_Traj = np.mean(Adam_Trajs_[0],axis=0)\n",
    "avg_Adam_Loss = np.mean(Adam_Trajs_[1],axis=0)\n",
    "\n",
    "std_Adam_Traj = np.std(Adam_Trajs_[0],axis=0)\n",
    "std_Adam_Loss = np.std(Adam_Trajs_[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "y_2nfBUYO7IG"
   },
   "outputs": [],
   "source": [
    "avg_RMS_Traj = np.mean(RMS_Trajs_[0],axis=0)\n",
    "avg_RMS_Loss = np.mean(RMS_Trajs_[1],axis=0)\n",
    "\n",
    "std_RMS_Traj = np.std(RMS_Trajs_[0],axis=0)\n",
    "std_RMS_Loss = np.std(RMS_Trajs_[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 471
    },
    "id": "GQzbZMZM5vh8",
    "outputId": "96d8073b-a828-44c8-9ea8-cb6a12f50aba"
   },
   "outputs": [],
   "source": [
    "start = 0\n",
    "end = -1\n",
    "\n",
    "min = np.min(avg_Adam_Loss) - 1e-3\n",
    "\n",
    "\n",
    "fact = 0.3\n",
    "\n",
    "plt.figure()\n",
    "# We plot\n",
    "\n",
    "\n",
    "line_3, = plt.plot(avg_RMS_Loss[start:end]-min, color = colors[3],linewidth=4)\n",
    "\n",
    "plt.fill_between( np.arange(0, nit),(avg_RMS_Loss[start:end]-min - fact*std_RMS_Loss[start:end]).flatten(), (avg_RMS_Loss[start:end]-min + fact*std_RMS_Loss[start:end]).flatten(),\n",
    "                 color=colors[3], alpha=0.4)\n",
    "\n",
    "\n",
    "line_4, = plt.plot(avg_RMSSDE_Loss[start:end]-min, color = colors[1],linewidth=2, dashes=custom_dashes[2])\n",
    "\n",
    "\n",
    "plt.fill_between( np.arange(0, nit),(avg_RMSSDE_Loss[start:end]-min - fact*std_RMSSDE_Loss[start:end]).flatten(), (avg_RMSSDE_Loss[start:end]-min + fact*std_RMSSDE_Loss[start:end]).flatten(),\n",
    "                 color=colors[1], alpha=0.4)\n",
    "\n",
    "line_1, = plt.plot(avg_Adam_Loss[start:end]-min, color = colors[4],linewidth=4)\n",
    "\n",
    "plt.fill_between( np.arange(0, nit),(avg_Adam_Loss[start:end]-min - fact*std_Adam_Loss[start:end]).flatten(), (avg_Adam_Loss[start:end]-min + fact*std_Adam_Loss[start:end]).flatten(),\n",
    "                 color=colors[4], alpha=0.4)\n",
    "\n",
    "\n",
    "line_2, = plt.plot(avg_AdamSDE_Loss[start:end]-min, color = colors[2],linewidth=2, dashes=custom_dashes[2])\n",
    "\n",
    "\n",
    "plt.fill_between( np.arange(0, nit),(avg_AdamSDE_Loss[start:end]-min - fact*std_AdamSDE_Loss[start:end]).flatten(), (avg_AdamSDE_Loss[start:end]-min + fact*std_AdamSDE_Loss[start:end]).flatten(),\n",
    "                 color=colors[2], alpha=0.4)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.legend([line_1, line_2, line_3, line_4], ['AdamW','AdamW SDE','RMSpropW','RMSpropW SDE'],fontsize=14)\n",
    "plt.title(\"Losses\",fontsize=25)\n",
    "plt.xlabel('Iterations',fontsize=25)\n",
    "plt.ylabel('Loss',fontsize=25)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.yscale('log')\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HaAoI8e8Z_jc"
   },
   "outputs": [],
   "source": [
    "max_x = np.max(np.abs(avg_Adam_Traj[:,0]))*1.1\n",
    "max_y = np.max(np.abs(avg_Adam_Traj[:,1]))*1.1\n",
    "x_values = np.linspace(-max_x, max_x, 100)\n",
    "y_values = np.linspace(-max_y, max_y, 100)\n",
    "X, Y = np.meshgrid(x_values, y_values)\n",
    "Z = np.zeros_like(X)\n",
    "\n",
    "for i in range(X.shape[0]):\n",
    "    for j in range(X.shape[1]):\n",
    "        Z[i, j] = cost(np.array([X[i, j], Y[i, j]]))\n",
    "plt.contour(X, Y, Z, cmap='coolwarm')  # Adjust the colormap ('cmap') as needed\n",
    "plt.colorbar()  # Add a colorbar to show the values\n",
    "line_1, = plt.plot(avg_Adam_Traj[:,0], avg_Adam_Traj[:,1], color = colors[4],linewidth=4)\n",
    "line_2, = plt.plot(avg_AdamSDE_Traj[:,0], avg_AdamSDE_Traj[:,1], color = colors[2],linewidth=3, dashes=custom_dashes[2])\n",
    "\n",
    "line_3, = plt.plot(avg_RMS_Traj[:,0], avg_RMS_Traj[:,1], color = colors[3],linewidth=4)\n",
    "line_4, = plt.plot(avg_RMSSDE_Traj[:,0], avg_RMSSDE_Traj[:,1], color = colors[1],linewidth=3, dashes=custom_dashes[2])\n",
    "\n",
    "\n",
    "plt.xlim([-max_x, max_x])\n",
    "plt.ylim([-max_y, max_y])\n",
    "\n",
    "plt.legend([line_1, line_2, line_3, line_4], ['AdamW','AdamW SDE','RMSpropW','RMSpropW SDE'])\n",
    "\n",
    "plt.title(\"Trajectories\",fontsize=25)\n",
    "plt.xlabel(r'$X_1$',fontsize=25)\n",
    "plt.ylabel(r'$X_2$',fontsize=25)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
