{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Illustration of Forward-Backward-Forward Algorithm\n",
    "## Min-Max-Problem with box constraints (linear classifier model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The FBF algorithm was originally formulated for monotone inclusions and finds application in variational inequality problems (VIPs).\n",
    "VIPs also cover the class of zero-sum games (Min-Max-Problem with a mutual objective in two variables):\n",
    "\n",
    "$\\min\\limits_{x \\in H} \\max\\limits_{y \\in G} F(x, y)$\n",
    "\n",
    "This type of problem perfectly fits the Wasserstein-GAN formulation with weight clipping (see https://arxiv.org/abs/1701.07875), where $x$ and $y$ is a parametrisation of the generator and discriminator network, respectively, and the constraint set $G$ is a d-dimensional cube."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The algorithm applied to this specific setting looks as follows:\n",
    "\n",
    "$u_k = P_H \\left[ x_k - \\alpha \\nabla_x F(x_k, y_k)\\right]$\n",
    "\n",
    "$v_k = P_G \\left[ y_k + \\alpha \\nabla_y F(x_k, y_k)\\right]$\n",
    "\n",
    "$x_{k+1} = u_k - \\alpha \\nabla_x F(u_k, v_k) + \\alpha \\nabla_x F(x_k, y_k)$\n",
    "\n",
    "$y_{k+1} = v_k + \\alpha \\nabla_y F(u_k, v_k) - \\alpha \\nabla_y F(x_k, y_k)$\n",
    "\n",
    "We have proved convergence of the FBF method if $F(x, y)$ is differentiable, and convex in $x$ and concave in $y$, and the constraint sets $H$ and $G$ are nonempty, closed and convex. This is a well-established result.\n",
    "\n",
    "In absence of a constraint set (and thus a projection) we get the so-called \"extra-gradient method\" (for application in GANs see https://arxiv.org/abs/1802.10551)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The implementation of one step (e.g., to get from $x_{k}$ to $x_{k+1}$) is split into two phases:\n",
    "\n",
    "1. \"extrapolation\":\n",
    "    1. compute update (either via SGD or Adam)\n",
    "    2. do descent step\n",
    "    3. store update (e.g., $- \\alpha \\nabla_x F(x_k, y_k)$)\n",
    "\n",
    "2. \"step\":\n",
    "    1. compute update (either via SGD or Adam)\n",
    "    2. do descent step and subtract stored update\n",
    "    \n",
    "Note: The projection (in case of a d-dimensional cube this means \"weight clipping\") is directly done in the executable training file, e.g., \"train_fbfadam.py\", and is not included in the optimiser class."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The purpose of this notebook is to illustrate the implementation of the FBF method, in particular to show the two key methods of the FBF optimiser class. In this case this is done for \"FBFSGD\" (\"FBFAdam\" works in a similar fashion).\n",
    "This is done for a linear classifier model (one fully connected layer without activation), showed only for one component as the algorithm does the same in both components apart from the opposite sign of the objective function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define toy instance of a neural network\n",
    "class LinClas(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LinClas, self).__init__()\n",
    "        self.fc = nn.Linear(5, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define toy example of a loss function\n",
    "def loss(x):\n",
    "    return x*x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print function for optimiser (weights, gradient and copy of update)\n",
    "def print_opt():\n",
    "    for group in opt.param_groups:\n",
    "        for p in group['params']:\n",
    "            print(f\"Weights\\n{p}\")\n",
    "            print(f\"Gradient\\n{p.grad}\\n\")\n",
    "    print(f\"Updates_Copy:\\n{opt.updates_copy}\\n\")\n",
    "    print(f\"Old_Params_Copy:\\n{opt.old_params_copy}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# radius of d-dimensional cube\n",
    "clip = 0.25\n",
    "\n",
    "# input of neural network (whole batch)\n",
    "inp = torch.Tensor([-0.1, 0., 0.1, 0.2, 0.3])\n",
    "\n",
    "# step size = learning rate\n",
    "lr = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad(output):\n",
    "    return 2*output*inp\n",
    "\n",
    "def weights():\n",
    "    return opt.param_groups[0][\"params\"][0].data.clone().detach()\n",
    "\n",
    "def update():\n",
    "    return opt.updates_copy[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we instantiate a fully connected (1-layer) neural network and have a look at the initial weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('fc.weight',\n",
       "              tensor([[-0.3931, -0.0523,  0.1615, -0.2265, -0.1716]])),\n",
       "             ('fc.bias', tensor([0.2289]))])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A = LinClas()\n",
    "A.state_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we import the FBF optimiser class \"FBFSGD\" and set up an instance with a certain stepsize (= \"lr\"). To easily keep track of what happens nothing fancy (e.g., \"Momentum\" or \"Nesterov\") is specified.\n",
    "To check that all the parameters of the network are tracked, make use of `print_opt()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.3931, -0.0523,  0.1615, -0.2265, -0.1716]], requires_grad=True)\n",
      "Gradient\n",
      "None\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.2289], requires_grad=True)\n",
      "Gradient\n",
      "None\n",
      "\n",
      "Updates_Copy:\n",
      "[]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "from optim import FBFSGD\n",
    "opt = FBFSGD(A.parameters(), lr = lr)\n",
    "print_opt()\n",
    "print(opt.inertia)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Iteration 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computation of gradient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute output of network with respect to input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output: tensor([0.1876], grad_fn=<AddBackward0>)\n",
      "Loss: tensor([0.0352], grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "outp = A(inp)\n",
    "print(f\"Output: {outp}\")\n",
    "lc_loss = loss(outp)\n",
    "print(f\"Loss: {lc_loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Clear old gradients that where possibly stored."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.3931, -0.0523,  0.1615, -0.2265, -0.1716]], requires_grad=True)\n",
      "Gradient\n",
      "None\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.2289], requires_grad=True)\n",
      "Gradient\n",
      "None\n",
      "\n",
      "Updates_Copy:\n",
      "[]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "opt.zero_grad()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Backpropagate the loss through the network to get gradients with respect to each weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.3931, -0.0523,  0.1615, -0.2265, -0.1716]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0375,  0.0000,  0.0375,  0.0750,  0.1126]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.2289], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.3752])\n",
      "\n",
      "Updates_Copy:\n",
      "[]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lc_loss.backward()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0375,  0.0000,  0.0375,  0.0750,  0.1126], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grad(outp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extrapolation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.3893, -0.0523,  0.1577, -0.2340, -0.1829]], grad_fn=<SubBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights() - lr*grad(outp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.3893, -0.0523,  0.1577, -0.2340, -0.1829]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0375,  0.0000,  0.0375,  0.0750,  0.1126]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1914], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.3752])\n",
      "\n",
      "Updates_Copy:\n",
      "[tensor([[ 0.0038, -0.0000, -0.0038, -0.0075, -0.0113]]), tensor([-0.0375])]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "opt.extrapolation()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2500, -0.0523,  0.1577, -0.2340, -0.1829]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0375,  0.0000,  0.0375,  0.0750,  0.1126]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1914], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.3752])\n",
      "\n",
      "Updates_Copy:\n",
      "[tensor([[ 0.0038, -0.0000, -0.0038, -0.0075, -0.0113]]), tensor([-0.0375])]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for p in A.parameters():\n",
    "    p.data.clamp_(-clip, clip)\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computation of gradient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute output of network with respect to input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output: tensor([0.1305], grad_fn=<AddBackward0>)\n",
      "Loss: tensor([0.0170], grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "outp = A(inp)\n",
    "print(f\"Output: {outp}\")\n",
    "lc_loss = loss(outp)\n",
    "print(f\"Loss: {lc_loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Clear old gradients that where possibly stored."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2500, -0.0523,  0.1577, -0.2340, -0.1829]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[0., 0., 0., 0., 0.]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1914], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.])\n",
      "\n",
      "Updates_Copy:\n",
      "[tensor([[ 0.0038, -0.0000, -0.0038, -0.0075, -0.0113]]), tensor([-0.0375])]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "opt.zero_grad()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Backpropagate the loss through the network to get gradients with respect to each weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2500, -0.0523,  0.1577, -0.2340, -0.1829]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0261,  0.0000,  0.0261,  0.0522,  0.0783]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1914], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.2610])\n",
      "\n",
      "Updates_Copy:\n",
      "[tensor([[ 0.0038, -0.0000, -0.0038, -0.0075, -0.0113]]), tensor([-0.0375])]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lc_loss.backward()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0261,  0.0000,  0.0261,  0.0522,  0.0783], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grad(outp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.2511, -0.0523,  0.1589, -0.2317, -0.1795]], grad_fn=<SubBackward0>)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights() - lr*grad(outp) - update()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2511, -0.0523,  0.1589, -0.2317, -0.1795]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0261,  0.0000,  0.0261,  0.0522,  0.0783]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.2028], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.2610])\n",
      "\n",
      "Updates_Copy:\n",
      "[]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "opt.step()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Iteration 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computation of gradient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute output of network with respect to input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output: tensor([0.1436], grad_fn=<AddBackward0>)\n",
      "Loss: tensor([0.0206], grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "outp = A(inp)\n",
    "print(f\"Output: {outp}\")\n",
    "lc_loss = loss(outp)\n",
    "print(f\"Loss: {lc_loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Clear old gradients that where possibly stored."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2511, -0.0523,  0.1589, -0.2317, -0.1795]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[0., 0., 0., 0., 0.]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.2028], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.])\n",
      "\n",
      "Updates_Copy:\n",
      "[]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "opt.zero_grad()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Backpropagate the loss through the network to get gradients with respect to each weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2511, -0.0523,  0.1589, -0.2317, -0.1795]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0287,  0.0000,  0.0287,  0.0575,  0.0862]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.2028], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.2873])\n",
      "\n",
      "Updates_Copy:\n",
      "[]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lc_loss.backward()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0287,  0.0000,  0.0287,  0.0575,  0.0862], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grad(outp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extrapolation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.2483, -0.0523,  0.1560, -0.2374, -0.1881]], grad_fn=<SubBackward0>)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights() - lr*grad(outp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2483, -0.0523,  0.1560, -0.2374, -0.1881]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0287,  0.0000,  0.0287,  0.0575,  0.0862]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1741], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.2873])\n",
      "\n",
      "Updates_Copy:\n",
      "[tensor([[ 0.0029, -0.0000, -0.0029, -0.0057, -0.0086]]), tensor([-0.0287])]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "opt.extrapolation()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2483, -0.0523,  0.1560, -0.2374, -0.1881]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0287,  0.0000,  0.0287,  0.0575,  0.0862]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1741], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.2873])\n",
      "\n",
      "Updates_Copy:\n",
      "[tensor([[ 0.0029, -0.0000, -0.0029, -0.0057, -0.0086]]), tensor([-0.0287])]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for p in A.parameters():\n",
    "    p.data.clamp_(-clip, clip)\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computation of gradient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute output of network with respect to input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output: tensor([0.1106], grad_fn=<AddBackward0>)\n",
      "Loss: tensor([0.0122], grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "outp = A(inp)\n",
    "print(f\"Output: {outp}\")\n",
    "lc_loss = loss(outp)\n",
    "print(f\"Loss: {lc_loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Clear old gradients that where possibly stored."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2483, -0.0523,  0.1560, -0.2374, -0.1881]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[0., 0., 0., 0., 0.]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1741], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.])\n",
      "\n",
      "Updates_Copy:\n",
      "[tensor([[ 0.0029, -0.0000, -0.0029, -0.0057, -0.0086]]), tensor([-0.0287])]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "opt.zero_grad()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Backpropagate the loss through the network to get gradients with respect to each weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2483, -0.0523,  0.1560, -0.2374, -0.1881]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0221,  0.0000,  0.0221,  0.0442,  0.0664]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1741], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.2212])\n",
      "\n",
      "Updates_Copy:\n",
      "[tensor([[ 0.0029, -0.0000, -0.0029, -0.0057, -0.0086]]), tensor([-0.0287])]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lc_loss.backward()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0221,  0.0000,  0.0221,  0.0442,  0.0664], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grad(outp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.2489, -0.0523,  0.1566, -0.2361, -0.1861]], grad_fn=<SubBackward0>)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights() - lr*grad(outp) - update()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights\n",
      "Parameter containing:\n",
      "tensor([[-0.2489, -0.0523,  0.1566, -0.2361, -0.1861]], requires_grad=True)\n",
      "Gradient\n",
      "tensor([[-0.0221,  0.0000,  0.0221,  0.0442,  0.0664]])\n",
      "\n",
      "Weights\n",
      "Parameter containing:\n",
      "tensor([0.1807], requires_grad=True)\n",
      "Gradient\n",
      "tensor([0.2212])\n",
      "\n",
      "Updates_Copy:\n",
      "[]\n",
      "\n",
      "Old_Params_Copy:\n",
      "[]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "opt.step()\n",
    "print_opt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
