{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### this is the plot for Section 5.2, Figure 3\n",
    "# import statements\n",
    "import numpy as np\n",
    "import numpy.linalg as la\n",
    "import matplotlib.pyplot as plt\n",
    "from fbm import FBM\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some functions needed later\n",
    "def sample_fbm_perturbations(fracbm,N,dim,var_adj=0):   # sample perturbations with variance sigma^2 = 1\n",
    "    incrms = np.zeros([N,dim])\n",
    "    for j in range(dim):\n",
    "        fbmsample = fracbm.fbm() # create fBM object\n",
    "        if var_adj: # variance adjustment on?\n",
    "            H = fracbm.hurst\n",
    "            fbmsample = N**(-H) * fbmsample # rescaled sample\n",
    "        incrms[:,j] = fbmsample[1:] - fbmsample[:-1]\n",
    "    return incrms\n",
    "\n",
    "def GD(f,nabla_f,x0,eta,N_iter):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i])\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def PGD(f,nabla_f,x0,eta,sig,N_iter,var_adj=0):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    if var_adj == 0: perturbs_bm = np.random.normal(0,1,(N_iter-1)*dim).reshape([N_iter-1,dim])\n",
    "    if var_adj == 1: \n",
    "        H = 1/2 # BM Hurst parameter\n",
    "        perturbs_bm = N_iter**(-H) * np.random.normal(0,1,(N_iter-1)*dim).reshape([N_iter-1,dim])\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_bm[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def antiPGD(f,nabla_f,x0,eta,sig,N_iter):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    rvs = np.random.normal(0,1,N_iter*dim).reshape([N_iter,dim])\n",
    "    perturbs_wn = rvs[1:] - rvs[:-1]\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_wn[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def fPGD(f,nabla_f,x0,eta,sig,H,N_iter,var_adj=0):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    fbmobj = FBM(n=N_iter, hurst=H, length=N_iter, method='daviesharte')\n",
    "    perturbs_fbm = sample_fbm_perturbations(fbmobj,N_iter,dim,var_adj=var_adj)\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_fbm[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 400\n",
    "flip_num = 10 # number of flipped points\n",
    "eigvals = np.sort(np.abs(np.random.normal(0,1,size=d))) # sample non-negative abs-values and sort them\n",
    "eigvals[:flip_num] = -eigvals[:flip_num]\n",
    "# define Hessian H as diagonal matrix\n",
    "Hess = np.diag(eigvals)\n",
    "# regularize these variables\n",
    "p = 4\n",
    "lam = 0.0001 # 0.001\n",
    "def f(x):\n",
    "    return 1/2*(x.T@Hess@x) + lam*np.sum(x**p)\n",
    "def nabla_f(x):\n",
    "    return Hess@x + lam*p*(x**(p-1))\n",
    "def Hess_f(x):\n",
    "    return Hess + lam*p*(p-1)*np.diag(x**(p-2))\n",
    "N_sim = 5\n",
    "N_iter = 100000\n",
    "eta = 0.005 \n",
    "sig = 0.005\n",
    "var_adj = 0 # in these experiments, we don't adjust the variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Hs = [0.1,0.2,0.3,0.4] # fBM values\n",
    "\n",
    "### storage\n",
    "xs_gd = np.zeros([N_sim,N_iter,d])\n",
    "ys_gd = np.zeros([N_sim,N_iter])\n",
    "xs_bm = np.zeros([N_sim,N_iter,d])\n",
    "ys_bm = np.zeros([N_sim,N_iter])\n",
    "xs_wn = np.zeros([N_sim,N_iter,d])\n",
    "ys_wn = np.zeros([N_sim,N_iter])\n",
    "xs_fbm = np.zeros([len(Hs),N_sim,N_iter,d])\n",
    "ys_fbm = np.zeros([len(Hs),N_sim,N_iter])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run simulations... (this might take a few hours. to reduce time, decrease N_iter)\n",
    "for sim_idx in range(N_sim):\n",
    "    x0 = np.zeros([d,])\n",
    "    xs_gd[sim_idx],ys_gd[sim_idx] = GD(f,nabla_f,x0,eta,N_iter)\n",
    "    xs_bm[sim_idx],ys_bm[sim_idx] = PGD(f,nabla_f,x0,eta,sig,N_iter,var_adj=var_adj)\n",
    "    xs_wn[sim_idx],ys_wn[sim_idx] = antiPGD(f,nabla_f,x0,eta,sig,N_iter)\n",
    "    for idx,H in enumerate(Hs):\n",
    "        print(sim_idx,idx)\n",
    "        xs_fbm[idx,sim_idx], ys_fbm[idx,sim_idx] = fPGD(f,nabla_f,x0,eta,sig,H,N_iter,var_adj=var_adj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### compute means of function values (by averaging over simulations)\n",
    "mean_y_gd = np.mean(ys_gd,axis=0)\n",
    "mean_y_bm = np.mean(ys_bm,axis=0)\n",
    "mean_y_wn = np.mean(ys_wn,axis=0)\n",
    "mean_y_fbm = np.mean(ys_fbm,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### plots\n",
    "plt.rc('xtick', labelsize=15) \n",
    "plt.rc('ytick', labelsize=15)\n",
    "\n",
    "plt.plot(range(1,N_iter+1),mean_y_gd,label='GD',color=colors[0])\n",
    "plt.plot(range(1,N_iter+1),mean_y_wn,label='H=0',color=colors[1])\n",
    "count = 2\n",
    "for idx,H in enumerate(Hs):\n",
    "    plt.plot(range(1,N_iter+1),mean_y_fbm[idx],label='H=%.1f'%H,color=colors[count + idx])\n",
    "    count += 1\n",
    "plt.plot(range(1,N_iter+1),mean_y_bm,label='H=0.5',color=colors[-1])\n",
    "plt.xlabel('iteration ($k$)',fontsize=15)\n",
    "plt.ylabel('$f(x_k)$',fontsize=15)\n",
    "plt.title('embedded saddle point',fontsize=15,fontweight='bold')\n",
    "plt.legend(fontsize=15)\n",
    "plt.show()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
