{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7GVWcjh65iGv"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "from scipy.linalg import sqrtm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b8aBVIt_8qHf"
   },
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.rcParams[\"figure.figsize\"] = (7,5)\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette(\"colorblind\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wLEakSJI8qOy"
   },
   "outputs": [],
   "source": [
    "# Define custom dash patterns for each line without names\n",
    "custom_dashes = [\n",
    "    (1, 0),           # Solid\n",
    "    (1, 1),           # Dotted\n",
    "    (4, 2),           # Dashed\n",
    "    (3, 2, 1, 2),     # Dash-Dot\n",
    "    (5, 2, 1, 2, 1, 2),  # Dash-Dot-Dot\n",
    "    (8, 2),           # Long Dashes\n",
    "    (2, 4),           # Loosely Dashed\n",
    "    (5, 4, 1, 4),     # Sparse Dash-Dot\n",
    "    (6, 4, 1, 4, 1, 4),  # Sparse Dash-Dot-Dot\n",
    "    (1, 4),           # Loosely Dotted\n",
    "    (4, 2, 1, 2, 1, 2),  # Dashed with Dots\n",
    "    (2, 4, 1, 4, 2, 4)   # Custom Pattern 1\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9iedPQJn5jvq"
   },
   "outputs": [],
   "source": [
    "def cost(x): return 0.5*x.T@H@x\n",
    "def g(x): return H@x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oVZOW6-UeAee"
   },
   "outputs": [],
   "source": [
    "def SignSGD_Traj(eta, x0, nit, sigma,seed):\n",
    "  x = np.zeros((nit, d))\n",
    "  f = np.zeros((nit,))\n",
    "  x[0] = x0\n",
    "  f[0] = cost(x[0])\n",
    "  for k in range(nit-1):\n",
    "    noise = sigma*np.random.normal(size=(d,))\n",
    "    x[k+1] = x[k] - eta*np.sign((g(x[k])+noise))\n",
    "    f[k+1] = cost(x[k+1])\n",
    "\n",
    "  return (x, f.reshape(f.shape[-1],1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ySznf5idOUf8"
   },
   "outputs": [],
   "source": [
    "def SignSGD_Trajs(eta, x0, nit, sigma, seeds):\n",
    "\n",
    "  d = np.shape(x0)[0]\n",
    "  x = np.zeros((len(seeds), nit+1,d))\n",
    "  f = np.zeros((len(seeds), nit+1,1))\n",
    "\n",
    "  i=0\n",
    "  for seed in seeds:\n",
    "    print(seed)\n",
    "    x[i,:,:], f[i,:,:]  = SignSGD_Traj(eta, x0, nit+1, sigma, seed)\n",
    "    i += 1\n",
    "\n",
    "  return (x, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K121PaKMDbr6"
   },
   "outputs": [],
   "source": [
    "d = 1\n",
    "landa = 1\n",
    "H = np.array([[landa]])\n",
    "eta = 0.001\n",
    "sigma = 1\n",
    "eps = 1e-7\n",
    "\n",
    "\n",
    "\n",
    "T = 100\n",
    "nit= int(T/eta)\n",
    "\n",
    "\n",
    "sigma1=1e-2\n",
    "sigma2=1e-1\n",
    "sigma3=1e0\n",
    "sigma4=1e1\n",
    "sigma5=1e3\n",
    "\n",
    "\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "x0 = 1*np.random.normal(size=(d,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def delta(sigma):\n",
    "    return np.sqrt(2/np.pi)*(1/sigma) + eta*landa/(np.pi*sigma**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4l7r9_Bk-0Di"
   },
   "outputs": [],
   "source": [
    "limit1 = 0.25*np.ones((nit,))*eta/delta(sigma1)\n",
    "limit2 = 0.25*np.ones((nit,))*eta/delta(sigma2)\n",
    "limit3 = 0.25*np.ones((nit,))*eta/delta(sigma3)\n",
    "limit4 = 0.25*np.ones((nit,))*eta/delta(sigma4)\n",
    "limit5 = 10*0.25*np.ones((nit,))*eta/delta(sigma5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PB0Pzkm0FuKV"
   },
   "outputs": [],
   "source": [
    "seeds = [i for i in np.arange(100)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xERLGH0VOp7C",
    "outputId": "a0ea6bb7-d9f2-44ca-db0e-3d838219a562",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "SGD_Trajs_1       = SignSGD_Trajs   (eta, x0, nit, sigma1, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_SignSGD_Traj_1 = np.mean(SGD_Trajs_1[0],axis=0)\n",
    "avg_SignSGD_Loss_1 = np.mean(SGD_Trajs_1[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "SGD_Trajs_2       = SignSGD_Trajs   (eta, x0, nit, sigma2, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_SignSGD_Traj_2 = np.mean(SGD_Trajs_2[0],axis=0)\n",
    "avg_SignSGD_Loss_2 = np.mean(SGD_Trajs_2[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "SGD_Trajs_3       = SignSGD_Trajs   (eta, x0, nit, sigma3, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_SignSGD_Traj_3 = np.mean(SGD_Trajs_3[0],axis=0)\n",
    "avg_SignSGD_Loss_3 = np.mean(SGD_Trajs_3[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "SGD_Trajs_4       = SignSGD_Trajs   (eta, x0, nit, sigma4, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_SignSGD_Traj_4 = np.mean(SGD_Trajs_4[0],axis=0)\n",
    "avg_SignSGD_Loss_4 = np.mean(SGD_Trajs_4[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "SGD_Trajs_5       = SignSGD_Trajs   (10*eta, x0, 10*nit, sigma5, seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_SignSGD_Traj_5 = np.mean(SGD_Trajs_5[0],axis=0)\n",
    "avg_SignSGD_Loss_5 = np.mean(SGD_Trajs_5[1],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "teDDHrLk-Rkq"
   },
   "outputs": [],
   "source": [
    "start = 0\n",
    "end = -1\n",
    "\n",
    "flag = 0\n",
    "\n",
    "# Set the overall figure size\n",
    "plt.rcParams[\"figure.figsize\"] = (3.5, 5)\n",
    "\n",
    "plt.figure()\n",
    "# We plot\n",
    "line_1, = plt.plot(avg_SignSGD_Loss_1[start:end] + 1 * flag, color=colors[0], linewidth=2)\n",
    "line_2, = plt.plot(avg_SignSGD_Loss_2[start:end] + 1 * flag, color=colors[1], linewidth=2)\n",
    "line_3, = plt.plot(avg_SignSGD_Loss_3[start:end] + 1 * flag, color=colors[2], linewidth=2)\n",
    "line_4, = plt.plot(avg_SignSGD_Loss_4[start:end] + 1 * flag, color=colors[3], linewidth=2)\n",
    "line_5, = plt.plot(avg_SignSGD_Loss_5[start:end][::10] + 1 * flag, color=colors[4], linewidth=2)\n",
    "\n",
    "line_6, = plt.plot(limit1[start:end] + 1 * flag, color=colors[5], linewidth=2)\n",
    "line_7, = plt.plot(limit2[start:end] + 1 * flag, color=colors[6], linewidth=2)\n",
    "line_8, = plt.plot(limit3[start:end] + 1 * flag, color=colors[7], linewidth=2)\n",
    "line_9, = plt.plot(limit4[start:end] + 1 * flag, color=colors[8], linewidth=2)\n",
    "line_10, = plt.plot(limit5[start:end] + 1 * flag, color=colors[9], linewidth=2)\n",
    "\n",
    "# Create the legend\n",
    "plt.legend([line_1, line_2, line_3, line_4, line_5, line_6, line_7, line_8, line_9, line_10], [\n",
    "    'SignSGD (' + '$\\sigma = $' + str(sigma1) + ')',\n",
    "    'SignSGD (' + '$\\sigma = $' + str(sigma2) + ')',\n",
    "    'SignSGD (' + '$\\sigma = $' + str(sigma3) + ')',\n",
    "    'SignSGD (' + '$\\sigma = $' + str(sigma4) + ')',\n",
    "    'SignSGD (' + '$\\sigma = $' + str(sigma5) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma1) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma2) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma3) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma4) + ')',\n",
    "    'Limit (' + '$\\sigma = $' + str(sigma5) + ')'\n",
    "], bbox_to_anchor=(1.05, 1), loc='upper left',fontsize=14)\n",
    "\n",
    "plt.title(\"Losses\", fontsize=25)\n",
    "plt.xlabel('Iterations', fontsize=25)\n",
    "plt.ylabel('Loss', fontsize=25)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.yscale('log')\n",
    "\n",
    "# Save the figure\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
