{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import statements\n",
    "import numpy as np\n",
    "import numpy.linalg as la\n",
    "import matplotlib.pyplot as plt\n",
    "from fbm import FBM\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some functions needed later\n",
    "def sample_fbm_perturbations(fracbm,N,dim,var_adj=0):   # sample perturbations with variance sigma^2 = 1\n",
    "    incrms = np.zeros([N,dim])\n",
    "    for j in range(dim):\n",
    "        fbmsample = fracbm.fbm() # create fBM object\n",
    "        if var_adj: # variance adjustment on?\n",
    "            H = fracbm.hurst\n",
    "            fbmsample = N**(-H) * fbmsample # rescaled sample\n",
    "        incrms[:,j] = fbmsample[1:] - fbmsample[:-1]\n",
    "    return incrms\n",
    "\n",
    "def GD(f,nabla_f,x0,eta,N_iter):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i])\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def PGD(f,nabla_f,x0,eta,sig,N_iter,var_adj=0):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    if var_adj == 0: perturbs_bm = np.random.normal(0,1,(N_iter-1)*dim).reshape([N_iter-1,dim])\n",
    "    if var_adj == 1: \n",
    "        H = 1/2 # BM Hurst parameter\n",
    "        perturbs_bm = N_iter**(-H) * np.random.normal(0,1,(N_iter-1)*dim).reshape([N_iter-1,dim])\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_bm[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def antiPGD(f,nabla_f,x0,eta,sig,N_iter):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    rvs = np.random.normal(0,1,N_iter*dim).reshape([N_iter,dim])\n",
    "    perturbs_wn = rvs[1:] - rvs[:-1]\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_wn[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def fPGD(f,nabla_f,x0,eta,sig,H,N_iter,var_adj=0):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    fbmobj = FBM(n=N_iter, hurst=H, length=N_iter, method='daviesharte')\n",
    "    perturbs_fbm = sample_fbm_perturbations(fbmobj,N_iter,dim,var_adj=var_adj)\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_fbm[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# params objective function\n",
    "l = 30.\n",
    "m = 18.\n",
    "v0 = 0.\n",
    "v2 = -12.5\n",
    "k0 = 5.\n",
    "k2 = 5.\n",
    "\n",
    "# define objective function\n",
    "a = np.sqrt(2*(l-v0)/k0)\n",
    "c = m - np.sqrt(2*(l-v2)/k2)\n",
    "b = (a+c)/2\n",
    "v1 = 1.5*l\n",
    "k1 = 8*(v1-l)/(c-a)**2\n",
    "def f(x):\n",
    "    if x<a:\n",
    "        return v0 + (1/2) * k0 * x**2\n",
    "    elif x<c:\n",
    "        return v1 - (1/2) * k1 * (x-b)**2\n",
    "    else:\n",
    "        return v2 + (1/2) * k2 * (x-m)**2\n",
    "def nabla_f(x):\n",
    "    if x<a:\n",
    "        return k0*x\n",
    "    elif x<c:\n",
    "        return -k1*(x-b)\n",
    "    else:\n",
    "        return k2*(x-m)\n",
    "d = 1\n",
    "N_sim = 1000\n",
    "N_iter = 1000\n",
    "eta = 0.45\n",
    "var_adj = 1 # Here we turn variance adjustment on to highlight the dependence on H, by fixing the final variance for all H. Variance adjusted plots are obtained by setting '0' here.\n",
    "if var_adj == 0: sig = 2.\n",
    "if var_adj == 1: sig = 0.8 * np.sqrt(10) * N_iter**(1/2)\n",
    "x0 = np.array([0.])\n",
    "Hs = [0.3,0.4,0.5,0.6,0.7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the loss landscape; upper left in Figure 2\n",
    "xmin = -4.\n",
    "xmax = 23\n",
    "xs = np.arange(xmin,xmax,0.1)\n",
    "ys = np.zeros(xs.shape)\n",
    "line_offset = 0\n",
    "for i in range(len(xs)):\n",
    "    ys[i] = f(xs[i])\n",
    "plt.plot(xs,ys,linewidth=2.5)\n",
    "plt.plot(0,0,marker='x',color='k',markersize=10)\n",
    "plt.axvline(a + line_offset,color='k',linestyle=\"--\")\n",
    "plt.axvline(c - line_offset,color='k',linestyle=\"--\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### storage\n",
    "xs_gd = np.zeros([N_sim,N_iter,d])\n",
    "ys_gd = np.zeros([N_sim,N_iter])\n",
    "xs_bm = np.zeros([N_sim,N_iter,d])\n",
    "ys_bm = np.zeros([N_sim,N_iter])\n",
    "xs_wn = np.zeros([N_sim,N_iter,d])\n",
    "ys_wn = np.zeros([N_sim,N_iter])\n",
    "xs_fbm = np.zeros([len(Hs),N_sim,N_iter,d])\n",
    "ys_fbm = np.zeros([len(Hs),N_sim,N_iter])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run the fPGD simulations...\n",
    "for sim_idx in range(N_sim):\n",
    "    # xs_gd[sim_idx],ys_gd[sim_idx] = GD(f,nabla_f,x0,eta,N_iter)\n",
    "    xs_bm[sim_idx],ys_bm[sim_idx] = PGD(f,nabla_f,x0,eta,sig,N_iter,var_adj=var_adj)\n",
    "    # xs_wn[sim_idx],ys_wn[sim_idx] = antiPGD(f,nabla_f,x0,eta,sig,N_iter)\n",
    "    for idx,H in enumerate(Hs):\n",
    "        xs_fbm[idx,sim_idx], ys_fbm[idx,sim_idx] = fPGD(f,nabla_f,x0,eta,sig,H,N_iter,var_adj=var_adj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show one oscillation for standard BM (equal in distribution to fBM plot with H=0 below)\n",
    "sim_idx = 0 # change sim_idx to get a different simulation, range (0,N_sim=1000)\n",
    "plt.plot(range(N_iter),xs_bm[sim_idx],color='c')\n",
    "plt.axhline(y=a+line_offset, color=\"black\", linestyle=\"--\")\n",
    "plt.axhline(y=c-line_offset, color=\"black\", linestyle=\"--\")\n",
    "plt.title('BM')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show oscillations for fBM, for all H in Hs; plots lower row Figure 2\n",
    "for idx,H in enumerate(Hs):\n",
    "    plt.title('fBM, H=%.1f'%H)\n",
    "    plt.plot(range(N_iter),xs_fbm[idx,sim_idx],color='c')\n",
    "    plt.xlabel('iterations ($k$)',fontsize=15)\n",
    "    if idx == 0: plt.ylabel('$f(x_k)',fontsize=15)\n",
    "    plt.title('fPGD (H=%.1f)'%H,fontsize=20,fontweight='bold')\n",
    "    plt.axhline(y=a+line_offset, color=\"black\", linestyle=\"--\")\n",
    "    plt.axhline(y=c-line_offset, color=\"black\", linestyle=\"--\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# post-processing for cumulative density function (cdf) of the first-exit times (fets)\n",
    "in_good_min = xs_fbm > c \n",
    "in_bad_min = xs_fbm < a\n",
    "fets = np.zeros([len(Hs),N_sim])\n",
    "for H_idx in range(len(Hs)):\n",
    "    for sim_idx in range(N_sim):\n",
    "        fets[H_idx,sim_idx] = np.argmax(in_good_min[H_idx,sim_idx])\n",
    "fets = fets + N_iter*(fets == 0) # overwrite value 0 to N_iter, because at 0 it hasn't exited"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plots of cdf of fets, upper right, Figure 2\n",
    "plt.xscale('log')\n",
    "plt.xlabel('iterations ($k$)',fontsize=15)    \n",
    "for H_idx in range(len(Hs)):\n",
    "    cdf_fets = np.zeros([N_iter])\n",
    "    for step in range(N_iter):\n",
    "        cdf_fets[step] = np.sum(fets[H_idx] <= step)/N_sim\n",
    "    plt.axhline(y=1, color=\"black\", linestyle=\"--\")\n",
    "    plt.plot(range(1,N_iter+1),cdf_fets,label='H=%.1f'%Hs[H_idx],color=colors[H_idx])\n",
    "plt.legend(fontsize=15,loc='center right')\n",
    "plt.title('first-exit cdfs of fPGD',fontsize=20,fontweight='bold')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
