{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import statements\n",
    "import numpy as np\n",
    "import numpy.linalg as la\n",
    "import matplotlib.pyplot as plt\n",
    "from fbm import FBM\n",
    "import seaborn as sns\n",
    "colors = sns.color_palette()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some functions needed below\n",
    "def sample_fbm_perturbations(fracbm,N,dim,var_adj=0):   # sample perturbations with variance sigma^2 = 1\n",
    "    incrms = np.zeros([N,dim])\n",
    "    for j in range(dim):\n",
    "        fbmsample = fracbm.fbm() # create fBM object\n",
    "        if var_adj: # variance adjustment on?\n",
    "            H = fracbm.hurst\n",
    "            fbmsample = N**(-H) * fbmsample # rescaled sample\n",
    "        incrms[:,j] = fbmsample[1:] - fbmsample[:-1]\n",
    "    return incrms\n",
    "\n",
    "def GD(f,nabla_f,x0,eta,N_iter):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i])\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def PGD(f,nabla_f,x0,eta,sig,N_iter,var_adj=0):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    if var_adj == 0: perturbs_bm = np.random.normal(0,1,(N_iter-1)*dim).reshape([N_iter-1,dim])\n",
    "    if var_adj == 1: \n",
    "        H = 1/2 # BM Hurst parameter\n",
    "        perturbs_bm = N_iter**(-H) * np.random.normal(0,1,(N_iter-1)*dim).reshape([N_iter-1,dim])\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_bm[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def antiPGD(f,nabla_f,x0,eta,sig,N_iter):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    rvs = np.random.normal(0,1,N_iter*dim).reshape([N_iter,dim])\n",
    "    perturbs_wn = rvs[1:] - rvs[:-1]\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_wn[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)\n",
    "\n",
    "def fPGD(f,nabla_f,x0,eta,sig,H,N_iter,var_adj=0):\n",
    "    dim = len(x0)\n",
    "    xs = np.zeros([N_iter,dim])\n",
    "    ys = np.zeros([N_iter])\n",
    "    xs[0] = x0\n",
    "    ys[0] = f(x0)\n",
    "    fbmobj = FBM(n=N_iter, hurst=H, length=N_iter, method='daviesharte')\n",
    "    perturbs_fbm = sample_fbm_perturbations(fbmobj,N_iter,dim,var_adj=var_adj)\n",
    "    for i in range(N_iter-1):\n",
    "        xs[i+1] = xs[i] - eta*nabla_f(xs[i]) + sig * perturbs_fbm[i]\n",
    "        ys[i+1] = f(xs[i+1])\n",
    "    return (xs,ys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the loss function: embbeded saddle point\n",
    "d = 400\n",
    "flip_num = 10 # number of flipped points\n",
    "eigvals = np.sort(np.abs(np.random.normal(0,1,size=d))) # sample non-negative abs-values and sort them\n",
    "eigvals[:flip_num] = -eigvals[:flip_num]\n",
    "# define Hessian H of saddle point, as diagonal matrix\n",
    "Hess = np.diag(eigvals)\n",
    "# regularize these variables\n",
    "p = 4\n",
    "lam = 0.0001\n",
    "def f(x):\n",
    "    return 1/2*(x.T@Hess@x) + lam*np.sum(x**p)\n",
    "def nabla_f(x):\n",
    "    return Hess@x + lam*p*(x**(p-1))\n",
    "def Hess_f(x):\n",
    "    return Hess + lam*p*(p-1)*np.diag(x**(p-2))\n",
    "N_sim = 5\n",
    "N_iter = 100000 \n",
    "eta = 0.005\n",
    "var_adj = 0 # in this plot, we don't adjust the variance for all H "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the values for sigma that we are trying\n",
    "sigs = [0.05,0.005,0.0005,0.00005] \n",
    "\n",
    "# non-Markovian values of Hs\n",
    "Hs = [0.1,0.2,0.3,0.4] \n",
    "\n",
    "# storage \n",
    "xs_bm = np.zeros([len(sigs),N_sim,N_iter,d])\n",
    "ys_bm = np.zeros([len(sigs),N_sim,N_iter])\n",
    "xs_wn = np.zeros([len(sigs),N_sim,N_iter,d])\n",
    "ys_wn = np.zeros([len(sigs),N_sim,N_iter])\n",
    "xs_fbm = np.zeros([len(Hs),len(sigs),N_sim,N_iter,d])\n",
    "ys_fbm = np.zeros([len(Hs),len(sigs),N_sim,N_iter])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# running the simulation... (this might take a few hours. to reduce time, decrease N_iter)\n",
    "for sim_idx in range(N_sim):\n",
    "    x0 = np.zeros([d,])\n",
    "    # run Markovian cases\n",
    "    for sig_idx,sig in enumerate(sigs):\n",
    "        print(sim_idx,sig_idx)\n",
    "        xs_bm[sig_idx,sim_idx],ys_bm[sig_idx,sim_idx] = PGD(f,nabla_f,x0,eta,sig,N_iter,var_adj=var_adj)\n",
    "        xs_wn[sig_idx,sim_idx],ys_wn[sig_idx,sim_idx] = antiPGD(f,nabla_f,x0,eta,sig,N_iter)\n",
    "    # run non-Markovian cases\n",
    "    for idx,H in enumerate(Hs):\n",
    "        for sig_idx,sig in enumerate(sigs):\n",
    "            print(sim_idx,sig_idx)\n",
    "            xs_fbm[idx,sig_idx,sim_idx], ys_fbm[idx,sig_idx,sim_idx] = fPGD(f,nabla_f,x0,eta,sig,H,N_iter,var_adj=var_adj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### compute means of function values (by averaging over simulations)\n",
    "# mean_y_gd = np.mean(ys_gd,axis=0)\n",
    "mean_y_bm = np.mean(ys_bm,axis=1)\n",
    "mean_y_wn = np.mean(ys_wn,axis=1)\n",
    "mean_y_fbm = np.mean(ys_fbm,axis=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# arange mean_y_fbm, mean_y_bm, mean_y_wn into a new data structure, which makes things easier to plot\n",
    "y_values = np.zeros([len(Hs)+2,len(sigs),N_iter])\n",
    "y_values[1:-1] = mean_y_fbm\n",
    "y_values[0] = mean_y_wn\n",
    "y_values[-1] = mean_y_bm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('xtick', labelsize=15) \n",
    "plt.rc('ytick', labelsize=15)\n",
    "\n",
    "hurst_parameters = [0.0,0.1,0.2,0.3,0.4,0.5]\n",
    "\n",
    "for H_idx,H in enumerate(hurst_parameters):\n",
    "    count = 0\n",
    "    for sig_idx,sig in enumerate(sigs):\n",
    "        plt.plot(range(1,N_iter+1),y_values[H_idx,sig_idx],label='$\\sigma$=%.2E'%sig,color=colors[count + sig_idx])\n",
    "        count += 1\n",
    "    plt.title('H=%.1f'%H,fontsize=15)\n",
    "    if H_idx//2 == 2:\n",
    "        plt.xlabel('iteration ($k$)',fontsize=15)\n",
    "    if H_idx%2 == 0:\n",
    "        plt.ylabel('$f(x_k)$',fontsize=15)\n",
    "    if H_idx//2 != 2:\n",
    "        plt.xticks([])\n",
    "    plt.ylim(bottom=-1.75,top=1.0)\n",
    "    if H_idx == 4: plt.legend(loc='lower left')\n",
    "    plt.show()\n",
    "    plt.clf()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
