{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7GVWcjh65iGv"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "from scipy.linalg import sqrtm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b8aBVIt_8qHf"
   },
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.rcParams[\"figure.figsize\"] = (7,5)\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette(\"colorblind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wLEakSJI8qOy"
   },
   "outputs": [],
   "source": [
    "# Define custom dash patterns for each line without names\n",
    "custom_dashes = [\n",
    "    (1, 0),           # Solid\n",
    "    (1, 1),           # Dotted\n",
    "    (4, 2),           # Dashed\n",
    "    (3, 2, 1, 2),     # Dash-Dot\n",
    "    (5, 2, 1, 2, 1, 2),  # Dash-Dot-Dot\n",
    "    (8, 2),           # Long Dashes\n",
    "    (2, 4),           # Loosely Dashed\n",
    "    (5, 4, 1, 4),     # Sparse Dash-Dot\n",
    "    (6, 4, 1, 4, 1, 4),  # Sparse Dash-Dot-Dot\n",
    "    (1, 4),           # Loosely Dotted\n",
    "    (4, 2, 1, 2, 1, 2),  # Dashed with Dots\n",
    "    (2, 4, 1, 4, 2, 4)   # Custom Pattern 1\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9iedPQJn5jvq"
   },
   "outputs": [],
   "source": [
    "def cost(x): return 0.5*x.T@H@x\n",
    "def g(x): return H@x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oVZOW6-UeAee"
   },
   "outputs": [],
   "source": [
    "def Adam_Traj(eta, x0, nit, sigma,beta_1,beta_2,eps, l,seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "  noise = sigma*np.random.normal(size=(d,))\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  m = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0]+noise)*g(x[0]+noise)\n",
    "  m[0] = g(x[0]+noise)\n",
    "  for k in range(nit-1):\n",
    "\n",
    "    gamma_1 = 1/(1 - beta_1**(k+1))\n",
    "    gamma_2 = 1/(1 - beta_2**(k+1))\n",
    "\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    m[k+1] = beta_1*m[k]+(1-beta_1)*(g(x[k])+noise)\n",
    "    v[k+1] = beta_2*v[k]+(1-beta_2)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k] - eta*(gamma_1*m[k+1])/(np.sqrt(gamma_2*v[k+1])+eps) - eta*l*x[k]\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n",
    "def RMS_Traj(eta, x0, nit, sigma,beta,eps, l,seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0])*g(x[0])\n",
    "  for k in range(nit-1):\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    v[k+1] = beta*v[k]+(1-beta)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k]*(1-l*eta) - eta*(g(x[k])+noise)/(np.sqrt(v[k])+eps)\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ySznf5idOUf8"
   },
   "outputs": [],
   "source": [
    "def Adam_Trajs(eta, x0, nit, sigma,beta_1,beta_2,eps,l, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = Adam_Traj(eta, x0, nit+1, sigma,beta_1,beta_2,eps,l, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)\n",
    "\n",
    "\n",
    "def RMS_Trajs(eta, x0, nit, sigma,beta,eps,l, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = RMS_Traj(eta, x0, nit+1, sigma,beta,eps,l, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K121PaKMDbr6"
   },
   "outputs": [],
   "source": [
    "d = 2\n",
    "landa1 = 1.0\n",
    "landa2 = 3.0\n",
    "H = np.array([[landa1, 0.0],[0.0, landa2]])\n",
    "eta = 0.001\n",
    "sigma = 1\n",
    "eps = 1e-7\n",
    "\n",
    "T=20\n",
    "nit= int(T/eta)\n",
    "\n",
    "mu = 1\n",
    "mu1 = 100\n",
    "mu2 = 1\n",
    "\n",
    "l = 4\n",
    "\n",
    "beta_1 = 1-eta*mu1\n",
    "beta_2 = 1-eta*mu2\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "x0 = 1*np.array([1.0,1.0])\n",
    "\n",
    "beta = 1-eta*mu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Sc6ljKPp9wTM"
   },
   "outputs": [],
   "source": [
    "var1 = (eta*sigma)/(2*(landa1+ sigma*l))\n",
    "var2 = (eta*sigma)/(2*(landa2+ sigma*l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PB0Pzkm0FuKV"
   },
   "outputs": [],
   "source": [
    "seeds = [i for i in np.arange(1000)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xERLGH0VOp7C",
    "outputId": "56ca768e-4a11-4adc-8296-4653f580de7e",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "RMSprop_Trajs_    = RMS_Trajs(eta                , x0, nit, sigma               , 1-mu*eta            ,  eps ,l        , seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Adam_Trajs_       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu1*eta, 1-mu2*eta,  eps ,l        , seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tQgz62A-O6s1"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj = np.mean(Adam_Trajs_[0],axis=0)\n",
    "var_Adam_Traj = np.var(Adam_Trajs_[0],axis=0)\n",
    "avg_Adam_Loss = np.mean(Adam_Trajs_[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "v6vLO8II6_pF"
   },
   "outputs": [],
   "source": [
    "avg_RMSprop_Traj = np.mean(RMSprop_Trajs_[0],axis=0)\n",
    "var_RMSprop_Traj = np.var(RMSprop_Trajs_[0],axis=0)\n",
    "avg_RMSprop_Loss = np.mean(RMSprop_Trajs_[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a figure and a set of subplots\n",
    "fig, axs = plt.subplots(1, 2)  # 1 row, 2 columns\n",
    "\n",
    "L2 = 20000\n",
    "L1 = 20000\n",
    "\n",
    "# First subplot\n",
    "line_1, = axs[0].plot(var_Adam_Traj[-L1:-1,0], color = colors[0],linewidth=3, dashes=custom_dashes[0])\n",
    "line_2, = axs[0].plot(var_RMSprop_Traj[-L2:-1,0], color = colors[1],linewidth=3, dashes=custom_dashes[0])\n",
    "line_3, = axs[0].plot(var1*np.ones((nit,)), color = colors[2],linewidth=3, dashes=custom_dashes[3])\n",
    "axs[0].legend([line_1, line_2,line_3], ['AdamW', 'RMSpropW', 'Theor. Pred.'], fontsize=13)\n",
    "axs[0].set_ylabel(r'$Var\\left[X_1\\right]$', fontsize=25)\n",
    "axs[0].set_xlabel(r'$t$', fontsize=25)\n",
    "axs[0].set_yscale('log')\n",
    "axs[0].tick_params(axis='both', which='major', labelsize=16)\n",
    "\n",
    "# Second subplot\n",
    "line_1, = axs[1].plot(var_Adam_Traj[-L1:-1,1], color = colors[0],linewidth=3, dashes=custom_dashes[0])\n",
    "line_2, = axs[1].plot(var_RMSprop_Traj[-L2:-1,1], color = colors[1],linewidth=3, dashes=custom_dashes[0])\n",
    "line_3, = axs[1].plot(var2*np.ones((nit,)), color = colors[2],linewidth=3, dashes=custom_dashes[3])\n",
    "axs[1].legend([line_1, line_2,line_3], ['AdamW', 'RMSpropW', 'Theor. Pred.'], fontsize=13)\n",
    "axs[1].set_ylabel(r'$Var\\left[X_2\\right]$', fontsize=25)\n",
    "axs[1].set_xlabel(r'$t$', fontsize=25)\n",
    "axs[1].set_yscale('log')\n",
    "axs[1].tick_params(axis='both', which='major', labelsize=16)\n",
    "\n",
    "\n",
    "# Adjust layout and display the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
