{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import hypergrad as hg\n",
    "import numpy as np\n",
    "from sklearn.datasets import make_spd_matrix as spd\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No GPU found! Running on CPU.\n"
     ]
    }
   ],
   "source": [
    "use_cuda = False\n",
    "\n",
    "if torch.cuda.is_available() and use_cuda:\n",
    "    print(\"Num GPUs found: \", torch.cuda.device_count())\n",
    "    print(\"GPU name\", torch.cuda.get_device_name(0))\n",
    "    tensor_type = 'torch.cuda.FloatTensor'\n",
    "    cuda = True\n",
    "else:\n",
    "    print('No GPU found! Running on CPU.')\n",
    "    tensor_type = 'torch.FloatTensor'\n",
    "    cuda = False\n",
    "    \n",
    "torch.set_default_tensor_type(tensor_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float64\n"
     ]
    }
   ],
   "source": [
    "Q = torch.from_numpy(spd(n_dim=10, random_state=0))\n",
    "print(Q.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 10\n",
    "seed = 0\n",
    "alpha = .1\n",
    "\n",
    "Q = torch.from_numpy(spd(n_dim=d, random_state=seed)).float() # will be torch.float32 tensor \n",
    "\n",
    "def Nesterov_func(x):\n",
    "    \n",
    "    L = 10\n",
    "    return ((torch.sum((x[:-1] - x[1:]) ** 2) + x[0] ** 2 + x[-1] ** 2) / 2 - x[0]) * L / 4\n",
    "\n",
    "def outer_func(hparams, params, posdefQ=Q):\n",
    "    hp = hparams[0]\n",
    "    w = params[0]\n",
    "    \n",
    "    out = 0.5 * (hp.unsqueeze(0) @ posdefQ @ hp.unsqueeze(1)) + w.unsqueeze(0) @ hp.unsqueeze(1)\n",
    "    #val_losses.append(out.item())\n",
    "    \n",
    "    return out.squeeze()\n",
    "\n",
    "def inner_func(hparams, params):\n",
    "    hp = hparams[0]\n",
    "    w = params[0]\n",
    "    diff = w - hp\n",
    "    out = Nesterov_func(w) + 0.5 * diff.unsqueeze(0) @ diff.unsqueeze(1)\n",
    "    \n",
    "    return out.squeeze()\n",
    "\n",
    "def map_func(hparams, params):\n",
    "    \n",
    "    g = inner_func(hparams, params)\n",
    "    inner_losses.append(g.item())\n",
    "    \n",
    "    return [params[0] - alpha * torch.autograd.grad(g, params, create_graph=True)[0]]\n",
    "\n",
    "\n",
    "def inner_solver(hparams, steps=100, params0=None, optim=None):\n",
    "\n",
    "    params = [torch.zeros(d).requires_grad_(True)]\n",
    "\n",
    "    for _ in range(steps):\n",
    "        params = map_func(hparams, params)\n",
    "\n",
    "    return params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 150\n",
    "eval_interval = 10\n",
    "T = 10\n",
    "#K = 10 \n",
    "mu = 0.0001\n",
    "beta = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'Nesterov_func' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-13-1385828cd8ec>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0mstep_start_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0minner_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m     \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minner_solver\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msteps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m     \u001b[0mt1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstep_start_time\u001b[0m \u001b[0;31m# inner loop time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-11-2c00ec0390d6>\u001b[0m in \u001b[0;36minner_solver\u001b[0;34m(hparams, steps, params0, optim)\u001b[0m\n\u001b[1;32m     35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msteps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m         \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     39\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-11-2c00ec0390d6>\u001b[0m in \u001b[0;36mmap_func\u001b[0;34m(hparams, params)\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmap_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m     \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minner_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     27\u001b[0m     \u001b[0minner_losses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-11-2c00ec0390d6>\u001b[0m in \u001b[0;36minner_func\u001b[0;34m(hparams, params)\u001b[0m\n\u001b[1;32m     18\u001b[0m     \u001b[0mw\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0mdiff\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mw\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mhp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m     \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNesterov_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m0.5\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdiff\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mdiff\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'Nesterov_func' is not defined"
     ]
    }
   ],
   "source": [
    "hparams = [torch.ones(d).requires_grad_(True)]\n",
    "\n",
    "outer_opt = torch.optim.SGD(lr=beta, momentum=.9, params=hparams)\n",
    "\n",
    "total_time, val_losses = 0,  []\n",
    "\n",
    "for k in range(K):\n",
    "    \n",
    "    step_start_time = time.time() \n",
    "    inner_losses = []\n",
    "    params = inner_solver(hparams, steps=T)\n",
    "    t1 = time.time() - step_start_time # inner loop time\n",
    "\n",
    "    outer_opt.zero_grad()\n",
    "    _, cost = hg.hgvzoj(params, hparams, outer_func, inner_solver, mu=mu, T=T, p=d, set_grad=True)\n",
    "    t2 = time.time() - step_start_time - t1 # hypergrad estimation time \n",
    "    val_losses.append(cost.item())\n",
    "    outer_opt.step()\n",
    "    \n",
    "    step_time = time.time()-step_start_time\n",
    "    total_time +=step_time\n",
    "\n",
    "    if k % eval_interval == 0 or k == K - 1:\n",
    "        print('outer step={} ({:.2e}s)({:.2e}, {:.2e}) val loss={} '.format(k, step_time, t1, t2, val_losses[-1]))\n",
    "\n",
    "print('total time = {}'.format(total_time))\n",
    "\n",
    "plt.title('validation loss')\n",
    "plt.xlabel('outer steps')\n",
    "plt.plot(val_losses)\n",
    "#plt.savefig('plots/val_loss2.png', bbox_inches='tight')\n",
    "#plt.close()\n",
    "plt.show()\n",
    "val_zoj = val_losses\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
