{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "\n",
    "from torchvision.models import vgg16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the pretrained model and random model\n",
    "tvgg = vgg16(pretrained = True)\n",
    "rvgg = vgg16(pretrained = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Measure the distribution of average weights for each kernel in first layer\n",
    "tkerf = tvgg.features[0].weight.data\n",
    "rkerf = rvgg.features[0].weight.data\n",
    "\n",
    "tkerl = tvgg.features[7].weight.data\n",
    "rkerl = rvgg.features[7].weight.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmeanf = tkerf.numpy().mean((1, 2, 3))\n",
    "rmeanf = rkerf.numpy().mean((1, 2, 3))\n",
    "\n",
    "tmeanl = tkerl.numpy().mean((1, 2, 3))\n",
    "rmeanl = rkerl.numpy().mean((1, 2, 3))\n",
    "\n",
    "tmeanf = np.clip(tmeanf, -0.03, 0.03)\n",
    "rmeanf = np.clip(rmeanf, -0.03, 0.03)\n",
    "\n",
    "theighf, tbinsf = np.histogram(tmeanf, bins = 15)\n",
    "rheighf, rbinsf = np.histogram(rmeanf, bins = 15)\n",
    "\n",
    "theighl, tbinsl = np.histogram(tmeanl, bins = 15)\n",
    "rheighl, rbinsl = np.histogram(rmeanl, bins = 15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAIUAAABoCAYAAADfADNgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIhElEQVR4nO2da2xT5xnHf4+d2AFCwiXghUATygIFVmBVq+5SadPUZmKgVarYhymsW7XSbVILaFrFpEoTYmLqtGljEhVSYduXTlWF+mUqTGUb49O2dmpFualAQFxih6S5QIhz8+XdBx+fOH6dq+2cBD8/yTp+z/v4vP9j/rzPc17bOWKMQVEy8XktQJl9qCkUCzWFYqGmUCzUFIqFmkKxKCvCMfUad3Yg032hzhSKhZpCsVBTKBbFqCnmLE83bac13Oa2V9bV8o9T73moyBvUFBm0htvYuf+0235r/zc8VOMdmj4UCzWFYqGmUCzUFIqFmkKxUFMoFmoKxUJNoVioKRQLNYVioaZQLNQUioWaQrFQUygWagrFQk2hWKgpFAs1hWKhplAs1BQ5qGq77bUET1FTZLGs5QLbfvkTFt2+7rUUz1BTZFHReze1vX/XUx1eoqbIwkjqLZES/rNPaoosjKR+lyvJpMdKvENNkYXxpWcKNYWSxkkf6EyhpJFkAgCfsy1F1BRZ+BJxAPyxYY+VeIeaIgtfPAaoKZQMfAknfcRiHivxDjVFFun0URYb8liJd6gpskinj/S2FFFTZKGFpprCIl1TqCkUF736UFNYjKQPrSkUB1/cMUVcZwrFIT1T+DR9KGncdYphXadQHNLpQ02huLiFpppCAXhmwM/D//0nAGVDgx6r8Q41hYNJJtl7P+i2y4bVFCVPPPLZqLamD4XY9dE/ANKZQiF2vZUeGflavz8ex1ei3/JXUzgk7nTS5Td80PwK57/1XQAqStQUer8Ph2RvlD4xXP/qN1ly8yqPnnybYIYpmrY2EY6E3XbdijpO/e2UB0qLj5rCIdnbR7+TL+KBCgAqzMgN+8KRMK8dfc1tH9x1cGYFziCaPhyS96NEHQ/EA6lL01JNH2oKh0RvlH6n0IwHU6YITv/WnnOakkgf2fVAONyFSBmh5SEgdQO5PydC9DseSDjp47n+knh7LErirLPrgX3NB3hq55/YuGEjkLqBXHJ+JVGfYQGQKA/QvWoNm1tbPFLsLZo+HJL3+4im1ylEOL+9mUojmOHS+waWmgLApC5J+zNKiIHqJQAkPuv2SJR3PHDpI9d6wlhcuXqFWCxG95028M+nNdZDg9M3WLUYgHhHN2V1oUmNlR5vrq9fPHCmmMp6QiwWo7a+lierlkIUWhIDPOX0DS5cRBJDor1r0mNNNN5cYUZMMdtvF/1kPEB09VI6Wlrc2QPgaZPgyMu7udBYk/cYc2lFdEZMMdtvF92YKKNvfQhaRmYPgJ4yw7Pbvsb7Z/L/x5tLK6JzNn2Mlc+nw8qEn8EVi6z9XZJkdXf/tI45GaZbkxR75p2zpihUPvfHhllufFytq7b6unxJ1nVHp61xIqZ7DsWeeWelKWayBqnu6sCHMLjCNkW3JCnvmfpMEaqrZ2Bg5Hcj8+YFqKmpzEvnTDIrTTGTNchwsIK/BPtpDFVZfV2+JIHufgjmeOE4DAwM8+Lv/u22j/30K/nKnFGKboqmrU1E2iJcvHTR3dfe0T6lY7R3tLNx80a3Pdna4fU9h7h9O7X4tK/5AItrqvj5H/aOiulbXMPxeVF+Xe63Xt8pSQJdUX5rFvCz5B32NR8AYHGNbaAHiaKbIhwJU7202q3oARKJqf3luUQiMa3Kvaezl2173gWgtr6W4we3Tmncj8pj3HrhS/z9jyf48q5jVK1vBOD4wa1UzbdN9KDgWfoYa+UxHAkTaet2Z5by8nI35vU9h+jp7OVe1z23nf0/v5AMiKFtxxc5duwtts2vpFjzQ/q8AO513SNUVz+qBpnpNQ0xhf8b1O4B9+7dy9FjR4knDMvrN7kBbdc+Jhgsp35dvbvv5uWbANSvq+fapRtu/NDgMN3hC6zd9Hl3/9Bgqoi71/4pazY05DwGwLVLN6gOPQJAsCJAx81zrNnQMOb+K+daWFL3BYIVAYBx95f5ZdRYmRqGhmLUrnks7/PNjH3i8Sfc9of/+4hQw2a33X7jE17a9UMOHTqUKWXaXwYphilGDi7SClQCZ4s2SOHZ4mzPeqhhqmwB+owxKwtxsGKnjxYAY8zXizxOwRCRMzA3NRcK/ehcsVBTKBZFrSmUuYnOFIqFmkKxyNsUIjJfRN4RkRYR+VREto8Tu8uJuyYih0VSd1wRkS0i8rGInBWRiyLypohM8ROHSWldKyL/EZErzrYxR4xfRN5wNLaIyIuT6SsWBdD8goicc97b8yKye8JBjTF5PYBfAEed543AHaAyR9xqoBVYRsqM7wPPO33zgIDz3Ae8C+zOV1sODaeBnc7zncDpHDHPO9p8jtZWoGGivmI9CqC5ipHacSFwE9g07pgFEH0ReDyj/R7wnRxxrwKHM9o7gBM54oLASeDlAr+5y4G7gN9p+532sqy4E8COjPZh4NWJ+opkiLw1Z8WFgAjw6HjjFqKmeMhxX5pbwKqpxonIChE5C3QC94E3C6Atk1VA2BiTAHC2kRxax9M52XMtFIXQjIh8W0QuOjG/McacH2/QCU3h5PrOMR4F+6jQGBMxxmwBPgcEgOcKdexSxxjzV2PMRmAt8D0RWTde/ISmMMY8ZoypGeORIOXKzE+FHgJy3Sx8UnHGmCjwDtA8kbYpchuoSxvZ2a7IoWE8nZM910JRCM0uxphbwIfAmBcD6cB8895+Rhea7cDCHHEPYxea38/oCzrPA8DbwK+KkKPPMLpo+1eOmB9gF22rJ+orYl2Rr+b1GXE1wGWgadwxCyB6AXCc1Idfl4FnM/oOAD/OaP8IuOY8jjBSQO0EzgOfABeAN4B5RXiDHwE+AK4423XO/pM4xTKpYu5Ihs6XMl4/Zl8RTZGv5t+Tuhg467y/r0w0pi5zKxa6oqlYqCkUCzWFYqGmUCzUFIqFmkKxUFMoFmoKxeL/7R9ZTafAXQ0AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 141.732x113.386 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from utils.plotting import cm2inch\n",
    "\n",
    "plt.rcParams.update ({'font.size' : 11, \"text.usetex\": False})\n",
    "\n",
    "tX = 0.5 * (tbinsf[1:] + tbinsf[:-1])\n",
    "rX = 0.5 * (rbinsf[1:] + rbinsf[:-1])\n",
    "\n",
    "twidth = 0.5 * (tbinsf[1] - tbinsf[0]) + 0.001\n",
    "rwidth = 0.5 * (tbinsf[1] - tbinsf[0]) + 0.001\n",
    "\n",
    "fig, ax = plt.subplots(figsize = (cm2inch(5), cm2inch(4)))\n",
    "\n",
    "ax.bar(rX, rheighf, width = rwidth, bottom = 0.5, ec = 'k', color = 'darkseagreen', alpha = 0.8)\n",
    "ax.bar(tX, theighf, width = twidth, bottom = 0.5, ec = 'k', color = 'cornflowerblue', alpha = 0.8)\n",
    "\n",
    "rX = 0.5 * (rbinsl[1:] + rbinsl[:-1])\n",
    "\n",
    "ax.plot(rX, theighl, c = 'crimson', lw = 1.3)\n",
    "\n",
    "ax.set_xticks([-0.03, 0, 0.03])\n",
    "ax.set_yticks([])\n",
    "ax.set_yticklabels([])\n",
    "\n",
    "ax.xaxis.set_tick_params(width = 1.5)\n",
    "ax.yaxis.set_tick_params(width = 1.5)\n",
    "for s in ax.spines.values(): s.set_linewidth(1.5)\n",
    "\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['left'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.spines['bottom'].set_bounds(-0.03, 0.03)\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig('results/Kernel_weights.png', dpi = 300)\n",
    "fig.savefig('results/Kernel_weights.pdf', dpi = 300)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "d463e5fc653bab9d8e4f24359b8c55f1b66bd68a84a58e4a6e27c30ef7709220"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('SISSA': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
