{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7GVWcjh65iGv"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.random.seed(666)\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "from scipy.linalg import sqrtm\n",
    "from scipy import special\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.rcParams[\"figure.figsize\"] = (7,5)\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette(\"colorblind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define custom dash patterns for each line without names\n",
    "custom_dashes = [\n",
    "    (1, 0),           # Solid\n",
    "    (1, 1),           # Dotted\n",
    "    (4, 2),           # Dashed\n",
    "    (3, 2, 1, 2),     # Dash-Dot\n",
    "    (5, 2, 1, 2, 1, 2),  # Dash-Dot-Dot\n",
    "    (8, 2),           # Long Dashes\n",
    "    (2, 4),           # Loosely Dashed\n",
    "    (5, 4, 1, 4),     # Sparse Dash-Dot\n",
    "    (6, 4, 1, 4, 1, 4),  # Sparse Dash-Dot-Dot\n",
    "    (1, 4),           # Loosely Dotted\n",
    "    (4, 2, 1, 2, 1, 2),  # Dashed with Dots\n",
    "    (2, 4, 1, 4, 2, 4)   # Custom Pattern 1\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "9iedPQJn5jvq",
    "outputId": "04d295ae-704f-48a8-e032-cb74052beee2"
   },
   "outputs": [],
   "source": [
    "\n",
    "def cost(x): return 0.5*x.T@H@x\n",
    "def g(x): return H@x\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "def cost(x): return 0.5*x.T@H@x + 0.25*landa*(np.sum(x**4))\n",
    "def g(x): return H@x + landa*x**3\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oVZOW6-UeAee"
   },
   "outputs": [],
   "source": [
    "def SDE_Traj(eta, x0, nit, sigma, seed):\n",
    "\n",
    "  ####### SDE\n",
    "\n",
    "  Id = np.identity(d)\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  #sde simulation\n",
    "  x_sde = np.zeros((nit, d))\n",
    "  f_sde = np.zeros((nit,))\n",
    "  x_sde[0] = x0\n",
    "  f_sde[0] = cost(x_sde[0])\n",
    "  for k in range(nit-1):\n",
    "    noise = np.random.normal(size=(d,))\n",
    "    ratio = g(x_sde[k])/(np.sqrt(2)*sigma)\n",
    "    x_sde[k+1] = x_sde[k] - eta*special.erf(ratio) + eta*sqrtm( Id-np.diag(special.erf(ratio)**2) )@noise\n",
    "    f_sde[k+1] = cost(x_sde[k+1])\n",
    "\n",
    "  return (x_sde, f_sde.reshape(f_sde.shape[-1],1))\n",
    "\n",
    "\n",
    "def SDE_Traj_2(eta, x0, nit, sigma, seed):\n",
    "\n",
    "  ####### SDE\n",
    "\n",
    "  Id = np.identity(d)\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  #sde simulation\n",
    "  x_sde = np.zeros((nit, d))\n",
    "  f_sde = np.zeros((nit,))\n",
    "  x_sde[0] = x0\n",
    "  f_sde[0] = cost(x_sde[0])\n",
    "  for k in range(nit-1):\n",
    "    noise = np.random.normal(size=(d,))\n",
    "    ratio = g(x_sde[k])/(np.sqrt(2)*sigma)\n",
    "    x_sde[k+1] = x_sde[k] - eta*np.sqrt(2/math.pi)*ratio + eta*sqrtm( Id-np.diag((2/math.pi)*ratio**2) )@noise\n",
    "    f_sde[k+1] = cost(x_sde[k+1])\n",
    "\n",
    "  return (x_sde, f_sde.reshape(f_sde.shape[-1],1))\n",
    "\n",
    "\n",
    "def Algo_Traj(eta, x0, nit, sigma, seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  for k in range(nit-1):\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    x[k+1] = x[k] - eta*np.sign(g(x[k])+noise)\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))\n",
    "\n",
    "def SGD_Traj(eta, x0, nit, sigma, seed):\n",
    "\n",
    "  np.random.seed(seed)\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  for k in range(nit-1):\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    x[k+1] = x[k] - eta*(g(x[k])+noise)\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AieQmYGYZwIs"
   },
   "outputs": [],
   "source": [
    "def ODE_Traj(eta, x0, nit):\n",
    "\n",
    "  ####### ODE\n",
    "\n",
    "  x = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  for k in range(nit-1):\n",
    "    x[k+1] = x[k] - eta*np.sign(g(x[k]))\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3dZGFZRonr1z"
   },
   "outputs": [],
   "source": [
    "def SDE_Trajs(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x_sde = np.zeros((len(seeds), nit+1,d))\n",
    "  f_sde = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x_sde[i,:,:], f_sde[i,:,:]  = SDE_Traj(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x_sde, f_sde)\n",
    "\n",
    "\n",
    "def SDE_Trajs_2(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x_sde = np.zeros((len(seeds), nit+1,d))\n",
    "  f_sde = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x_sde[i,:,:], f_sde[i,:,:]  = SDE_Traj_2(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x_sde, f_sde)\n",
    "\n",
    "\n",
    "def Algo_Trajs(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = Algo_Traj(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)\n",
    "\n",
    "\n",
    "def SGD_Trajs(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = SGD_Traj(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-aIp9cv7qind"
   },
   "outputs": [],
   "source": [
    "def SDE_TheoPred(nit, mu, sigma, eta, f0, Lt):\n",
    "  k = 2*mu*(np.sqrt(2/math.pi)*(1/sigma) + (eta/math.pi)*(mu/sigma**2))\n",
    "  t = np.arange(0,nit)*eta\n",
    "  return f0*np.exp(-k*t) + ((Lt*eta)/(2*k))*(1-np.exp(-k*t))\n",
    "\n",
    "\n",
    "def SDE_TheoPred_Quad(nit, mu, sigma, eta, f0):\n",
    "  k = 2*mu*(np.sqrt(2/math.pi)*(1/sigma) + (eta*mu)/(math.pi*sigma**2))\n",
    "  denom = np.sqrt(2/math.pi)*(1/sigma) + (eta*mu)/(math.pi*sigma**2)\n",
    "  t = np.arange(0,nit)*eta\n",
    "  return f0*np.exp(-k*t) + 0.25*(eta/denom)*(1-np.exp(-k*t))\n",
    "\n",
    "def ODE_TheoPred(nit, mu, f0):\n",
    "  t = np.arange(0,nit)*eta\n",
    "  return 0.25*(np.sqrt(2*mu)*t - 2*np.sqrt(f0))**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_FKnz0lfnykB"
   },
   "outputs": [],
   "source": [
    "d = 1\n",
    "H = np.array([[2.0]])\n",
    "sigma = 0.1\n",
    "eps = 1e-8\n",
    "landa = 1.0\n",
    "\n",
    "mu = 2.0\n",
    "\n",
    "eta = 0.001\n",
    "T   = 3\n",
    "nit = int(T/eta)\n",
    "\n",
    "x0 = 1.0*np.random.normal(size=(d,))\n",
    "f0 = cost(x0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9xwBwPUZnymn"
   },
   "outputs": [],
   "source": [
    "seeds = [i for i in np.arange(500)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0ZpSj88ZqScn",
    "outputId": "ca6aa6c1-5326-490a-8086-b0ca93abcb12"
   },
   "outputs": [],
   "source": [
    "SDE_Trajs_ = SDE_Trajs(eta, x0, nit, sigma, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "hDjLCXFM-RRq",
    "outputId": "14bf1578-23ae-4d8c-cb9f-1e6ac04abea4"
   },
   "outputs": [],
   "source": [
    "Algo_Trajs_ = Algo_Trajs(eta, x0, nit, sigma, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "k4r_2DRaoliL"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vxZxgNfJ_GjE"
   },
   "outputs": [],
   "source": [
    "avg_SDE_Traj = np.mean(SDE_Trajs_[0],axis=0)\n",
    "avg_SDE_Loss = np.mean(SDE_Trajs_[1],axis=0)\n",
    "\n",
    "std_SDE_Traj = np.std(SDE_Trajs_[0],axis=0)\n",
    "std_SDE_Loss = np.std(SDE_Trajs_[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vd_fT6iJ67PL"
   },
   "outputs": [],
   "source": [
    "avg_Algo_Traj = np.mean(Algo_Trajs_[0],axis=0)\n",
    "avg_Algo_Loss = np.mean(Algo_Trajs_[1],axis=0)\n",
    "\n",
    "std_Algo_Traj = np.std(Algo_Trajs_[0],axis=0)\n",
    "std_Algo_Loss = np.std(Algo_Trajs_[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "r2UPOPrqpUaQ"
   },
   "outputs": [],
   "source": [
    "L = 780\n",
    "\n",
    "ODE_TheoPred_ = ODE_TheoPred(nit, mu, f0)\n",
    "SDE_TheoPred_ = SDE_TheoPred_Quad(nit+1, mu, sigma, eta, avg_Algo_Loss[L])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 471
    },
    "id": "GQzbZMZM5vh8",
    "outputId": "60a73a64-4092-4abe-a8ed-090c6ef2237a"
   },
   "outputs": [],
   "source": [
    "start = 0\n",
    "end = -1\n",
    "\n",
    "x = [i for i in np.arange(nit)]\n",
    "\n",
    "plt.figure()\n",
    "# We plot\n",
    "line_1, = plt.plot(x,avg_Algo_Loss[start:end], color = 'red',linewidth=4)\n",
    "line_2, = plt.plot(x,avg_SDE_Loss[start:end], color = 'blue',linewidth=2)\n",
    "line_3, = plt.plot(x[L:end],SDE_TheoPred_[start:nit-L-1], color = 'green',linewidth=2)\n",
    "line_4, = plt.plot(x[start:L],ODE_TheoPred_[start:L], color = 'pink',linewidth=2)\n",
    "\n",
    "plt.legend([line_1, line_2,line_3, line_4], ['Real','SDE Ours', 'Theor Prediction SDE', 'Theor Prediction ODE'])\n",
    "plt.title(\"Losses\",fontsize=15)\n",
    "plt.xlabel('Iterations',fontsize=15)\n",
    "plt.ylabel('Loss',fontsize=15)\n",
    "plt.xticks(fontsize=15)\n",
    "plt.yticks(fontsize=15)\n",
    "plt.yscale('log')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
