{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as Data\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "\n",
    "import numpy as np\n",
    "import imageio\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"animation.writer\": \"ffmpeg\",\n",
    "    \"font.family\": \"serif\",  # use serif/main font for text elements\n",
    "    \"font.size\": 12,\n",
    "    \"text.usetex\": True,     # use inline math for ticks\n",
    "    \"pgf.rcfonts\": False,    # don't setup fonts from rc parameters\n",
    "    \"hist.bins\": 20, # default number of bins in histograms\n",
    "    \"pgf.preamble\": [\n",
    "         \"\\\\usepackage{units}\",          # load additional packages\n",
    "         \"\\\\usepackage{metalogo}\",\n",
    "         \"\\\\usepackage{unicode-math}\",   # unicode math setup\n",
    "         r\"\\setmathfont{xits-math.otf}\",\n",
    "         r\"\\setmainfont{DejaVu Serif}\",  # serif font via preamble\n",
    "         r'\\usepackage{color}',\n",
    "    ]\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnAAAAEICAYAAADfismSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAASlklEQVR4nO3dsW5cR7on8P+348DAAIs2MQIc+RqcF7hLEfsA11Q0KS3sCyyNjZQsBnszhytnCqX7BGMZmzmSZsOJKL6BCNzIBnghNxYw4GTwbcBDTZPulps2ydOn9fsBBLvqVHWXfVDw36eqmtXdAQBgOv7T2AMAAOB6BDgAgIkR4AAAJkaAAwCYGAEOAGBiBDgAgIn5YOwB3KU//OEP/emnn449DACAX/Tq1av/6O57y669VwHu008/zfHx8djDAAD4RVX176uuWUIFAJgYAQ4AYGIEOACAiRHgAAAmZtRDDFU1S3KUZN7dz1a0OUwyTzJLctrdJ++qBwDYdmOfQt1/18Uh4D3o7i+G8oskD1bV3+5QAYD33ben3+bJyZN8/+P3+fj3H+fR3qP8afdPdz6OUQNcd7+sqt13NHmY5PVCeV5VezkPfj+r9xQOALgt355+my//9mV++vtPSZLvfvwuX/7tyyS58xC36XvgZjlfJr3wJsnOO+oBAG7Fk5Mnb8PbhZ/+/lOenDy587FseoBbZnad+qo6qqrjqjo+Ozu7tUEBANvt+x+/v1b9bdr0ADfP5WC2k+T0HfU/093Punu/u/fv3Vv61ygAAH7Rx7//+Fr1t2nTA9zXSf64UJ4N+9xW1QMA3IpHe4/y4e8+vFT34e8+zKO9R3c+lrG/RuQg56dHZ1V12t0vh/pXST7r7nlVvRjaJcnjJFlVDwBwWy4OKmzCKdTq7jv/0LHs7++3P2YPAExBVb3q7qVfubbpS6gAAFwhwAEATIwABwAwMQIcAMDECHAAABMjwAEATIwABwAwMQIcAMDECHAAABMjwAEATIwABwAwMQIcAMDECHAAABMjwAEATIwABwAwMQIcAMDEjB7gquqwqg6G33tLrs+q6vGKvk+ram/4+fPtjxYAYHyjBriqmiV50N0vu/ubJMuC2m6So6p6Pfz8sBDWdpM8T/JFd391N6MGABjXByN//sMkrxfK86ra6+6TxUbd/dHF66o6HMJech74jpPs3PpIAQA2xNhLqLMk84Xym1wJY4thrqoOk7xcuLw7tF+5zAoAsG3GDnDLzJZVDsutu909v6jr7mfdfTqEvIOq2l3S76iqjqvq+Ozs7JaGDABwd8YOcPNcDmw7SU5XtP3XLDx9Gw49LB5ceLOs0xDy9rt7/969e79ttAAAG2DsAPd1kj8ulGdX978tOMjl5dbTXF5O3enuVeEPAGBrjHqIobvnVfWiqg6Gqrf72KrqVZLPriyZni68Phmewu3mfC/c53c0bACAUY19CjULJ0qv1t9/V/ldfQEAttnYS6gAAFyTAAcAMDECHADAxAhwAAATI8ABAEyMAAcAMDECHADAxAhwAAATI8ABAEyMAAcAMDECHADAxAhwAAATI8ABAEyMAAcAMDECHADAxAhwAAAT88HYA6iqwyTzJLMkp919sqTN0yRPh+JBd3+1bl8AgG0zaoCrqlmSB939xVB+keTBkqa7SZ4nebnQdt2+AABbZewl1IdJXi+U51W1t6Td4yT3h9/X7QsAsFXGDnCznC+BXniTZGdJu92hflZVFyFurb5VdVRVx1V1fHZ29ttHDAAwsrED3DKzqxXd/ay7L/a4HVTV7jX77nf3/r179252pAAAIxg7wM1zOXTtJDldbFBVh1X154WqN+v2BQDYRmMHuK+T/HGhPFtykvQ0ycuF8k53n67ZFwBg64x6CrW751X1oqoOhqq3hxSq6lWSz7r7ZHgKt5vzvXCf/1JfAIBtNvr3wHX3Nyvq76/RZmk9AMA2G3sJFQCAaxLgAAAmRoADAJgYAQ4AYGIEOACAiRHgAAAmRoADAJgYAQ4AYGIEOACAiRHgAAAmRoADAJgYAQ4AYGIEOACAiRHgAAAmRoADAJiYD8YeQFUdJpknmSU57e6TFW12ktxP8ry7Xw71T5M8HZoddPdXdzFmAIAxjRrgqmqW5EF3fzGUXyR5cKXNXs6D3TdD+YckHw2Xd5M8T/Ly4j0AALbd2EuoD5O8XijPh8C2aCfJYjh7s9Dmcc6fyj2+vSECAGyWsQPcLOfLpxfe5DywvdXdV5+u7Swss+4O7WdVJcQBAO+FsQPcMrNVF4aQ9vlFubufdffFvrmDqtpd0ueoqo6r6vjs7OxWBgwAcJfGDnDzXA5sO0lOlzUcDjL8ZeEAw2FV/XmhyZtl/YaQt9/d+/fu3buRQQMAjGnsU6hf5/L+tdmKU6h7SU66+3ThKdtpLoe9ne5eGv4AALbJqAGuu+dV9aKqDoaqt2Guql4l+Szn+9z+mvPDC8l5UPtoaHM4BLrdLCytAgBss7GfwOXi60GW1N8fXp7kH18bslZfAIBtNvYeOAAArkmAAwCYGAEOAGBiBDgAgIkR4AAAJkaAAwCYGAEOAGBiBDgAgIkR4AAAJkaAAwCYGAEOAGBiBDgAgIkR4AAAJkaAAwCYGAEOAGBiBDgAgIn5YOwBVNVhknmSWZLT7j5Zt806fQEAts1aT+Cq6n/exodX1SzJg+5+2d3fJHm8bpt1+gIAbKN1l1D/a1X9y/Dzn2/w8x8meb1QnlfV3ppt1ukLALB11lpC7e6HF6+r6p+r6iDnS5b/5zd+/iznS6AX3iTZWbPNOn0BALbOWgGuqv55ePnfkhwkeZHkr1X1X5J8lOS4u//fDY1p9hva/Ky+qo6SHCXJJ5988mvHBACwMdY9xPB/k/wlyTfd/b+uXqyqfxnaXNc8l0PXTpLTNdvsrKi/pLufJXmWJPv7+/0rxggAsFHW3QP337v7f3T3X69eqKr//Rs+/+skf1woz5acJF3VZp2+AABbp7rHfSi18FUgSZLufjnUv0ryWXfP39Fmaf0q+/v7fXx8fIOjBwC4HVX1qrv3l10b/Xvghq8AWVZ/f402S+sBALaZv8QAADAxAhwAwMQIcAAAEyPAAQBMjAAHADAxAhwAwMQIcAAAEyPAAQBMjAAHADAxAhwAwMQIcAAAEyPAAQBMjAAHADAxAhwAwMQIcAAAEyPAAQBMzAdjD6CqDpPMk8ySnHb3yYo2O0nuJ3ne3S+H+qdJng7NDrr7q7sYMwDAmEYNcFU1S/Kgu78Yyi+SPLjSZi/nwe6bofxDko+Gy7tJnid5efEeAADbbuwl1IdJXi+U50NgW7STZDGcvVlo8zjnT+Ue394QAQA2y9gBbpbz5dMLb3Ie2N7q7qtP13YWlll3h/azqhLiAID3wtgBbpnZqgtDSPv8otzdz7r7Yt/cQVXtLulzVFXHVXV8dnZ2KwMGALhLt7oHrqqOsiKQDQcO5leu7yQ5XfFeh0n+cvH0bSjvLhxceLPic54leZYk+/v7fd1/BgCATXOrAW4IT+/ydS7vX5utOIW6l+Sku08XnrKd5nLY2+nupeEPAGCbjHoKtbvnVfWiqg6GqrdhrqpeJfks5/vc/przwwvJeVD7aGhzOAS63SwsrQIAbLPRvwfu4utBltTfH16e5B9fG7JWXwCAbbaJhxgAAHgHAQ4AYGIEOACAiRHgAAAmRoADAJgYAQ4AYGIEOACAiRHgAAAmRoADAJgYAQ4AYGIEOACAiRHgAAAmRoADAJgYAQ4AYGIEOACAiflg7AFU1WGSeZJZktPuPlnS5mmSp0PxoLu/WrcvAMC2GTXAVdUsyYPu/mIov0jyYEnT3STPk7xcaLtuXwCArTL2EurDJK8XyvOq2lvS7nGS+8Pv6/YFANgqYwe4Wc6XQC+8SbKzpN3uUD+rqosQt25fAICtMvoeuCVmVyu6+9nF66r6t6raXbdvVR0lOUqSTz755GZGCAAwolsNcEN4mi27NhxEmF+5vpPk9Mp7HCbZvTi4kPMnbVmn7/A5z5I8S5L9/f2+1j8AAMAGutUAt/jkbIWvc3lf22zJSdLTXA5mO919WlXr9AUA2DqjLqF297yqXlTVwVD1NpBV1askn3X3SVUdDsumu0k+/6W+AADbbPQ9cN39zYr6+2u0WVoPALDNxj6FCgDANQlwAAATI8ABAEyMAAcAMDECHADAxAhwAAATI8ABAEyMAAcAMDECHADAxAhwAAATI8ABAEyMAAcAMDECHADAxAhwAAATI8ABAEyMAAcAMDGjB7iqOqyqg+H33pLrs6p6vKLv06raG37+fPujBQAY36gBrqpmSR5098vu/ibJsqC2m+Soql4PPz8shLXdJM+TfNHdX93NqAEAxvXByJ//MMnrhfK8qva6+2SxUXd/dPG6qg6HsJecB77jJDu3PlIAgA0x9hLqLMl8ofwmV8LYYpirqsMkLxcu7w7tVy6zAgBsm7ED3DKzZZXDcutud88v6rr7WXefDiHvoKp2l/Q7qqrjqjo+Ozu7pSEDANydW11CraqjrAhkw561+ZXrO0lOV7zdvyb5y8J7H+Y80F3sfXuz4nOeJXmWJPv7+7324AEANtStBrghPL3L17l8cGF2df/bgoMkTxfKp7kc9na6e1X4AwDYGqMeYujueVW9qKqDoeptmKuqV0k+u7Jkerrw+mT46pHdnO+F+/yOhg0AMKqxT6Fm4UTp1fr77yq/qy8AwDbbxEMMAAC8gwAHADAxAhwAwMQIcAAAEyPAAQBMjAAHADAxAhwAwMQIcAAAEyPAAQBMjAAHADAxAhwAwMQIcAAAEyPAAQBMjAAHADAxAhwAwMQIcAAAE/PB2AOoqlmSoyTz7n62os1hknmSWZLT7j55Vz0AwDYbPcAl2X/XxSHgPejuL4byiyQPVtXf7lBX+/b02zw5eZLvf/w+H//+4zzae5Q/7f5prOEAAFts9ADX3S+ravcdTR4meb1QnlfVXs6D38/qx3gK9+3pt/nyb1/mp7//lCT57sfv8uXfvkwSIQ4AuHFT2AM3y/ky6YU3SXbeUX/nnpw8eRveLvz095/y5OTJGMMBALbcFALcMrN166vqqKqOq+r47OzsVgbz/Y/fX6seAOC3uNUl1Ko6yoqw1d1frfk28yvvsZPkNP94Cne1/urnPEvyLEn29/d7zc+8lo9//3G++/G7pfUAADftVgPcqlOl1/R1kscL5Vl3n1TV6bL6G/i8a3u09+jSHrgk+fB3H+bR3qMxhgMAbLnRDzFU1UHOT4/Oquq0u18O9a+SfNbd86p6MbRLhtC2qn4MFwcVnEIFAO5Cdd/KquJG2t/f7+Pj47GHAQDwi6rqVXcv/bq1qR5iAAB4bwlwAAATI8ABAEyMAAcAMDECHADAxLxXp1Cr6izJv9/yx/whyX/c8mdwfe7L5nFPNpP7snnck810F/fln7r73rIL71WAuwtVdbzqyC/jcV82j3uymdyXzeOebKax74slVACAiRHgAAAmRoC7eTfx91+5ee7L5nFPNpP7snnck8006n2xBw4AYGI8gQMAmJgPxh7AtqiqwyTzJLMkp919sqTN0yRPh+JBd391ZwN8D6x5D36xDTfL3NgsVTVLcpRk3t1Ll4DMk7u35n0xT+7YMBd2ktxP8ry7X65oM88dzxcB7gYME+9Bd38xlF8kebCk6W6S50leXrTlZqxzD65xn7gh5sZGeufXHpgno1nn6yjMkztUVXs5D2TfDOUfknx0pc0sI80XS6g342GS1wvl+XDjr3qc8xT/+E5G9X5Z5x6se5+4OebGhhmeIMzf0cQ8GcEa9yUxT+7aTpLFoPxmk/67IsDdjFkuT7w3Ob/xV+0O9bOqMgFv1iy/fA/WacPNmsXcmJpZzJNNZZ7coe6++qRzZ8ny6CwjzRdLqLdndrVicV9DVf1bVe129+mdjur9MruhNtys2dUKc2PjzcYeAObJmIbA/PmazWe3OJS3BLg1VNVRVtyQYRPp/Mr1nSSXJtWwyXF3YdPpm5se53tunl+4B2u24WbNY25MzTzmycYxT8Yz/Lv/y4rDCfOMNF8EuDWsOhG04Otc3pMwW3KjT3P5pu74P6cbtc49WKcNN8vcmB7zZDOZJyMY9rOddPdpVe0myZV/76PNF1/ke0MWjhEnebshNVX1Ksln3T0f2iTn+xi+Mflu1jXuwc/acHvMjc1SVQc535g9S/LYPNkM17gviXlyJ4bw9tf842nnTnd/NFwbfb4IcAAAE+MUKgDAxAhwAAATI8ABAEyMAAcAMDECHADAxAhwAAATI8ABAEyMAAcAMDECHMCvUFWHVfV6+D27eD32uID3g7/EAPArDX/+6EGSp8nP/kYiwK3xBA7gVxr+5uEsyYHwBtwlAQ7gt3mR5POxBwG8XyyhAvxKVbWXZJ5kL8lud3817oiA94UncAC/QlUdJXme5E2SkySPhzqAW+cJHADAxHgCBwAwMQIcAMDECHAAABMjwAEATIwABwAwMQIcAMDECHAAABMjwAEATMz/B4C3na4YnkBTAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 720x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "##Data\n",
    "## orthogonal toy example\n",
    "x = torch.Tensor(np.array([-0.5,2]).reshape(-1,1)) \n",
    "y = torch.Tensor(np.array([-1,1]).reshape(-1,1)) \n",
    "## non-orthognal toy example\n",
    "#x = torch.Tensor(np.array([-1.5,-1.,0,0.25,2]).reshape(-1,1))\n",
    "#y = torch.Tensor(np.array([0.7,-0.3,-0.3,0.4,0.2]).reshape(-1,1)) \n",
    "\n",
    "# torch can only train on Variable, so convert them to Variable\n",
    "x, y = Variable(x), Variable(y)\n",
    "\n",
    "# view data\n",
    "plt.figure(figsize=(10,4))\n",
    "plt.scatter(x.data.numpy(), y.data.numpy(), color = \"tab:green\")\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Neural network architecture and initialisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    1 hidden layer Relu network architecture\n",
    "    \"\"\"\n",
    "    def __init__(self, n_feature, n_hidden, n_output, init_scale=1, bias_hidden=True, bias_output=False, balanced=True, **kwargs):\n",
    "        \"\"\"\n",
    "        n_feature: dimension of input\n",
    "        n_hidden: number of hidden neurons\n",
    "        n_output: dimension of output\n",
    "        init_scale: the weights are initialized ~ N(0, init_scale^2/(md)) where d is the input dimension of this layer (without bias) and m the width\n",
    "        bias_hidden: if True, use bias parameters in hidden layer. Use no bias otherwise\n",
    "        bias_output: if True, use bias parameters in output layer. Use no bias otherwise\n",
    "        balanced: if True, use a balanced initialisation\n",
    "        \"\"\"\n",
    "        super(Net, self).__init__()\n",
    "        self.init_scale = init_scale\n",
    "        \n",
    "        self.hidden = torch.nn.Linear(n_feature, n_hidden, bias=bias_hidden)   # hidden layer with rescaled init\n",
    "        torch.nn.init.normal_(self.hidden.weight.data, std=(init_scale/np.sqrt(n_hidden*n_feature)))\n",
    "        if bias_hidden:\n",
    "            torch.nn.init.normal_(self.hidden.bias.data, std=(init_scale/np.sqrt(n_hidden*n_feature)))\n",
    "            \n",
    "        self.predict = torch.nn.Linear(n_hidden, n_output, bias=bias_output)   # output layer with rescaled init\n",
    "        if balanced: # balanced initialisation\n",
    "            if bias_hidden:\n",
    "                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()+self.hidden.bias.data.square()).sqrt()\n",
    "            else:\n",
    "                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()).sqrt()\n",
    "            self.predict.weight.data = 2*torch.bernoulli(0.5*torch.ones_like(self.predict.weight.data)) -1\n",
    "            self.predict.weight.data *= neuron_norms\n",
    "        else:\n",
    "            torch.nn.init.normal_(self.predict.weight.data, std=(init_scale/np.sqrt(n_hidden)))\n",
    "        if bias_output:\n",
    "            torch.nn.init_normal_(self.predict.bias.data, std=(init_scale/np.sqrt(n_hidden)))\n",
    "            \n",
    "        self.activation = kwargs.get('activation', torch.nn.ReLU()) # activation of hidden layer\n",
    "\n",
    "    def forward(self, z):\n",
    "        z = self.activation(self.hidden(z))     \n",
    "        z = self.predict(z)             # linear output\n",
    "        return z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Visualisation functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multicolor_label(ax,list_of_strings,list_of_colors,axis='x',anchorpad=0,**kw):\n",
    "    \"\"\"this function creates axes labels with multiple colors\n",
    "    ax: specifies the axes object where the labels should be drawn\n",
    "    list_of_strings: a list of all of the text items\n",
    "    list_if_colors: a corresponding list of colors for the strings\n",
    "    axis:'x', 'y', or 'both' and specifies which label(s) should be drawn\"\"\"\n",
    "    from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, HPacker, VPacker\n",
    "\n",
    "    # x-axis label\n",
    "    if axis=='x' or axis=='both':\n",
    "        boxes = [TextArea(text, textprops=dict(color=color, ha='left',va='bottom',**kw)) \n",
    "                    for text,color in zip(list_of_strings,list_of_colors) ]\n",
    "        xbox = HPacker(children=boxes,align=\"center\",pad=0, sep=60)\n",
    "        anchored_xbox = AnchoredOffsetbox(loc=3, child=xbox, pad=anchorpad,frameon=False,bbox_to_anchor=(0.27, -0.18),\n",
    "                                          bbox_transform=ax.transAxes, borderpad=0.)\n",
    "        ax.add_artist(anchored_xbox)\n",
    "\n",
    "    # y-axis label\n",
    "    if axis=='y' or axis=='both':\n",
    "        boxes = [TextArea(text, textprops=dict(color=color, ha='left',va='bottom',rotation=90,**kw)) \n",
    "                     for text,color in zip(list_of_strings[::-1],list_of_colors) ]\n",
    "        ybox = VPacker(children=boxes,align=\"center\", pad=0, sep=5)\n",
    "        anchored_ybox = AnchoredOffsetbox(loc=3, child=ybox, pad=anchorpad, frameon=False, bbox_to_anchor=(-0.10, 0.2), \n",
    "                                          bbox_transform=ax.transAxes, borderpad=0.)\n",
    "        ax.add_artist(anchored_ybox)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_single_frame(fig, arts, frame_number):\n",
    "    \"\"\"save as a pdf a single frame of an animation\n",
    "    fig: the figure to save\n",
    "    arts: list of images resulting in the animation\n",
    "    frame_number: the specific frame to save as a pdf\n",
    "    \"\"\"\n",
    "    # make sure everything is hidden\n",
    "    for frame_arts in arts:\n",
    "        for art in frame_arts:\n",
    "            art.set_visible(False)\n",
    "    # make the one artist we want visible\n",
    "    for i in range(len(arts[frame_number])):\n",
    "        arts[frame_number][i].set_visible(True)\n",
    "    fig.savefig(\"frame_{}.pdf\".format(frame_number))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Neuron alignment visualisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 20000/20000 [00:04<00:00, 4890.36it/s]\n"
     ]
    }
   ],
   "source": [
    "# init network\n",
    "net = Net(n_feature=1, n_hidden=60, n_output=1, init_scale=1e-6, balanced=True)     # define the network\n",
    " \n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.001) # Gradient descent\n",
    "loss_func = torch.nn.MSELoss(reduction='mean')  # mean squared error\n",
    "\n",
    "n_samples = x.shape[0]\n",
    "n_iterations = 20000 # number of descent steps\n",
    "\n",
    "loss = torch.Tensor(np.array([0]))\n",
    "previous_loss = torch.Tensor(np.array([np.infty]))\n",
    "\n",
    "# plot parameters\n",
    "iter_geom = 1.05 #saved frames correspond to steps t=\\lceil k^{iter_geom} \\rceil for all integers k \n",
    "last_iter = 0\n",
    "frame = 0\n",
    "ims = []\n",
    "fig = plt.figure(\"Training dynamics\")\n",
    "plt.ioff()\n",
    "\n",
    "# Cosmetics\n",
    "c1 = 'tab:green' # color of left axis\n",
    "c2 = 'tab:blue' # color of right axis\n",
    "\n",
    "plt.subplots_adjust(left=0.15, right=0.85)\n",
    "\n",
    "ax1 = fig.add_subplot(111)\n",
    "ax1.set_xlim(x.min()-1,x.max()+1)\n",
    "ax1.set_ylim(y.min()-2.6,y.max()+1)\n",
    "ax2 = ax1.twinx()\n",
    "ax2.axhline(0, linestyle='--', alpha=0.5)\n",
    "ax1.set_ylabel(r'$h_{\\theta}(x)$', fontsize=20)\n",
    "ax2.set_ylabel(r'$\\mathsf{s}_j\\|w_j\\|$', fontsize=20)\n",
    "\n",
    "ax1.yaxis.label.set_color(c1)\n",
    "ax2.yaxis.label.set_color(c2)\n",
    "\n",
    "ax2.spines[\"left\"].set_edgecolor(c1)\n",
    "ax2.spines[\"right\"].set_edgecolor(c2)\n",
    "\n",
    "ax1.tick_params(axis='y', colors=c1)\n",
    "ax2.tick_params(axis='y', colors=c2)\n",
    "multicolor_label(ax1,(r'$x$',r'$-w_{j,2}/w_{j,1}$'),(c1,c2),axis='x', fontsize=20)\n",
    "#######\n",
    "\n",
    "losses = []\n",
    "z = torch.Tensor(np.linspace(x.min()-1,x.max()+1,100).reshape(-1,1))\n",
    "iters = []\n",
    "\n",
    "\n",
    "# train the network\n",
    "for it in tqdm(range(n_iterations)):\n",
    "    previous_loss = loss\n",
    "    prediction = net(x)\n",
    "    loss = loss_func(prediction, y) # training loss\n",
    "\n",
    "    if (it<2 or it==int(last_iter*iter_geom)+1): # save frame in animation\n",
    "        im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True) # current estimated function\n",
    "        # showing each neuron individually\n",
    "        im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')\n",
    "        t = ax1.annotate(\"iteration: \"+str(it),(0.4,0.95),xycoords='figure fraction',annotation_clip=False) # add text\n",
    "        if it == 0:\n",
    "            ax1.scatter(x.data.numpy(), y.data.numpy(), color=c1)\n",
    "        ims.append([im1,im2,t])\n",
    "        last_iter = it\n",
    "        iters.append(last_iter)\n",
    "        frame += 1\n",
    "\n",
    "    losses.append(loss.data.numpy())\n",
    "    optimizer.zero_grad()   # clear gradients for next train\n",
    "    loss.backward()         # backpropagation, compute gradients\n",
    "    optimizer.step()        # gradient descent step\n",
    "    \n",
    "ani = animation.ArtistAnimation(fig, ims, interval=100, repeat=False)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "ani.save('alignment.mp4', fps=10, dpi=120) # save animation as .mp4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save specific frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del ani"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "it_to_save = 16178\n",
    "frame = iters.index(it_to_save)\n",
    "save_single_frame(fig, ims, frame) # save specific frame of animation as a .pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Loss profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(losses, lw=3)\n",
    "plt.ylim(ymin=0)\n",
    "#plt.xlim(xmin=0, xmax=20000)\n",
    "plt.ylabel(r'$L(\\theta)$',fontsize=20)\n",
    "plt.xlabel('Iterations', fontsize=20)\n",
    "plt.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig('loss_profile.pdf',fontsize=20)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
