{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0a0d7513",
   "metadata": {},
   "source": [
    "\n",
    "## Quadratic-Link regression\n",
    "\n",
    "\n",
    "<p>\n",
    "$$ \\frac{1}{2}   \\sum_{i=1} \\left( \\sigma( w^\\top x_i )  - y_i     \\right)^2, $$\n",
    "</p>\n",
    "where $\\sigma(z) = (z)^{2}$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62372a47",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import HSS\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "kwargs = {'linewidth' : 7.0}\n",
    "font = {'weight' : 'normal',\n",
    "        'size'   : 28}\n",
    "matplotlib.rc('font', **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6371e535",
   "metadata": {},
   "outputs": [],
   "source": [
    "def quadratic_gradient(z):\n",
    "    g = 2*z\n",
    "    return g\n",
    "\n",
    "def quadratic_link(z):\n",
    "    return z**2\n",
    "\n",
    "def quadraticopt_grad(A, y, w, num):\n",
    "    # A is a data matrix, y is a vector of labels\n",
    "    Aw = np.dot(A,w)\n",
    "    g1 = quadratic_gradient( Aw )\n",
    "    g2 = g1.dot( np.diag( quadratic_link(Aw) - y  ) ).dot(A)\n",
    "    return g2 / num\n",
    "\n",
    "def quadratic_fval(A, y, w, num):\n",
    "    # A is a data matrix, y is a vector of labels\n",
    "    Aw = np.dot(A,w)\n",
    "    return np.sum( (quadratic_link(Aw) - y  )**2 ) / num\n",
    "\n",
    "def quadratic(w, func_only=False, grad_only=False):\n",
    "    if (func_only == True):\n",
    "        return quadratic_fval(A,y,w, num)\n",
    "    if (grad_only)== True:\n",
    "        return quadratic_grad(A, y, w, num)\n",
    "    return quadratic_fval(A,y,w, num), quadraticopt_grad(A, y, w, num)\n",
    "\n",
    "\n",
    "def GDupdate(wold, eta, grad):\n",
    "    w = wold - eta*grad\n",
    "    return w\n",
    "\n",
    "num, d  = 1000, 50\n",
    "A     = np.random.normal(0, 1, (num,d) )\n",
    "wstar = np.random.normal(0, 1, d)\n",
    "y     = quadratic_link( np.dot(A,wstar) ) \n",
    "#############################################################\n",
    "w_init = (10**-2)*np.random.normal(0, 1, d)\n",
    "iters = 600\n",
    "#------------------------------------------------------------\n",
    "#perform GD\n",
    "eta_candi   = [10**-3]\n",
    "func_arrGD  = np.zeros([len(eta_candi), iters+1])\n",
    "cpu_arrGD   = np.zeros([len(eta_candi), iters+1])\n",
    "w_GD = w_init\n",
    "\n",
    "for q in range(len(eta_candi)):\n",
    "    eta = eta_candi[q]\n",
    "    elarpsed = 0\n",
    "    func_arrGD[q,0] = quadratic(w_GD, func_only=True)\n",
    "    cpu_arrGD[q,0]  = elarpsed\n",
    "    for i in range(iters):\n",
    "        start = time.time()\n",
    "        w_GD = GDupdate(w_GD, eta, quadraticopt_grad(A,y,w_GD,num) )\n",
    "        func_arrGD[q,i+1] = quadratic(w_GD, func_only=True)\n",
    "        elarpsed += time.time() - start\n",
    "        cpu_arrGD[q,i+1]  = elarpsed                \n",
    "#########################################################\n",
    "idd_GD = np.argmin(np.argmin( func_arrGD[:,iters] ) )\n",
    "print('=== idd_GD:', idd_GD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7c62a31",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Continuizedupdate(v, w, z, rho, mu, L, A, y, num):\n",
    "    gamma  =  1./L\n",
    "    gamma1 =  1./np.sqrt(mu*L) \n",
    "    sta    =  np.exp( -(1.+rho)* np.sqrt( mu/L )* np.random.exponential(1,1)  )\n",
    "    tau    =  1./(1+rho)*(1. - sta )\n",
    "    tau1   =  rho*(1.- sta) / (rho + sta)\n",
    "#    print(tau)\n",
    "    v1 = w + tau*( z - w)  \n",
    "    grad = quadraticopt_grad(A,y,v1,num)\n",
    "    w1  = v1 - gamma*grad\n",
    "    z1  = z  + tau1*(v1-z) - gamma1 * grad\n",
    "    return v1, w1, z1\n",
    "#----------------------------------------------------------------------------\n",
    "L_candi   = [1000]\n",
    "mu_candi  = [100]\n",
    "rho_candi = [0.5]\n",
    "\n",
    "#-------------------------------------------------------------------------\n",
    "runs = 10\n",
    "func_arrConti  = np.zeros( [len(L_candi), iters+1, runs] )\n",
    "cpu_arrConti   = np.zeros( [len(L_candi), iters+1, runs] )\n",
    "for q in range(len(L_candi)):\n",
    "    print(\"q\",q)\n",
    "    L   =  L_candi[q]\n",
    "    mu  = mu_candi[q]\n",
    "    rho =rho_candi[q]\n",
    "    for r in range(runs):\n",
    "        v_Conti   = w_init\n",
    "        w_Conti   = w_init\n",
    "        z_Conti   = w_init\n",
    "        elarpsed = 0\n",
    "        func_arrConti[q,0,r] = quadratic(w_Conti, func_only=True)\n",
    "        cpu_arrConti[q,0,r]  = elarpsed\n",
    "        for i in range(iters):\n",
    "            start = time.time()\n",
    "            v_Conti, w_Conti, z_Conti = Continuizedupdate(v_Conti,w_Conti,z_Conti,rho,mu,L,A,y,num)\n",
    "            func_arrConti[q,i+1,r] = quadratic(w_Conti, func_only=True)\n",
    "            elarpsed += time.time() - start\n",
    "            cpu_arrConti[q,i+1,r]  = elarpsed                \n",
    "#print(dist_arrConti)\n",
    "func_Conti_mean = np.mean( func_arrConti, axis=2 )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0e8f97c",
   "metadata": {},
   "outputs": [],
   "source": [
    "idd_conti = np.argmin( func_Conti_mean[:,iters] ) \n",
    "print('=== idd_conti:', idd_conti)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe2a082e",
   "metadata": {},
   "outputs": [],
   "source": [
    "myoptions = HSS.AGD_options() \n",
    "#----------------------------------------------------------------------------\n",
    "L_candi   = [1000]\n",
    "mu_candi  = [100]\n",
    "rho_candi = [0.5]\n",
    "\n",
    "func_arrAcc  = []\n",
    "nfunc_arrAcc = []\n",
    "ngrad_arrAcc = []\n",
    "cpu_arrAcc   = []\n",
    "for q in range(len(L_candi)):\n",
    "    print(\"q\",q)    \n",
    "    L   =  L_candi[q]\n",
    "    mu  = mu_candi[q]\n",
    "    rho =rho_candi[q]\n",
    "    k_ACC, num_nostep_ACC, func_eval_ACC, grad_eval_ACC, fvals_ACC, w_ACC, num_gradcalls, num_funccalls, cpu_time = HSS.agd_strong(quadratic, rho, L, mu, w_init, myoptions) \n",
    "    \n",
    "    func_arrAcc.append( np.array(fvals_ACC) )\n",
    "    nfunc_arrAcc.append( np.array(num_funccalls) )\n",
    "    ngrad_arrAcc.append( np.array(num_gradcalls) )\n",
    "    cpu_arrAcc.append( np.array(cpu_time) )     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97cdcde3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fmins = np.zeros(len(L_candi))\n",
    "for i in range( len(L_candi)):\n",
    "    fmins[i] = func_arrAcc[i][-1]\n",
    "idd_Acc = np.argmin( fmins) \n",
    "print('=== idd_Acc:', idd_Acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fb0caad",
   "metadata": {},
   "outputs": [],
   "source": [
    "nums_r = func_arrAcc[idd_Acc].size\n",
    "shows = np.min( [iters+1, nums_r] )\n",
    "\n",
    "ts = np.arange( shows )\n",
    "print(ts.shape)\n",
    "#'-', '--', '-.', ':', '',\n",
    "plt.figure(figsize=(14, 8))\n",
    "plt.title('function value v.s. iteration')\n",
    "texa = \"GD\"\n",
    "plt.yscale( 'log' ) \n",
    "plt.plot(ts, func_arrGD[idd_GD,:shows], 'r-', label= texa, linestyle='--', **kwargs)\n",
    "#=============================================================================\n",
    "#compute mean and std\n",
    "func_Conti_mean = np.mean( func_arrConti, axis=2 )\n",
    "func_Conti_std  = np.std( func_arrConti, axis=2 )\n",
    "#=============================================================================\n",
    "texc = \"AGD\"\n",
    "plt.plot(ts, func_arrAcc[idd_Acc][:shows], 'b-', label= texc, linestyle=':', **kwargs)\n",
    "#=============================================================================\n",
    "texb = \"Continuized Acc.\"\n",
    "plt.plot(ts, func_Conti_mean[idd_conti,:shows], 'g-', label= texb, linestyle='-', **kwargs)\n",
    "plt.fill_between(ts, func_Conti_mean[idd_conti,:shows], func_Conti_mean[idd_conti,:shows] + func_Conti_std[idd_conti,:shows], color='gray', alpha=0.2)\n",
    "\n",
    "\n",
    "plt.legend()\n",
    "plt.ylabel('function value')\n",
    "plt.xlabel('iteration')\n",
    "\n",
    "myname = \"quadratic\"\n",
    "hi = \"./figures/\" + myname + \"_iters.jpg\"\n",
    "plt.savefig(hi)\n",
    "\n",
    "plt.show();\n",
    "\n",
    "#=============================================================================\n",
    "\n",
    "plt.figure(figsize=(14, 8))\n",
    "plt.title('function value v.s. # gradient calls')\n",
    "ts = np.arange( iters+1 )\n",
    "texa = \"GD\"\n",
    "plt.yscale( 'log' ) \n",
    "plt.plot(ts, func_arrGD[idd_GD,:], 'r-', label= texa, linestyle='--', **kwargs)\n",
    "#=============================================================================\n",
    "texc = \"AGD\"\n",
    "plt.plot(ngrad_arrAcc[idd_Acc], func_arrAcc[idd_Acc], 'b-', label= texc, linestyle=':', **kwargs)\n",
    "#=============================================================================\n",
    "texb = \"Continuized Acc.\"\n",
    "plt.plot(ts, func_Conti_mean[idd_conti,:], 'g-', label= texb, linestyle='-', **kwargs)\n",
    "plt.fill_between(ts, func_Conti_mean[idd_conti,:], func_Conti_mean[idd_conti,:] + func_Conti_std[idd_conti,:], color='gray', alpha=0.2)\n",
    "plt.xlim([0, 150])\n",
    "\n",
    "plt.legend()\n",
    "plt.ylabel('function value')\n",
    "plt.xlabel('# gradient calls')\n",
    "\n",
    "myname = \"quadratic\"\n",
    "hi = \"./figures/\" + myname + \"_calls.jpg\"\n",
    "plt.savefig(hi)\n",
    "\n",
    "plt.show();\n",
    "\n",
    "\n",
    "#=============================================================================\n",
    "\n",
    "plt.figure(figsize=(14, 8))\n",
    "plt.title('function value v.s. cpu time (Sec.)')\n",
    "texa = \"GD\"\n",
    "plt.yscale( 'log' ) \n",
    "plt.plot(cpu_arrGD[idd_GD,:], func_arrGD[idd_GD,:], 'r-', label= texa, linestyle='--', **kwargs)\n",
    "#=============================================================================\n",
    "#compute mean and std\n",
    "cpu_Conti_mean = np.mean( cpu_arrConti, axis=2 )\n",
    "#=============================================================================\n",
    "texc = \"AGD\"\n",
    "plt.plot(cpu_arrAcc[idd_Acc], func_arrAcc[idd_Acc], 'b-', label= texc, linestyle=':', **kwargs)\n",
    "#=============================================================================\n",
    "texb = \"Continuized Acc.\"\n",
    "plt.plot(cpu_Conti_mean[idd_conti,:], func_Conti_mean[idd_conti,:], 'g-', label= texb, linestyle='-', **kwargs)\n",
    "plt.fill_between(cpu_Conti_mean[idd_conti,:], func_Conti_mean[idd_conti,:], func_Conti_mean[idd_conti,:] + func_Conti_std[idd_conti,:], color='gray', alpha=0.2)\n",
    "\n",
    "\n",
    "plt.legend()\n",
    "plt.ylabel('function value')\n",
    "plt.xlabel('cpu time (Sec.)')\n",
    "plt.xlim([0, 0.2])\n",
    "\n",
    "myname = \"quadratic\"\n",
    "hi = \"./figures/\" + myname + \"_cpu.jpg\"\n",
    "plt.savefig(hi)\n",
    "\n",
    "plt.show();\n",
    "\n",
    "#=============================================================================\n",
    "\n",
    "plt.figure(figsize=(14, 8))\n",
    "plt.title('function value v.s. # gradient + function calls')\n",
    "ts = np.arange( iters+1 )\n",
    "texa = \"GD\"\n",
    "plt.yscale( 'log' ) \n",
    "plt.plot(ts, func_arrGD[idd_GD,:], 'r-', label= texa, linestyle='--', **kwargs)\n",
    "#=============================================================================\n",
    "texc = \"AGD\"\n",
    "plt.plot(ngrad_arrAcc[idd_Acc]+nfunc_arrAcc[idd_Acc], func_arrAcc[idd_Acc], 'b-', label= texc, linestyle=':', **kwargs)\n",
    "#=============================================================================\n",
    "texb = \"Continuized Acc.\"\n",
    "plt.plot(ts, func_Conti_mean[idd_conti,:], 'g-', label= texb, linestyle='-', **kwargs)\n",
    "plt.fill_between(ts, func_Conti_mean[idd_conti,:], func_Conti_mean[idd_conti,:] + func_Conti_std[idd_conti,:], color='gray', alpha=0.2)\n",
    "plt.xlim([0, iters])\n",
    "\n",
    "plt.legend()\n",
    "plt.ylabel('function value')\n",
    "plt.xlabel('# gradient + function calls')\n",
    "\n",
    "myname = \"quadratic\"\n",
    "hi = \"./figures/\" + myname + \"_fgcalls.jpg\"\n",
    "plt.savefig(hi)\n",
    "\n",
    "plt.show();\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
