{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as Data\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "\n",
    "import numpy as np\n",
    "import imageio\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"animation.writer\": \"ffmpeg\",\n",
    "    \"font.family\": \"serif\",  # use serif/main font for text elements\n",
    "    \"font.size\": 12,\n",
    "    \"text.usetex\": True,     # use inline math for ticks\n",
    "    \"pgf.rcfonts\": False,    # don't setup fonts from rc parameters\n",
    "    \"hist.bins\": 20, # default number of bins in histograms\n",
    "    \"pgf.preamble\": [\n",
    "         \"\\\\usepackage{units}\",          # load additional packages\n",
    "         \"\\\\usepackage{metalogo}\",\n",
    "         \"\\\\usepackage{unicode-math}\",   # unicode math setup\n",
    "         r\"\\setmathfont{xits-math.otf}\",\n",
    "         r\"\\setmainfont{DejaVu Serif}\",  # serif font via preamble\n",
    "         r'\\usepackage{color}',\n",
    "    ]\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmoAAAEICAYAAADvBtizAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPhUlEQVR4nO3dMW9cV3oG4PeLXBAQEIwJC1C1Mbi9kdDc/IBYUqMukI38gZWQis0WKdklWqRhuUq15dpKyUrylqoo/gMT2EoEtNAOAghgI5wUHEokTXKGMsl7NPM8gCDecw853xwdD1/fc89MtdYCAEB//m7oAgAAOJ2gBgDQKUENAKBTghoAQKcENQCATglqAACd+mzoAq7CF1980b788suhywAAmOrly5d/ba3dOu3cXAa1L7/8Mtvb20OXAQAwVVX95axzlj4BADolqAEAdEpQAwDolKAGANApQQ0AoFOC2gVt7W7l3tN7+eqPX+Xe03vZ2t0auiQAYE7N5dtzXJWt3a1svNjI/rv9JMmrt6+y8WIjSXJ/5f6AlQEA88gVtQvY3Nl8H9IO7b/bz+bO5kAVAQDzTFC7gL23exdqBwD4JQS1C7h98/aF2gEAfglB7QLWV9ezdGPpWNvSjaWsr64PVBEAMM9sJriAww0Dmzub2Xu7l9s3b2d9dd1GAgDgSghqF3R/5b5gBgBcC0ufAACdEtQAADolqAEAdEpQAwDolKAGANApQQ0AoFOCGgBApwQ1AIBOCWoAAJ0S1AAAOiWoAQB0avDP+qyqB0nGSUZJdltrO2f0e5hkN0laa8+vqz4AgKEMGtSqapTkbmvt0eT4WZK7p/T7IclvW2vjydeCGgAw94Ze+vwuyU9HjsdVtXq0w+HxJKSttNa+vc4CAQCGMnRQG+Vg2fPQmyTLJ/qsJe+vvo2q6vF1FAYAMLShg9ppRqcdt9bGk/vXVqtq5eQ3VdXDqtququ3Xr19ffZUAAFds6KA2zvFgtpzJhoEjdk+0jZP8LKi11p601tZaa2u3bt263CoBAAYwdFD7PsmvjxyPTtn1+TzHw9xKku0rrgsAYHCD7vqcbBB4VlV3Jk3v7z+rqpdJvjnS5+Hk1H+21sbXXSsAwHUb/H3UWmtPz2j/elofAIB5NvTSJwAAZxDUAAA6JagBAHRKUAMA6JSgBgDQKUENAKBTghoAQKcENQCATglqAACdEtQAADolqAEAdEpQAwDolKAGANApQQ0AoFOCGgBApwQ1AIBOCWoAAJ0S1AAAOiWoAQB0SlADAOiUoAYA0ClBDQCgU4IaAECnBDUAgE4JagAAnRLUAAA6JagBAHRKUAMA6JSgBgDQKUENAKBTghoAQKc+G7qAqnqQZJxklGS3tbYzrW9r7fn1VAeQbO1uZXNnM3tv93L75u2sr67n/sr9ocsCFsCgV9SqapTkbmvteWvtaZLHU/o+uqbSAJIchLSNFxt59fZVWlpevX2VjRcb2drdGro0YAEMvfT5XZKfjhyPq2r1jL5rSZ5dfUkAH2zubGb/3f6xtv13+9nc2RyoImCRDB3URjlY9jz0JsnyyU6T8LZ93g+qqodVtV1V269fv77MGoEFtvd270LtAJdp6KB2mtFpja218Xnf1Fp70lpba62t3bp16yrqAhbQ7Zu3L9QOcJmGDmrjHA9my0l2j3aoqodJViYbCX6T5G5VrVxXgcBiW19dz9KNpWNtSzeWsr66PlBFwCIZetfn9zm+gWB0ctdna+3J4ddV9Zskz1prx8IcwFU53N1p1ycwhEGDWmttXFXPqurOpOl9aKuql0m+OVzynNyntppkVFW7whpwXe6v3BfMgEEMfUUtk7flOK396xPHO0nuXktRAAAdGPoeNQAAziCoAQB0SlADAOiUoAYA0ClBDQCgU4IaAECnBDUAgE4JagAAnRLUAAA6JagBAHRKUAMA6JSgBgDQKUENAKBTghoAQKcENQCATglqAACdEtQAADolqAEAdEpQAwDolKAGANApQQ0AoFOCGgBApwQ1AIBOCWoAAJ0S1AAAOiWoAQB0SlADAOiUoAYA0ClBDQCgU4IaAECnPhu6gKp6kGScZJRkt7W2c0af5SRfJ/mhtfb8OmsEABjCTEGtqn7XWvvvy37wqholudtaezQ5fpbk7ok+qzkIcE8nx39L8vll1wIAcGhrdyubO5vZe7uX2zdvZ311PfdX7l97HbMuff5zVf3L5M/fX+Ljf5fkpyPH40kwO2o5yaMjx29O6QMAcCm2drey8WIjr96+SkvLq7evsvFiI1u7W9dey0xBrbX2XWvtz621PydZqarfVdW/XsLjj3Kw7HnoTQ6C2dHHfn54xW1i+bTlUQCAy7C5s5n9d/vH2vbf7WdzZ/Paa5l16fMfJ1/+W5I7SZ4l+bGq/ikHy5DbrbX/u6SaRufU8TjJt2ece5jkYZL86le/uqRSAIBFs/d270LtV2nWzQR/TvKnJE9ba/9x8mRV/cukz0WNczyYLSfZPa3jZEPBn866mtZae5LkSZKsra21j6gFACC3b97Oq7evTm2/brPeo/bb1tq/t9Z+PHmiqv7rFzz+90l+feR4dMauz9UkO621napaqaqVX/CYAABnWl9dz9KNpWNtSzeWsr66fu21zHRFrbX2v+ec+9kVtlm11sZV9ayq7kyaHh+eq6qXSb5JspLkxxxsIkgO7lGz6xMAuBKHuzt72PVZrc3fKuHa2lrb3t4eugwAgKmq6mVrbe20cz6ZAACgU4IaAECnBDUAgE4JagAAnRLUAAA6JagBAHRKUAMA6JSgBgDQKUENAKBTghoAQKcENQCATglqAACdEtQAADolqAEAdEpQAwDolKAGANApQQ0AoFOCGgBApwQ1AIBOCWoAAJ0S1AAAOiWoAQB0SlADAOiUoAYA0ClBDQCgU4IaAECnBDUAgE4JagAAnRLUAAA6JagBAHRKUAMA6NRnQxdQVQ+SjJOMkuy21nY+pg8AwLwZNKhV1SjJ3dbao8nxsyR3L9oHAGAeDb30+V2Sn44cj6tq9SP6AADMnaGD2igHS5qH3iRZ/og+qaqHVbVdVduvX7++1CIBAIYwdFA7zehj+rTWnrTW1lpra7du3br0ogAArtvQQW2c46FrOcnuR/QBAJg7Qwe175P8+sjx6JQdnbP0AQCYO4Pu+mytjavqWVXdmTQ9PjxXVS+TfHNeHwCAeTb4+6i11p6e0f71tD4AAPNs6KVPAADOIKgBAHRKUAMA6JSgBgDQKUENAKBTghoAQKcENQCATglqAACdEtQAADolqAEAdEpQAwDolKAGANApQQ0AoFOCGgBApwQ1AIBOCWoAAJ0S1AAAOiWoAQB0SlADAOiUoAYA0ClBDQCgU4IaAECnBDUAgE4JagAAnRLUAAA6JagBAHRKUAMA6JSgBgDQKUENAKBTghoAQKc+G7qAqnqQZJxklGS3tbZzRp/lJF8n+aG19vw6awQAGMKgQa2qRknuttYeTY6fJbl7os9qDgLc08nx35J8fs2lAgBcu6GXPr9L8tOR4/EkmB21nOTRkeM3p/QBAJg7Qy99jnKw7HnoTQ6C2XuTZc6jS53Lpy2PAgDMm6GvqJ1mdNaJqnqc5Nszzj2squ2q2n79+vVV1QYAcG2u9IpaVT3MGcGrtfb7fNhEcGg5ye4ZP+tBkj+ddTWttfYkyZMkWVtbax9bMwBAL640qE3C03m+T/L4yPHojF2fq0l2Wmu7VbUy+dmnBjoAgHkx6D1qrbVxVT2rqjuTpvehrapeJvkmyUqSH3OwiSA5uEfNrk8AYO4NvZkgh2+7cUr715Mvd+LtOACABdTjZgIAACKoAQB0S1ADAOiUoAYA0ClBDQCgU4IaAECnBDUAgE4JagAAnRLUAAA6JagBAHRKUAMA6JSgBgDQKUGNS7e1u5V7T+/lqz9+lXtP72Vrd2vokvjEmEMABz4bugDmy9buVjZebGT/3X6S5NXbV9l4sZEkub9yf8DK+FSYQwAfuKLGpdrc2Xz/C/bQ/rv9bO5sDlQRnxpzCOADQY1Ltfd270LtcJI5BPCBoMalun3z9oXa4SRzCOADQY1Ltb66nqUbS8falm4sZX11faCK+NSYQwAf2EzApTq82XtzZzN7b/dy++btrK+uuwmcmZlDAB9Ua23oGi7d2tpa297eHroMAICpqupla23ttHOWPgEAOiWoAQB0SlADAOiUoAYA0ClBDQCgU3O567OqXif5yxU/zBdJ/nrFj/GpM0bnMz7TGaPzGZ/pjNH5jM901zFG/9Bau3XaibkMatehqrbP2krLAWN0PuMznTE6n/GZzhidz/hMN/QYWfoEAOiUoAYA0ClB7eM9GbqAT4AxOp/xmc4Ync/4TGeMzmd8pht0jNyjBgDQKVfUAAA69dnQBXxqqmqU5GGScWvt1MuhVfWHJH+YHN5prf3+msrrwoxj9CDJOMkoyW5rbee66hvaLM99kebQjOOxsPMlMWfO4/VmOr+3ppvMkeUkXyf5obX2/Iw+41zzPBLULm6WLborSX5I8ry19uiK6+nRuWM0edG4ezg2VfUsyd1rqGtwF3juCzGHZhmPRZ4viTkzA6830/m9dY6qWs1B8Ho6Of5bks9P9BlloHlk6fOCJil7PKXb4xyk8sdXXlCHZhij75L8dOR4PPkPZRHM+twXZQ7NMh6LPF8Sc+ZcXm+m83trquUkR8Ppm55ehwS1q7GSg3/4UVUt4qSfZpTjLxpvcjBei2CU2Z77osyhUaaPxyx95tko5swvMcpiz59ZLez8aa2dvIq4fMqy5igDzSNLn1fg6D0AVfU/VbXSWtsdsqZPwGjoAgY0Otmw4HNodEl95tnoZMOCz5mLGg1dQG/MnwOTkPrtjN1HV1jKe4LaEVX1MGcM/Kw3Vk5uNlw50v/N5VTXh8sYo3y4GfPQcpK5eEGYYXzGmfLc530OnTDO9LkwS595No4580uMs9jzZyrz58BkHP50xiaBcQaaR4LaEWfthrmg3Rz/x1uep/8ruaQx+j7H74MYzcsurBnGZ5bnPtdz6IRZxmNu58uMzJlfZtHnzywWfv5M7jfbaa3tVtVKkpwYg8HmkTe8vaCqupODmw5HSR4fbuGtqpdJvmmtjSepPDlY83+6gBN+1jEaH37PaVuh59VZz31R59AFxuNnfRaFOXM2rzfT+b11vklI+zEfriQut9Y+n5wbfB4JagAAnbLrEwCgU4IaAECnBDUAgE4JagAAnRLUAAA6JagBAHRKUAMA6JSgBgDQKUEN4BxV9aCqfpr8PTr8eui6gMXgkwkApph8BM/dJH9IfvYZgABXxhU1gCkmn+k3SnJHSAOuk6AGMJtnSb4dughgsVj6BJiiqlaTjJOsJllprf1+2IqAReGKGsA5quphkh+SvEmyk+TxpA3gyrmiBgDQKVfUAAA6JagBAHRKUAMA6JSgBgDQKUENAKBTghoAQKcENQCATglqAACd+n+NJ0+OLRsPewAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 720x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "##Data\n",
    "x = torch.Tensor(np.array([-1.5,-1.,0,0.25,2]).reshape(-1,1))\n",
    "y = torch.Tensor(np.array([0.7,-0.3,-0.3,0.4,0.2]).reshape(-1,1))\n",
    "\n",
    "# torch can only train on Variable, so convert them to Variable\n",
    "x, y = Variable(x), Variable(y)\n",
    "\n",
    "# view data\n",
    "plt.figure(figsize=(10,4))\n",
    "plt.scatter(x.data.numpy(), y.data.numpy(), color = \"tab:green\")\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Neural network architecture and initialisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    1 hidden layer Relu network\n",
    "    \"\"\"\n",
    "    def __init__(self, n_feature, n_hidden, n_output, init_scale=1, bias_hidden=True, bias_output=False, balanced=True, **kwargs):\n",
    "        \"\"\"\n",
    "        n_feature: dimension of input\n",
    "        n_hidden: number of hidden neurons\n",
    "        n_output: dimension of output\n",
    "        init_scale: all the weights are initialized ~ N(0, init_scale^2) where m is the input dimension of this layer\n",
    "        bias_hidden: if True, use bias parameters in hidden layer. Use no bias otherwise\n",
    "        bias_output: if True, use bias parameters in output layer. Use no bias otherwise\n",
    "        balanced: if True, use a balanced initialisation\n",
    "        \"\"\"\n",
    "        super(Net, self).__init__()\n",
    "        self.init_scale = init_scale\n",
    "        \n",
    "        self.hidden = torch.nn.Linear(n_feature, n_hidden, bias=bias_hidden)   # hidden layer with rescaled init\n",
    "        torch.nn.init.normal_(self.hidden.weight.data, std=init_scale)\n",
    "        if bias_hidden:\n",
    "            torch.nn.init.normal_(self.hidden.bias.data, std=init_scale)\n",
    "            \n",
    "        self.predict = torch.nn.Linear(n_hidden, n_output, bias=bias_output)   # output layer with rescaled init\n",
    "        if balanced: # balanced initialisation\n",
    "            if bias_hidden:\n",
    "                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()+self.hidden.bias.data.square()).sqrt()\n",
    "            else:\n",
    "                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()).sqrt()\n",
    "            self.predict.weight.data = 2*torch.bernoulli(0.5*torch.ones_like(self.predict.weight.data)) -1\n",
    "            self.predict.weight.data *= neuron_norms\n",
    "        else:\n",
    "            torch.nn.init.normal_(self.predict.weight.data, std=init_scale)\n",
    "        if bias_output:\n",
    "            torch.nn.init_normal_(self.predict.bias.data, std=init_scale)\n",
    "            \n",
    "        self.activation = kwargs.get('activation', torch.nn.ReLU()) # activation of hidden layer\n",
    "        \n",
    "        if kwargs.get('zero_output', False):\n",
    "            # ensure that the estimated function is 0 at initialisation\n",
    "            half_n = int(n_hidden/2)\n",
    "            self.hidden.weight.data[half_n:] = self.hidden.weight.data[:half_n]\n",
    "            if bias_hidden:\n",
    "                self.hidden.bias.data[half_n:] = self.hidden.bias.data[:half_n]\n",
    "            self.predict.weight.data[0, half_n:] = -self.predict.weight.data[0, :half_n]\n",
    "            if bias_output:\n",
    "                self.predict.bias.data[0, half_n:] = -self.predict.bias.data[0, :half_n]\n",
    "                \n",
    "        self.skip_ = kwargs.get('skip_connection', False)    \n",
    "        if self.skip_:\n",
    "            # add a free skip connection in the parameterisation\n",
    "            self.skip = torch.nn.Linear(n_feature, n_output, bias=True)\n",
    "            # initialise the skip connection as 0\n",
    "            self.skip.weight.data *= 0\n",
    "            self.skip.bias.data *= 0\n",
    "\n",
    "    def forward(self, z):\n",
    "        out = self.predict(self.activation(self.hidden(z)))\n",
    "        if self.skip_:\n",
    "            out += self.skip(z)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Visualisation functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multicolor_label(ax,list_of_strings,list_of_colors,axis='x',anchorpad=0,**kw):\n",
    "    \"\"\"this function creates axes labels with multiple colors\n",
    "    ax: specifies the axes object where the labels should be drawn\n",
    "    list_of_strings: a list of all of the text items\n",
    "    list_if_colors: a corresponding list of colors for the strings\n",
    "    axis:'x', 'y', or 'both' and specifies which label(s) should be drawn\"\"\"\n",
    "    from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, HPacker, VPacker\n",
    "\n",
    "    # x-axis label\n",
    "    if axis=='x' or axis=='both':\n",
    "        boxes = [TextArea(text, textprops=dict(color=color, ha='left',va='bottom',**kw)) \n",
    "                    for text,color in zip(list_of_strings,list_of_colors) ]\n",
    "        xbox = HPacker(children=boxes,align=\"center\",pad=0, sep=60)\n",
    "        anchored_xbox = AnchoredOffsetbox(loc=3, child=xbox, pad=anchorpad,frameon=False,bbox_to_anchor=(0.27, -0.18),\n",
    "                                          bbox_transform=ax.transAxes, borderpad=0.)\n",
    "        ax.add_artist(anchored_xbox)\n",
    "\n",
    "    # y-axis label\n",
    "    if axis=='y' or axis=='both':\n",
    "        boxes = [TextArea(text, textprops=dict(color=color, ha='left',va='bottom',rotation=90,**kw)) \n",
    "                     for text,color in zip(list_of_strings[::-1],list_of_colors) ]\n",
    "        ybox = VPacker(children=boxes,align=\"center\", pad=0, sep=5)\n",
    "        anchored_ybox = AnchoredOffsetbox(loc=3, child=ybox, pad=anchorpad, frameon=False, bbox_to_anchor=(-0.10, 0.2), \n",
    "                                          bbox_transform=ax.transAxes, borderpad=0.)\n",
    "        ax.add_artist(anchored_ybox)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_single_frame(fig, arts, frame_number):\n",
    "    \"\"\"save as a pdf a single frame of an animation\n",
    "    fig: the figure to save\n",
    "    arts: list of images resulting in the animation\n",
    "    frame_number: the specific frame to save as a pdf\n",
    "    \"\"\"\n",
    "    # make sure everything is hidden\n",
    "    for frame_arts in arts:\n",
    "        for art in frame_arts:\n",
    "            art.set_visible(False)\n",
    "    # make the one artist we want visible\n",
    "    for i in range(len(arts[frame_number])):\n",
    "        arts[frame_number][i].set_visible(True)\n",
    "    fig.savefig(\"frame_{}.pdf\".format(frame_number))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Weight decay with biases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████| 100000/100000 [00:26<00:00, 3725.41it/s]\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "# init network\n",
    "net = Net(n_feature=1, n_hidden=200, n_output=1, init_scale=1/np.sqrt(200), skip_connection=False, balanced=False, zero_output=False)     # define the network\n",
    " \n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=5e-2) #Gradient descent\n",
    "loss_func = torch.nn.MSELoss(reduction='mean')  # mean squared error\n",
    "reg_lambda = 1e-3 # scale of regularization by 2 norm of parameters\n",
    "\n",
    "n_samples = x.shape[0]\n",
    "n_iterations = 100000 # number of descent steps\n",
    "\n",
    "loss = torch.Tensor(np.array([0]))\n",
    "previous_loss = torch.Tensor(np.array([np.infty]))\n",
    "\n",
    "# plot parameters\n",
    "iter_geom = 2 #saved frames correspond to step t=\\lceil k^{iter_geom} \\rceil for all integers k \n",
    "last_iter = 0\n",
    "frame = 0\n",
    "ims = []\n",
    "fig = plt.figure(\"Training dynamics\")\n",
    "plt.ioff()\n",
    "\n",
    "# Cosmetics\n",
    "c1 = 'tab:green' # color of left axis\n",
    "c2 = 'tab:blue' # color of right axis\n",
    "\n",
    "plt.subplots_adjust(left=0.15, right=0.85)\n",
    "\n",
    "ax1 = fig.add_subplot(111)\n",
    "ax1.set_xlim(x.min()-1,x.max()+1)\n",
    "ax1.set_ylim(y.min()-1.6,y.max()+0.8)\n",
    "ax2 = ax1.twinx()\n",
    "ax2.axhline(0, linestyle='--', alpha=0.5)\n",
    "ax1.set_ylabel(r'$f_{\\theta}(x)$', fontsize=20)\n",
    "ax2.set_ylabel(r'$\\mathsf{s}_j\\|w_j\\|$', fontsize=20)\n",
    "\n",
    "ax1.yaxis.label.set_color(c1)\n",
    "ax2.yaxis.label.set_color(c2)\n",
    "\n",
    "ax2.spines[\"left\"].set_edgecolor(c1)\n",
    "ax2.spines[\"right\"].set_edgecolor(c2)\n",
    "\n",
    "ax1.tick_params(axis='y', colors=c1)\n",
    "ax2.tick_params(axis='y', colors=c2)\n",
    "multicolor_label(ax1,(r'$x$',r'$-w_{j,2}/w_{j,1}$'),(c1,c2),axis='x', fontsize=20)\n",
    "#######\n",
    "\n",
    "iters = []\n",
    "losses = []\n",
    "z = torch.Tensor(np.linspace(x.min()-1,x.max()+1,100).reshape(-1,1))\n",
    "\n",
    "# train the network\n",
    "for it in tqdm(range(n_iterations)):\n",
    "    previous_loss = loss\n",
    "    prediction = net(x)\n",
    "    reg_loss = None\n",
    "    for name, param in net.named_parameters():\n",
    "        if \"skip\" not in name:\n",
    "            if reg_loss is None:\n",
    "                reg_loss = 0.5 * torch.sum(param**2)\n",
    "            else:\n",
    "                reg_loss += 0.5 * param.norm(2)**2\n",
    "    loss = loss_func(prediction, y) + reg_lambda*reg_loss    # must be (1. nn output, 2. target)\n",
    "\n",
    "    if (it<2 or it==int(last_iter*iter_geom)+1): # save frame in animation\n",
    "        im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)\n",
    "        im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')\n",
    "        t = ax1.annotate(\"iteration: \"+str(it),(0.4,0.95),xycoords='figure fraction',annotation_clip=False) # add text\n",
    "        if it == 0:\n",
    "            ax1.scatter(x.data.numpy(), y.data.numpy(), color=c1)\n",
    "        ims.append([im1,im2,t])\n",
    "        last_iter = it\n",
    "        iters.append(last_iter)\n",
    "        frame += 1\n",
    "\n",
    "    losses.append(loss.data.numpy())\n",
    "    optimizer.zero_grad()   # clear gradients for next train\n",
    "    loss.backward()         # backpropagation, compute gradients\n",
    "    optimizer.step()        # descent step\n",
    "    \n",
    "# plot last iterate\n",
    "im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)\n",
    "im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')\n",
    "t = ax1.annotate(\"iteration: \"+str(it),(0.4,0.95),xycoords='figure fraction',annotation_clip=False) # add text\n",
    "ims.append([im1,im2,t])\n",
    "iters.append(it)\n",
    "\n",
    "#############\n",
    "    \n",
    "ani = animation.ArtistAnimation(fig, ims, interval=100, repeat=False)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ani.save('regularisation_full.mp4', fps=10, dpi=120) # save animation as video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save specific frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "del ani"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "it_to_save = 399999\n",
    "frame = iters.index(it_to_save)\n",
    "save_single_frame(fig, ims, frame) # save specific frame of animation as a .pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save last frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = torch.Tensor(np.linspace(x.min()-1,x.max()+1,10000).reshape(-1,1))\n",
    "\n",
    "fig = plt.figure(\"Training dynamics\")\n",
    "plt.ioff()\n",
    "\n",
    "# Cosmetics\n",
    "c1 = 'tab:green' # color of left axis\n",
    "c2 = 'tab:blue' # color of right axis\n",
    "\n",
    "plt.subplots_adjust(left=0.15, right=0.85)\n",
    "\n",
    "ax1 = fig.add_subplot(111)\n",
    "ax1.set_xlim(x.min()-1,x.max()+1)\n",
    "ax1.set_ylim(y.min()-1.6,y.max()+0.8)\n",
    "ax2 = ax1.twinx()\n",
    "ax2.axhline(0, linestyle='--', alpha=0.5)\n",
    "ax1.set_ylabel(r'$f_{\\theta}(x)$', fontsize=20)\n",
    "ax2.set_ylabel(r'$\\mathsf{s}_j\\|w_j\\|$', fontsize=20)\n",
    "\n",
    "ax1.yaxis.label.set_color(c1)\n",
    "ax2.yaxis.label.set_color(c2)\n",
    "\n",
    "ax2.spines[\"left\"].set_edgecolor(c1)\n",
    "ax2.spines[\"right\"].set_edgecolor(c2)\n",
    "\n",
    "ax1.tick_params(axis='y', colors=c1)\n",
    "ax2.tick_params(axis='y', colors=c2)\n",
    "multicolor_label(ax1,(r'$x$',r'$-w_{j,2}/w_{j,1}$'),(c1,c2),axis='x', fontsize=20)\n",
    "#######\n",
    "\n",
    "ax1.scatter(x.data.numpy(), y.data.numpy(), color=c1)\n",
    "# plot last iterate\n",
    "im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)\n",
    "im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')\n",
    "\n",
    "ax1.grid(alpha=0.5)\n",
    "\n",
    "plt.savefig('regularisation_full_lastiter.pdf')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Loss profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(losses, lw=3)\n",
    "plt.ylim(ymin=0)\n",
    "#plt.xlim(xmin=0, xmax=20000)\n",
    "plt.ylabel(r'$L(\\theta)$',fontsize=20)\n",
    "plt.xlabel('Iterations', fontsize=20)\n",
    "plt.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig('loss_profile.pdf',fontsize=20)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Weight decay without biases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# init network\n",
    "net = Net(n_feature=1, n_hidden=200, n_output=1, init_scale=1/np.sqrt(200), skip_connection=False, balanced=False, zero_output=False)     # define the network\n",
    " \n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=5e-2) #Gradient descent\n",
    "loss_func = torch.nn.MSELoss(reduction='mean')  # mean squared error\n",
    "reg_lambda = 1e-3 # scale of regularization\n",
    "\n",
    "n_samples = x.shape[0]\n",
    "n_iterations = 8000000 # number of descent steps\n",
    "\n",
    "loss = torch.Tensor(np.array([0]))\n",
    "previous_loss = torch.Tensor(np.array([np.infty]))\n",
    "\n",
    "# plot parameters\n",
    "iter_geom = 1.5 #saved frames correspond to step t=\\lceil k^{iter_geom} \\rceil for all integers k \n",
    "last_iter = 0\n",
    "frame = 0\n",
    "ims = []\n",
    "fig = plt.figure(\"Training dynamics\")\n",
    "plt.ioff()\n",
    "\n",
    "# Cosmetics\n",
    "c1 = 'tab:green' # color of left axis\n",
    "c2 = 'tab:blue' # color of right axis\n",
    "\n",
    "plt.subplots_adjust(left=0.15, right=0.85)\n",
    "\n",
    "ax1 = fig.add_subplot(111)\n",
    "ax1.set_xlim(x.min()-1,x.max()+1)\n",
    "ax1.set_ylim(y.min()-1.6,y.max()+0.8)\n",
    "ax2 = ax1.twinx()\n",
    "ax2.axhline(0, linestyle='--', alpha=0.5)\n",
    "ax1.set_ylabel(r'$f_{\\theta}(x)$', fontsize=20)\n",
    "ax2.set_ylabel(r'$\\mathsf{s}_j\\|w_j\\|$', fontsize=20)\n",
    "\n",
    "ax1.yaxis.label.set_color(c1)\n",
    "ax2.yaxis.label.set_color(c2)\n",
    "\n",
    "ax2.spines[\"left\"].set_edgecolor(c1)\n",
    "ax2.spines[\"right\"].set_edgecolor(c2)\n",
    "\n",
    "ax1.tick_params(axis='y', colors=c1)\n",
    "ax2.tick_params(axis='y', colors=c2)\n",
    "multicolor_label(ax1,(r'$x$',r'$-w_{j,2}/w_{j,1}$'),(c1,c2),axis='x', fontsize=20)\n",
    "#######\n",
    "\n",
    "iters = []\n",
    "losses = []\n",
    "z = torch.Tensor(np.linspace(x.min()-1,x.max()+1,100).reshape(-1,1))\n",
    "\n",
    "# train the network\n",
    "for it in tqdm(range(n_iterations)):\n",
    "    previous_loss = loss\n",
    "    prediction = net(x)\n",
    "    reg_loss = None\n",
    "    for name, param in net.named_parameters():\n",
    "        if \"weight\" in name and \"skip\" not in name:\n",
    "            if reg_loss is None:\n",
    "                reg_loss = 0.5 * torch.sum(param**2)\n",
    "            else:\n",
    "                reg_loss += 0.5 * param.norm(2)**2\n",
    "    loss = loss_func(prediction, y) + reg_lambda*reg_loss    # must be (1. nn output, 2. target)\n",
    "\n",
    "    if (it<2 or it==int(last_iter*iter_geom)+1): # save frame in animation\n",
    "        im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)\n",
    "        im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')\n",
    "        #t = ax1.annotate(\"iteration: \"+str(it)+\", frame: \"+str(frame),(0.4,0.95),xycoords='figure fraction',annotation_clip=False) # add text\n",
    "        t = ax1.annotate(\"iteration: \"+str(it),(0.4,0.95),xycoords='figure fraction',annotation_clip=False) # add text\n",
    "        if it == 0:\n",
    "            ax1.scatter(x.data.numpy(), y.data.numpy(), color=c1)\n",
    "        ims.append([im1,im2,t])\n",
    "        last_iter = it\n",
    "        iters.append(last_iter)\n",
    "        frame += 1\n",
    "\n",
    "    losses.append(loss.data.numpy())\n",
    "    optimizer.zero_grad()   # clear gradients for next train\n",
    "    loss.backward()         # backpropagation, compute gradients\n",
    "    optimizer.step()        # descent step\n",
    "    \n",
    "# plot last iterate\n",
    "im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)\n",
    "im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')\n",
    "t = ax1.annotate(\"iteration: \"+str(it),(0.4,0.95),xycoords='figure fraction',annotation_clip=False) # add text\n",
    "ims.append([im1,im2,t])\n",
    "iters.append(it)\n",
    "\n",
    "#############\n",
    "    \n",
    "ani = animation.ArtistAnimation(fig, ims, interval=100, repeat=False)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ani.save('regularisation_nobias.mp4', fps=10, dpi=120) # save animation as video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save specific frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del ani"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it_to_save = 3999999\n",
    "frame = iters.index(it_to_save)\n",
    "save_single_frame(fig, ims, frame) # save specific frame of animation as a .pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save last frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = torch.Tensor(np.linspace(x.min()-1,x.max()+1,10000).reshape(-1,1))\n",
    "\n",
    "fig = plt.figure(\"Training dynamics\")\n",
    "plt.ioff()\n",
    "\n",
    "# Cosmetics\n",
    "c1 = 'tab:green' # color of left axis\n",
    "c2 = 'tab:blue' # color of right axis\n",
    "\n",
    "plt.subplots_adjust(left=0.15, right=0.85)\n",
    "\n",
    "ax1 = fig.add_subplot(111)\n",
    "ax1.set_xlim(x.min()-1,x.max()+1)\n",
    "ax1.set_ylim(y.min()-1.6,y.max()+0.8)\n",
    "ax2 = ax1.twinx()\n",
    "ax2.axhline(0, linestyle='--', alpha=0.5)\n",
    "ax1.set_ylabel(r'$f_{\\theta}(x)$', fontsize=20)\n",
    "ax2.set_ylabel(r'$\\mathsf{s}_j\\|w_j\\|$', fontsize=20)\n",
    "\n",
    "ax1.yaxis.label.set_color(c1)\n",
    "ax2.yaxis.label.set_color(c2)\n",
    "\n",
    "ax2.spines[\"left\"].set_edgecolor(c1)\n",
    "ax2.spines[\"right\"].set_edgecolor(c2)\n",
    "\n",
    "ax1.tick_params(axis='y', colors=c1)\n",
    "ax2.tick_params(axis='y', colors=c2)\n",
    "multicolor_label(ax1,(r'$x$',r'$-w_{j,2}/w_{j,1}$'),(c1,c2),axis='x', fontsize=20)\n",
    "#######\n",
    "\n",
    "ax1.scatter(x.data.numpy(), y.data.numpy(), color=c1)\n",
    "# plot last iterate\n",
    "im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)\n",
    "im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')\n",
    "\n",
    "ax1.grid(alpha=0.5)\n",
    "\n",
    "plt.savefig('regularisation_nobias_lastiter.pdf')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Loss profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(losses, lw=3)\n",
    "plt.ylim(ymin=0,ymax=1)\n",
    "plt.xlim(xmin=0, xmax=5000)\n",
    "plt.ylabel(r'$L(\\theta)$',fontsize=20)\n",
    "plt.xlabel('Iterations', fontsize=20)\n",
    "plt.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig('loss_profile.pdf',fontsize=20)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
