{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7GVWcjh65iGv"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "from scipy.linalg import sqrtm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b8aBVIt_8qHf"
   },
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.rcParams[\"figure.figsize\"] = (7,5)\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette(\"colorblind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8QGBVJirAqOW"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wLEakSJI8qOy"
   },
   "outputs": [],
   "source": [
    "# Define custom dash patterns for each line without names\n",
    "custom_dashes = [\n",
    "    (1, 0),           # Solid\n",
    "    (1, 1),           # Dotted\n",
    "    (4, 2),           # Dashed\n",
    "    (3, 2, 1, 2),     # Dash-Dot\n",
    "    (5, 2, 1, 2, 1, 2),  # Dash-Dot-Dot\n",
    "    (8, 2),           # Long Dashes\n",
    "    (2, 4),           # Loosely Dashed\n",
    "    (5, 4, 1, 4),     # Sparse Dash-Dot\n",
    "    (6, 4, 1, 4, 1, 4),  # Sparse Dash-Dot-Dot\n",
    "    (1, 4),           # Loosely Dotted\n",
    "    (4, 2, 1, 2, 1, 2),  # Dashed with Dots\n",
    "    (2, 4, 1, 4, 2, 4)   # Custom Pattern 1\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9iedPQJn5jvq"
   },
   "outputs": [],
   "source": [
    "def cost(x): return 0.5*x.T@H@x + 0.25*landa*(np.sum(x**4)) - c*(1/3)*landa*(np.sum(x**3))\n",
    "def g(x): return H@x + landa*x**3 - c*landa*x**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oVZOW6-UeAee"
   },
   "outputs": [],
   "source": [
    "def Adam_Traj(eta, x0, nit, sigma,beta_1,beta_2,eps, seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "  noise = sigma*np.random.normal(size=(d,))\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  v = np.zeros((nit, d))\n",
    "  m = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  v[0] = g(x[0]+noise)*g(x[0]+noise)\n",
    "  m[0] = g(x[0]+noise)\n",
    "  for k in range(nit-1):\n",
    "\n",
    "    gamma_1 = 1/(1 - beta_1**(k+1))\n",
    "    gamma_2 = 1/(1 - beta_2**(k+1))\n",
    "\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    m[k+1] = beta_1*m[k]+(1-beta_1)*(g(x[k])+noise)\n",
    "    v[k+1] = beta_2*v[k]+(1-beta_2)*(g(x[k])+noise)*(g(x[k])+noise)\n",
    "    x[k+1] = x[k]*(1-eta*reg) - eta*(gamma_1*m[k+1])/(np.sqrt(gamma_2*v[k+1])+eps)\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ySznf5idOUf8"
   },
   "outputs": [],
   "source": [
    "def Adam_Trajs(eta, x0, nit, sigma,beta_1,beta_2,eps, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = Adam_Traj(eta, x0, nit+1, sigma,beta_1,beta_2,eps, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K121PaKMDbr6"
   },
   "outputs": [],
   "source": [
    "d = 2\n",
    "H = np.array([[-1.0, 0.0],[0.0, 2.0]])\n",
    "eta = 0.001\n",
    "sigma = 0.1\n",
    "eps = 1e-7\n",
    "\n",
    "\n",
    "T=50\n",
    "nit= int(T/eta)\n",
    "\n",
    "reg = 0.1\n",
    "\n",
    "landa = 1.0\n",
    "c = 1\n",
    "\n",
    "mu11 = 1\n",
    "mu12 = 100\n",
    "mu21 = 2\n",
    "mu22 = 4\n",
    "mu23 = 8\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "x0 = -5*np.random.normal(size=(d,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PB0Pzkm0FuKV"
   },
   "outputs": [],
   "source": [
    "seeds = [i for i in np.arange(5)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xERLGH0VOp7C",
    "outputId": "182c06d2-ffb8-444e-9e97-e0c2883d58a1"
   },
   "outputs": [],
   "source": [
    "Adam_Trajs_11       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu11*eta, 1-mu21*eta,  eps         , seeds)\n",
    "Adam_Trajs_12       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu11*eta, 1-mu22*eta,  eps         , seeds)\n",
    "Adam_Trajs_13       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu11*eta, 1-mu23*eta,  eps         , seeds)\n",
    "Adam_Trajs_21       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu12*eta, 1-mu21*eta,  eps         , seeds)\n",
    "Adam_Trajs_22       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu12*eta, 1-mu22*eta,  eps         , seeds)\n",
    "Adam_Trajs_23       = Adam_Trajs   (eta                , x0, nit, sigma               , 1-mu12*eta, 1-mu23*eta,  eps         , seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tQgz62A-O6s1"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj_11 = np.mean(Adam_Trajs_11[0],axis=0)\n",
    "avg_Adam_Loss_11 = np.mean(Adam_Trajs_11[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "69jFjN_1eSix"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj_12 = np.mean(Adam_Trajs_12[0],axis=0)\n",
    "avg_Adam_Loss_12 = np.mean(Adam_Trajs_12[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aPAc7JVDeSlJ"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj_13 = np.mean(Adam_Trajs_13[0],axis=0)\n",
    "avg_Adam_Loss_13 = np.mean(Adam_Trajs_13[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "J_wmv0RAeSnX"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj_21 = np.mean(Adam_Trajs_21[0],axis=0)\n",
    "avg_Adam_Loss_21 = np.mean(Adam_Trajs_21[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4G_-MVZgeSps"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj_22 = np.mean(Adam_Trajs_22[0],axis=0)\n",
    "avg_Adam_Loss_22 = np.mean(Adam_Trajs_22[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7Y2ojXMveSsI"
   },
   "outputs": [],
   "source": [
    "avg_Adam_Traj_23 = np.mean(Adam_Trajs_23[0],axis=0)\n",
    "avg_Adam_Loss_23 = np.mean(Adam_Trajs_23[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 486
    },
    "id": "GQzbZMZM5vh8",
    "outputId": "4c5376ce-8859-4f4a-87a3-f7b632e7703c"
   },
   "outputs": [],
   "source": [
    "start = 0\n",
    "end = -1\n",
    "\n",
    "\n",
    "min = np.min(avg_Adam_Loss_11) - 1e-5\n",
    "\n",
    "plt.figure()\n",
    "# We plot\n",
    "line_1, = plt.plot(avg_Adam_Loss_11[start:end]- min, color = colors[0],linewidth=2)\n",
    "line_2, = plt.plot(avg_Adam_Loss_12[start:end]- min, color = colors[1],linewidth=2)\n",
    "line_3, = plt.plot(avg_Adam_Loss_13[start:end]- min, color = colors[2],linewidth=2)\n",
    "line_4, = plt.plot(avg_Adam_Loss_21[start:end]- min, color = colors[3],linewidth=2)\n",
    "line_5, = plt.plot(avg_Adam_Loss_22[start:end]- min, color = colors[4],linewidth=2)\n",
    "line_6, = plt.plot(avg_Adam_Loss_23[start:end]- min, color = colors[5],linewidth=2)\n",
    "\n",
    "lab11 = r'$\\beta_1 = $' + str(1-mu11*eta) + r', $\\beta_2 = $' + str(1-mu21*eta) \n",
    "lab12 = r'$\\beta_1 = $' + str(1-mu11*eta) + r', $\\beta_2 = $' + str(1-mu22*eta) \n",
    "lab13 = r'$\\beta_1 = $' + str(1-mu11*eta) + r', $\\beta_2 = $' + str(1-mu23*eta) \n",
    "lab21 = r'$\\beta_1 = $' + str(1-mu12*eta) + r', $\\beta_2 = $' + str(1-mu21*eta) \n",
    "lab22 = r'$\\beta_1 = $' + str(1-mu12*eta) + r', $\\beta_2 = $' + str(1-mu22*eta) \n",
    "lab23 = r'$\\beta_1 = $' + str(1-mu12*eta) + r', $\\beta_2 = $' + str(1-mu23*eta) \n",
    "\n",
    "\n",
    "plt.legend([line_1, line_2, line_3, line_4, line_5, line_6,], [lab11,lab12,lab13,lab21,lab22,lab23],fontsize=14)\n",
    "plt.title(\"Losses - AdamW\",fontsize=25)\n",
    "plt.xlabel('Iterations',fontsize=25)\n",
    "plt.ylabel('Loss',fontsize=25)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.yscale('log')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
