{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import optim\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "from scipy.stats import ksone\n",
    "from scipy.stats import normaltest, kstest, shapiro\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# Get survival rates for linear and uniform modes\n",
    "def get_survival_rates(depth, mode='none', survival_prop=0.5):\n",
    "    if mode == 'uniform':\n",
    "        survival_rates = [survival_prop] * depth\n",
    "    elif mode == 'linear':\n",
    "        survival_rates =  [1 - 2 * float(i + 1) * (1-survival_prop) / float(depth+1)\n",
    "                       for i in range(depth)]\n",
    "    else:\n",
    "        survival_rates = None\n",
    "    return survival_rates\n",
    "\n",
    "# Adjust the learning rate \n",
    "def adjust_lr(optimizer, lr, e, epochs):\n",
    "    if e >= epochs/2:\n",
    "        lr /= 10\n",
    "        if e >= 3*epochs/4:\n",
    "            lr /= 10\n",
    "            \n",
    "        for param_group in optimizer.param_groups:\n",
    "            param_group['lr'] = lr\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Set plotting parameters\n",
    "SMALL_SIZE = 14\n",
    "MEDIUM_SIZE = 16\n",
    "BIGGER_SIZE = 18\n",
    "\n",
    "plt.rc('font', size=SMALL_SIZE)          # controls default text sizes\n",
    "plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title\n",
    "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Section 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Illustrating Equation 7\n",
    "\n",
    "Here we empirically verifie that the approximation of the Loss is valid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Import simple ResNet model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generate synthetic data\n",
    "\n",
    "We consider here a simple regression task, where the true model is given by\n",
    "\\\\( y = \\sin( X \\cdot \\beta ) \\\\)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/disk1/fadhel/py36/lib/python3.6/site-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  if sys.path[0] == '':\n"
     ]
    }
   ],
   "source": [
    "# Create dataset\n",
    "n = 10000\n",
    "in_features = 256 # input dimension\n",
    "\n",
    "n_train = 8000\n",
    "\n",
    "bs = 512 # Batch size\n",
    "\n",
    "beta = np.random.normal(size=in_features)\n",
    "\n",
    "X = torch.tensor(np.random.normal(size=(n, in_features)), dtype=torch.float32)\n",
    "y = torch.tensor(np.sin(X @ beta.transpose()), dtype=torch.float32)\n",
    "\n",
    "x_train = X[:n_train, :]\n",
    "y_train = y[:n_train]\n",
    "\n",
    "train_ds = TensorDataset(x_train, y_train)\n",
    "train_dl = DataLoader(train_ds, batch_size=bs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/164 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment  resnet50_linear_SD_0.5\n",
      "Run  0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 164/164 [03:45<00:00,  1.37s/it]\n",
      "  0%|          | 0/164 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment  resnet50_linear_SD_0.5\n",
      "Run  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/164 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-24-a4691a3a5a7a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     90\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m                     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmc_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m                         \u001b[0mmc_outputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     93\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     94\u001b[0m                 \u001b[0mloss_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmc_outputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/mnt/disk1/fadhel/py36/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/mnt/disk1/fadhel/stoch_depth/model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    115\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc_in\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresnet_layers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    118\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/mnt/disk1/fadhel/py36/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/mnt/disk1/fadhel/py36/lib/python3.6/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    117\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    118\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    120\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    121\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/mnt/disk1/fadhel/py36/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    884\u001b[0m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbw_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup_input_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    885\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 886\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_tracing_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import itertools\n",
    "\n",
    "# Define Simple ResNet Model\n",
    "depth = 50\n",
    "mode = 'linear'\n",
    "survival_prop = 0.8\n",
    "save_epochs = [0, 40, 80, 120, 160]\n",
    "mc_samples = 200\n",
    "n_runs = 10\n",
    "\n",
    "# Optimize parameters\n",
    "lr = 0.05\n",
    "\n",
    "device = \"cuda:0\"\n",
    "\n",
    "# Init experiment params\n",
    "epochs = 164\n",
    "\n",
    "experiments = []\n",
    "    \n",
    "for r in range(n_runs):\n",
    "    expe = {\n",
    "        'name': 'resnet{}_{}_SD_{}'.format(depth, mode, survival_prop),\n",
    "        'run': r,\n",
    "        'mode': mode,\n",
    "        'depth': depth,\n",
    "        'survival_prop': survival_prop,\n",
    "        'loss_history': [],\n",
    "        'losses_avg_model': [],\n",
    "        'losses_approx': [],\n",
    "        'losses_mc': [],\n",
    "    }\n",
    "    \n",
    "    experiments += [expe]\n",
    "    \n",
    "    print(\"Experiment \", expe['name'])\n",
    "    print(\"Run \", expe['run'])\n",
    "    \n",
    "    survival_rates = get_survival_rates(depth, mode, survival_prop)\n",
    "    \n",
    "    model = SimpleResNet(\n",
    "        in_features=in_features,\n",
    "        out_features=1,\n",
    "        h_features=128,\n",
    "        depth=depth,\n",
    "        block=BasicBlock,\n",
    "        res_scale=1./np.sqrt(depth),\n",
    "        survival_rates=survival_rates,\n",
    "    )\n",
    "    \n",
    "    model = model.to(device)\n",
    "    \n",
    "    optimizer = optim.SGD(model.parameters(), lr=lr)\n",
    "    \n",
    "    for e in tqdm(range(epochs)):       \n",
    "\n",
    "        adjust_lr(optimizer, lr, e, epochs)\n",
    "        loss_fn = nn.MSELoss(reduction='mean')\n",
    "        \n",
    "        avg_loss = 0\n",
    "        n_b = 0\n",
    "        for xb, yb in train_dl:\n",
    "            xb = xb.to(device)\n",
    "            yb = yb.to(device)\n",
    "\n",
    "            pred = model(xb)\n",
    "            loss = loss_fn(pred, yb)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            avg_loss += loss.detach().cpu().numpy()\n",
    "            n_b += 1\n",
    "            \n",
    "        avg_loss /= n_b\n",
    "        expe['loss_history'] += [avg_loss]\n",
    "            \n",
    "        if e in save_epochs:\n",
    "            losses_approx = []\n",
    "            losses_avg_model = []\n",
    "            losses_mc = []\n",
    "            for xb, yb in train_dl:\n",
    "                mc_outputs = []\n",
    "                xb = xb.to(device)\n",
    "                yb = yb.to(device)\n",
    "\n",
    "                model.train()\n",
    "                with torch.no_grad():\n",
    "                    for i in range(mc_samples):\n",
    "                        mc_outputs += [loss_fn(model(xb), yb).cpu().detach().numpy()]\n",
    "\n",
    "                loss_samples = np.array(mc_outputs)\n",
    "\n",
    "                L = np.mean(loss_samples)\n",
    "\n",
    "                old_p = model.survival_rates.copy()\n",
    "\n",
    "                A = torch.ones((depth, len(xb), 1), requires_grad=True, device=device)\n",
    "                model.reset_survival_rates(A)\n",
    "\n",
    "                #loss = loss_fn(model(xb), yb)\n",
    "                optimizer.zero_grad()\n",
    "                model(xb).sum().backward()\n",
    "                coeffs = A.grad.detach().cpu().numpy().squeeze()\n",
    "                A.grad.zero_()\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                model.reset_survival_rates(old_p)\n",
    "                model.eval()\n",
    "                avg_model_loss = loss_fn(model(xb), yb)\n",
    "\n",
    "                p_fact = (np.array(old_p)*(1-np.array(old_p))).reshape(-1, 1)\n",
    "                pen = np.sum(p_fact*coeffs**2, axis=0)\n",
    "\n",
    "                L_approx = np.mean(pen)+avg_model_loss\n",
    "                losses_approx += [(L-L_approx).detach().cpu().numpy()]\n",
    "                losses_avg_model += [(L-avg_model_loss).detach().cpu().numpy()]\n",
    "                losses_mc += [L]\n",
    "                \n",
    "            model.train()\n",
    "\n",
    "            expe['losses_approx'] += [np.mean(losses_approx)]\n",
    "            expe['losses_avg_model'] += [np.mean(losses_avg_model)]\n",
    "            expe['losses_mc'] += [np.mean(losses_mc)]\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses_approx = np.zeros((len(save_epochs), n_runs))\n",
    "losses_avg_model = np.zeros((len(save_epochs), n_runs))\n",
    "losses_mc = np.zeros((len(save_epochs), n_runs))\n",
    "\n",
    "for r in range(n_runs):\n",
    "    losses_approx[:, r] = experiments[r]['losses_approx']\n",
    "    losses_avg_model[:, r] = experiments[r]['losses_avg_model']\n",
    "    losses_mc[:, r] = experiments[r]['losses_mc']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment resnet50_linear_SD_0.5\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>$ | \\mathcal{L} - \\bar{\\mathcal{L}} | $</th>\n",
       "      <th>$ | \\mathcal{L} - \\bar{\\mathcal{L}} - pen| $</th>\n",
       "      <th>Ratio</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>epoch</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.010671</td>\n",
       "      <td>0.008086</td>\n",
       "      <td>1.319673</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>0.265010</td>\n",
       "      <td>0.183009</td>\n",
       "      <td>1.448075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80</th>\n",
       "      <td>0.619548</td>\n",
       "      <td>0.466230</td>\n",
       "      <td>1.328847</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>120</th>\n",
       "      <td>0.802216</td>\n",
       "      <td>0.624529</td>\n",
       "      <td>1.284515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>160</th>\n",
       "      <td>0.841235</td>\n",
       "      <td>0.635900</td>\n",
       "      <td>1.322904</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       $ | \\mathcal{L} - \\bar{\\mathcal{L}} | $  \\\n",
       "epoch                                            \n",
       "0                                     0.010671   \n",
       "40                                    0.265010   \n",
       "80                                    0.619548   \n",
       "120                                   0.802216   \n",
       "160                                   0.841235   \n",
       "\n",
       "       $ | \\mathcal{L} - \\bar{\\mathcal{L}} - pen| $     Ratio  \n",
       "epoch                                                          \n",
       "0                                          0.008086  1.319673  \n",
       "40                                         0.183009  1.448075  \n",
       "80                                         0.466230  1.328847  \n",
       "120                                        0.624529  1.284515  \n",
       "160                                        0.635900  1.322904  "
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print('Experiment resnet{}_{}_SD_{}'.format(depth, mode, survival_prop))\n",
    "res = pd.DataFrame()\n",
    "res['epoch'] = save_epochs\n",
    "res[r'$ | \\mathcal{L} - \\bar{\\mathcal{L}} | $'] = np.mean(np.abs(losses_avg_model)/np.array(losses_mc), axis=1)\n",
    "res[r'$ | \\mathcal{L} - \\bar{\\mathcal{L}} - pen| $'] = np.mean(np.abs(losses_approx)/np.array(losses_mc), axis=1)\n",
    "res['Ratio'] = np.mean(np.abs(losses_avg_model), axis=1) / np.mean(np.abs(losses_approx), axis=1)\n",
    "\n",
    "res.set_index('epoch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py36",
   "language": "python",
   "name": "py36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
