{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "7c42a7ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from reluNets.plotting.network_plotting_functions import plot_single_metric_axis, save_figure\n",
    "import matplotlib.colors as colors\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "import numpy as np\n",
    "import random\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "# This notebook plot the TNR and TPR of th"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d3663df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of parameters in layer 1: 401408\n",
      "Number of parameters in layer 2: 5120\n"
     ]
    }
   ],
   "source": [
    "# train a small relu network on the task xor task \n",
    "# hyperparameters\n",
    "input_size = 784\n",
    "hidden_size = 512\n",
    "output_size = 10\n",
    "leak = 0.1\n",
    "bias_magnitude = 0\n",
    "lr = 0.0005\n",
    "init_scale = 0.0002\n",
    "bias_init_scale = 0.0002\n",
    "has_bias = False\n",
    "num_epochs = 300\n",
    "\n",
    "# define a pytorch model and initialize weights as small gaussian\n",
    "class ReluNet(torch.nn.Module):\n",
    "    def __init__(self, input_size=input_size, hidden_size=hidden_size, output_size=output_size, has_bias=False):\n",
    "        super(ReluNet, self).__init__()\n",
    "        self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=has_bias)\n",
    "        self.fc2 = torch.nn.Linear(hidden_size, output_size, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, input_size)\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "    def get_hidden(self, x):\n",
    "        x = x.view(-1, input_size)\n",
    "        return torch.relu(self.fc1(x))\n",
    "    \n",
    "# initialize the model\n",
    "model = ReluNet()\n",
    "# initialize the weights\n",
    "init_weights_1 = torch.randn_like(model.fc1.weight.data)*init_scale\n",
    "init_weights_2 = torch.randn_like(model.fc2.weight.data)*init_scale\n",
    "\n",
    "# intialize the biases\n",
    "if has_bias:\n",
    "    init_bias_1 = torch.randn_like(model.fc1.bias.data)*bias_init_scale\n",
    "    model.fc1.bias.data = init_bias_1\n",
    "\n",
    "model.fc1.weight.data = init_weights_1\n",
    "model.fc2.weight.data = init_weights_2\n",
    "\n",
    "# copy and detach weights to keep them for later use\n",
    "init_weights_1 = init_weights_1.detach().clone()\n",
    "init_weights_2 = init_weights_2.detach().clone()\n",
    "\n",
    "print(f\"Number of parameters in layer 1: {model.fc1.weight.data.numel()}\")\n",
    "print(f\"Number of parameters in layer 2: {model.fc2.weight.data.numel()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "315154f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing per-pixel means...\n",
      "Computed means for 60000 images. Mean range: [0.0000, 0.5473]\n",
      "Masking 140 pixels (out of 784) with mean > 0.5\n",
      "Masked pixel positions: [180, 181, 182, 183, 184, 185, 186, 207, 208, 209]...\n"
     ]
    }
   ],
   "source": [
    "# First, load raw MNIST data without normalization to compute pixel means\n",
    "raw_transform = transforms.Compose([transforms.ToTensor()])\n",
    "raw_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=raw_transform)\n",
    "\n",
    "# Compute per-pixel mean across entire training set\n",
    "print(\"Computing per-pixel means...\")\n",
    "pixel_sums = torch.zeros(784)\n",
    "num_samples = len(raw_train_dataset)\n",
    "\n",
    "# Use a dataloader to efficiently compute means\n",
    "temp_loader = DataLoader(raw_train_dataset, batch_size=1000, shuffle=False)\n",
    "for images, _ in temp_loader:\n",
    "    # Flatten images and accumulate pixel sums\n",
    "    flattened = images.view(-1, 784)\n",
    "    pixel_sums += flattened.sum(dim=0)\n",
    "\n",
    "# Compute per-pixel means\n",
    "pixel_means = pixel_sums / num_samples\n",
    "print(f\"Computed means for {num_samples} images. Mean range: [{pixel_means.min():.4f}, {pixel_means.max():.4f}]\")\n",
    "\n",
    "\n",
    "# Custom transform that subtracts per-pixel means\n",
    "class PerPixelNormalizeBias:\n",
    "    def __init__(self, pixel_means, bias_magnitude):\n",
    "        self.pixel_means = pixel_means\n",
    "        self.bias_magnitude = bias_magnitude  # Bias pixel magnitude\n",
    "    \n",
    "    def __call__(self, tensor):\n",
    "        # Flatten, subtract means, return flattened\n",
    "        flattened = tensor.view(-1)\n",
    "        flattened = flattened - self.pixel_means\n",
    "        # add a bias pixel by concatenating a bias pixel\n",
    "        bias_pixel = torch.ones(1) * bias_magnitude\n",
    "        flattened = torch.cat((flattened, bias_pixel))\n",
    "        return flattened\n",
    "\n",
    "class PerPixelNormalize:\n",
    "    def __init__(self, pixel_means):\n",
    "        self.pixel_means = pixel_means\n",
    "    \n",
    "    def __call__(self, tensor):\n",
    "        # Flatten, subtract means, return flattened\n",
    "        flattened = tensor.view(-1)\n",
    "        flattened = flattened - self.pixel_means\n",
    "        return flattened\n",
    "\n",
    "class AddBias:\n",
    "    def __init__(self, bias_magnitude):\n",
    "        self.bias_magnitude = bias_magnitude  # Bias pixel magnitude\n",
    "    \n",
    "    def __call__(self, tensor):\n",
    "        # Flatten and add a bias pixel\n",
    "        flattened = tensor.view(-1)\n",
    "        bias_pixel = torch.ones(1) * self.bias_magnitude\n",
    "        flattened = torch.cat((flattened, bias_pixel))\n",
    "        return flattened\n",
    "\n",
    "# Create transform with per-pixel mean subtraction\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    PerPixelNormalize(pixel_means),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "1be892b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "# create data loaders\n",
    "batch_size = 64\n",
    "test_batch_size = 500\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c3c954",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/300], Loss: 0.3475\n",
      "Epoch [2/300], Loss: 0.1429\n",
      "Epoch [3/300], Loss: 0.1184\n",
      "Epoch [4/300], Loss: 0.1077\n",
      "Epoch [5/300], Loss: 0.0990\n",
      "Epoch [6/300], Loss: 0.0916\n",
      "Epoch [7/300], Loss: 0.0851\n",
      "Epoch [8/300], Loss: 0.0796\n",
      "Epoch [9/300], Loss: 0.0750\n",
      "Epoch [10/300], Loss: 0.0711\n",
      "Epoch [11/300], Loss: 0.0678\n",
      "Epoch [12/300], Loss: 0.0649\n",
      "Epoch [13/300], Loss: 0.0624\n",
      "Epoch [14/300], Loss: 0.0602\n",
      "Epoch [15/300], Loss: 0.0581\n",
      "Epoch [16/300], Loss: 0.0564\n",
      "Epoch [17/300], Loss: 0.0547\n",
      "Epoch [18/300], Loss: 0.0532\n",
      "Epoch [19/300], Loss: 0.0519\n",
      "Epoch [20/300], Loss: 0.0505\n",
      "Epoch [21/300], Loss: 0.0494\n",
      "Epoch [22/300], Loss: 0.0483\n",
      "Epoch [23/300], Loss: 0.0473\n",
      "Epoch [24/300], Loss: 0.0463\n",
      "Epoch [25/300], Loss: 0.0454\n",
      "Epoch [26/300], Loss: 0.0446\n",
      "Epoch [27/300], Loss: 0.0438\n",
      "Epoch [28/300], Loss: 0.0430\n",
      "Epoch [29/300], Loss: 0.0423\n",
      "Epoch [30/300], Loss: 0.0417\n",
      "Epoch [31/300], Loss: 0.0410\n",
      "Epoch [32/300], Loss: 0.0404\n",
      "Epoch [33/300], Loss: 0.0398\n",
      "Epoch [34/300], Loss: 0.0393\n",
      "Epoch [35/300], Loss: 0.0388\n",
      "Epoch [36/300], Loss: 0.0383\n",
      "Epoch [37/300], Loss: 0.0378\n",
      "Epoch [38/300], Loss: 0.0373\n",
      "Epoch [39/300], Loss: 0.0369\n",
      "Epoch [40/300], Loss: 0.0365\n",
      "Epoch [41/300], Loss: 0.0361\n",
      "Epoch [42/300], Loss: 0.0357\n",
      "Epoch [43/300], Loss: 0.0353\n",
      "Epoch [44/300], Loss: 0.0349\n",
      "Epoch [45/300], Loss: 0.0346\n",
      "Epoch [46/300], Loss: 0.0342\n",
      "Epoch [47/300], Loss: 0.0339\n",
      "Epoch [48/300], Loss: 0.0336\n",
      "Epoch [49/300], Loss: 0.0333\n",
      "Epoch [50/300], Loss: 0.0330\n",
      "Epoch [51/300], Loss: 0.0327\n",
      "Epoch [52/300], Loss: 0.0325\n",
      "Epoch [53/300], Loss: 0.0322\n",
      "Epoch [54/300], Loss: 0.0319\n",
      "Epoch [55/300], Loss: 0.0316\n",
      "Epoch [56/300], Loss: 0.0314\n",
      "Epoch [57/300], Loss: 0.0311\n",
      "Epoch [58/300], Loss: 0.0309\n",
      "Epoch [59/300], Loss: 0.0307\n",
      "Epoch [60/300], Loss: 0.0305\n",
      "Epoch [61/300], Loss: 0.0302\n",
      "Epoch [62/300], Loss: 0.0300\n",
      "Epoch [63/300], Loss: 0.0298\n",
      "Epoch [64/300], Loss: 0.0296\n",
      "Epoch [65/300], Loss: 0.0294\n",
      "Epoch [66/300], Loss: 0.0292\n",
      "Epoch [67/300], Loss: 0.0290\n",
      "Epoch [68/300], Loss: 0.0288\n",
      "Epoch [69/300], Loss: 0.0286\n",
      "Epoch [70/300], Loss: 0.0285\n",
      "Epoch [71/300], Loss: 0.0283\n",
      "Epoch [72/300], Loss: 0.0281\n",
      "Epoch [73/300], Loss: 0.0279\n",
      "Epoch [74/300], Loss: 0.0278\n",
      "Epoch [75/300], Loss: 0.0276\n",
      "Epoch [76/300], Loss: 0.0275\n",
      "Epoch [77/300], Loss: 0.0273\n",
      "Epoch [78/300], Loss: 0.0272\n",
      "Epoch [79/300], Loss: 0.0270\n",
      "Epoch [80/300], Loss: 0.0269\n",
      "Epoch [81/300], Loss: 0.0267\n",
      "Epoch [82/300], Loss: 0.0266\n",
      "Epoch [83/300], Loss: 0.0265\n",
      "Epoch [84/300], Loss: 0.0263\n",
      "Epoch [85/300], Loss: 0.0262\n",
      "Epoch [86/300], Loss: 0.0261\n",
      "Epoch [87/300], Loss: 0.0259\n",
      "Epoch [88/300], Loss: 0.0258\n",
      "Epoch [89/300], Loss: 0.0257\n",
      "Epoch [90/300], Loss: 0.0256\n",
      "Epoch [91/300], Loss: 0.0255\n",
      "Epoch [92/300], Loss: 0.0253\n",
      "Epoch [93/300], Loss: 0.0252\n",
      "Epoch [94/300], Loss: 0.0251\n",
      "Epoch [1/300], Loss: 0.3487\n",
      "Epoch [2/300], Loss: 0.1437\n",
      "Epoch [3/300], Loss: 0.1186\n",
      "Epoch [4/300], Loss: 0.1079\n",
      "Epoch [5/300], Loss: 0.0991\n",
      "Epoch [6/300], Loss: 0.0915\n",
      "Epoch [7/300], Loss: 0.0851\n",
      "Epoch [8/300], Loss: 0.0796\n",
      "Epoch [9/300], Loss: 0.0751\n",
      "Epoch [10/300], Loss: 0.0712\n",
      "Epoch [11/300], Loss: 0.0679\n",
      "Epoch [12/300], Loss: 0.0650\n",
      "Epoch [13/300], Loss: 0.0625\n",
      "Epoch [14/300], Loss: 0.0603\n",
      "Epoch [15/300], Loss: 0.0583\n",
      "Epoch [16/300], Loss: 0.0565\n",
      "Epoch [17/300], Loss: 0.0548\n",
      "Epoch [18/300], Loss: 0.0534\n",
      "Epoch [19/300], Loss: 0.0520\n",
      "Epoch [20/300], Loss: 0.0507\n",
      "Epoch [21/300], Loss: 0.0495\n",
      "Epoch [22/300], Loss: 0.0484\n",
      "Epoch [23/300], Loss: 0.0474\n",
      "Epoch [24/300], Loss: 0.0465\n",
      "Epoch [25/300], Loss: 0.0455\n",
      "Epoch [26/300], Loss: 0.0447\n",
      "Epoch [27/300], Loss: 0.0439\n",
      "Epoch [28/300], Loss: 0.0432\n",
      "Epoch [29/300], Loss: 0.0425\n",
      "Epoch [30/300], Loss: 0.0418\n",
      "Epoch [31/300], Loss: 0.0412\n",
      "Epoch [32/300], Loss: 0.0406\n",
      "Epoch [33/300], Loss: 0.0400\n",
      "Epoch [34/300], Loss: 0.0394\n",
      "Epoch [35/300], Loss: 0.0389\n",
      "Epoch [36/300], Loss: 0.0384\n",
      "Epoch [37/300], Loss: 0.0379\n",
      "Epoch [38/300], Loss: 0.0374\n",
      "Epoch [39/300], Loss: 0.0370\n",
      "Epoch [40/300], Loss: 0.0366\n",
      "Epoch [41/300], Loss: 0.0362\n",
      "Epoch [42/300], Loss: 0.0358\n",
      "Epoch [43/300], Loss: 0.0354\n",
      "Epoch [44/300], Loss: 0.0350\n",
      "Epoch [45/300], Loss: 0.0347\n",
      "Epoch [46/300], Loss: 0.0343\n",
      "Epoch [47/300], Loss: 0.0340\n",
      "Epoch [48/300], Loss: 0.0337\n",
      "Epoch [49/300], Loss: 0.0333\n",
      "Epoch [50/300], Loss: 0.0331\n",
      "Epoch [51/300], Loss: 0.0328\n",
      "Epoch [52/300], Loss: 0.0325\n",
      "Epoch [53/300], Loss: 0.0322\n",
      "Epoch [54/300], Loss: 0.0320\n",
      "Epoch [55/300], Loss: 0.0317\n",
      "Epoch [56/300], Loss: 0.0314\n",
      "Epoch [57/300], Loss: 0.0312\n",
      "Epoch [58/300], Loss: 0.0309\n",
      "Epoch [59/300], Loss: 0.0307\n",
      "Epoch [60/300], Loss: 0.0305\n",
      "Epoch [61/300], Loss: 0.0303\n",
      "Epoch [62/300], Loss: 0.0300\n",
      "Epoch [63/300], Loss: 0.0298\n",
      "Epoch [64/300], Loss: 0.0296\n",
      "Epoch [65/300], Loss: 0.0294\n",
      "Epoch [66/300], Loss: 0.0292\n",
      "Epoch [67/300], Loss: 0.0290\n",
      "Epoch [68/300], Loss: 0.0289\n",
      "Epoch [69/300], Loss: 0.0287\n",
      "Epoch [70/300], Loss: 0.0285\n",
      "Epoch [71/300], Loss: 0.0283\n",
      "Epoch [72/300], Loss: 0.0281\n",
      "Epoch [73/300], Loss: 0.0280\n",
      "Epoch [74/300], Loss: 0.0278\n",
      "Epoch [75/300], Loss: 0.0277\n",
      "Epoch [76/300], Loss: 0.0275\n",
      "Epoch [77/300], Loss: 0.0273\n",
      "Epoch [78/300], Loss: 0.0272\n",
      "Epoch [79/300], Loss: 0.0271\n",
      "Epoch [80/300], Loss: 0.0269\n",
      "Epoch [81/300], Loss: 0.0267\n",
      "Epoch [82/300], Loss: 0.0266\n",
      "Epoch [83/300], Loss: 0.0265\n",
      "Epoch [84/300], Loss: 0.0264\n",
      "Epoch [85/300], Loss: 0.0262\n",
      "Epoch [86/300], Loss: 0.0261\n",
      "Epoch [87/300], Loss: 0.0260\n",
      "Epoch [88/300], Loss: 0.0258\n",
      "Epoch [89/300], Loss: 0.0257\n",
      "Epoch [90/300], Loss: 0.0256\n",
      "Epoch [91/300], Loss: 0.0255\n",
      "Epoch [92/300], Loss: 0.0254\n",
      "Epoch [93/300], Loss: 0.0252\n",
      "Epoch [94/300], Loss: 0.0251\n",
      "Epoch [95/300], Loss: 0.0250\n",
      "Epoch [1/300], Loss: 0.3503\n",
      "Epoch [2/300], Loss: 0.1437\n",
      "Epoch [3/300], Loss: 0.1182\n",
      "Epoch [4/300], Loss: 0.1075\n",
      "Epoch [5/300], Loss: 0.0988\n",
      "Epoch [6/300], Loss: 0.0912\n",
      "Epoch [7/300], Loss: 0.0847\n",
      "Epoch [8/300], Loss: 0.0792\n",
      "Epoch [9/300], Loss: 0.0747\n",
      "Epoch [10/300], Loss: 0.0708\n",
      "Epoch [11/300], Loss: 0.0675\n",
      "Epoch [12/300], Loss: 0.0646\n",
      "Epoch [13/300], Loss: 0.0621\n",
      "Epoch [14/300], Loss: 0.0599\n",
      "Epoch [15/300], Loss: 0.0580\n",
      "Epoch [16/300], Loss: 0.0561\n",
      "Epoch [17/300], Loss: 0.0545\n",
      "Epoch [18/300], Loss: 0.0530\n",
      "Epoch [19/300], Loss: 0.0516\n",
      "Epoch [20/300], Loss: 0.0504\n",
      "Epoch [21/300], Loss: 0.0492\n",
      "Epoch [22/300], Loss: 0.0481\n",
      "Epoch [23/300], Loss: 0.0470\n",
      "Epoch [24/300], Loss: 0.0461\n",
      "Epoch [25/300], Loss: 0.0452\n",
      "Epoch [26/300], Loss: 0.0443\n",
      "Epoch [27/300], Loss: 0.0435\n",
      "Epoch [28/300], Loss: 0.0428\n",
      "Epoch [29/300], Loss: 0.0421\n",
      "Epoch [30/300], Loss: 0.0414\n",
      "Epoch [31/300], Loss: 0.0408\n",
      "Epoch [32/300], Loss: 0.0402\n",
      "Epoch [33/300], Loss: 0.0396\n",
      "Epoch [34/300], Loss: 0.0390\n",
      "Epoch [35/300], Loss: 0.0385\n",
      "Epoch [36/300], Loss: 0.0380\n",
      "Epoch [37/300], Loss: 0.0375\n",
      "Epoch [38/300], Loss: 0.0370\n",
      "Epoch [39/300], Loss: 0.0366\n",
      "Epoch [40/300], Loss: 0.0362\n",
      "Epoch [41/300], Loss: 0.0357\n",
      "Epoch [42/300], Loss: 0.0354\n",
      "Epoch [43/300], Loss: 0.0349\n",
      "Epoch [44/300], Loss: 0.0346\n",
      "Epoch [45/300], Loss: 0.0342\n",
      "Epoch [46/300], Loss: 0.0339\n",
      "Epoch [47/300], Loss: 0.0336\n",
      "Epoch [48/300], Loss: 0.0332\n",
      "Epoch [49/300], Loss: 0.0329\n",
      "Epoch [50/300], Loss: 0.0326\n",
      "Epoch [51/300], Loss: 0.0323\n",
      "Epoch [52/300], Loss: 0.0320\n",
      "Epoch [53/300], Loss: 0.0318\n",
      "Epoch [54/300], Loss: 0.0315\n",
      "Epoch [55/300], Loss: 0.0312\n",
      "Epoch [56/300], Loss: 0.0310\n",
      "Epoch [57/300], Loss: 0.0307\n",
      "Epoch [58/300], Loss: 0.0305\n",
      "Epoch [59/300], Loss: 0.0303\n",
      "Epoch [60/300], Loss: 0.0300\n",
      "Epoch [61/300], Loss: 0.0298\n",
      "Epoch [62/300], Loss: 0.0296\n",
      "Epoch [63/300], Loss: 0.0294\n",
      "Epoch [64/300], Loss: 0.0292\n",
      "Epoch [65/300], Loss: 0.0290\n",
      "Epoch [66/300], Loss: 0.0288\n",
      "Epoch [67/300], Loss: 0.0286\n",
      "Epoch [68/300], Loss: 0.0284\n",
      "Epoch [69/300], Loss: 0.0282\n",
      "Epoch [70/300], Loss: 0.0281\n",
      "Epoch [71/300], Loss: 0.0279\n",
      "Epoch [72/300], Loss: 0.0277\n",
      "Epoch [73/300], Loss: 0.0275\n",
      "Epoch [74/300], Loss: 0.0274\n",
      "Epoch [75/300], Loss: 0.0273\n",
      "Epoch [76/300], Loss: 0.0271\n",
      "Epoch [77/300], Loss: 0.0269\n",
      "Epoch [78/300], Loss: 0.0268\n",
      "Epoch [79/300], Loss: 0.0266\n",
      "Epoch [80/300], Loss: 0.0265\n",
      "Epoch [81/300], Loss: 0.0264\n",
      "Epoch [82/300], Loss: 0.0262\n",
      "Epoch [83/300], Loss: 0.0261\n",
      "Epoch [84/300], Loss: 0.0259\n",
      "Epoch [85/300], Loss: 0.0258\n",
      "Epoch [86/300], Loss: 0.0257\n",
      "Epoch [87/300], Loss: 0.0256\n",
      "Epoch [88/300], Loss: 0.0255\n",
      "Epoch [89/300], Loss: 0.0253\n",
      "Epoch [90/300], Loss: 0.0252\n",
      "Epoch [91/300], Loss: 0.0251\n",
      "Epoch [1/300], Loss: 0.3472\n",
      "Epoch [2/300], Loss: 0.1426\n",
      "Epoch [3/300], Loss: 0.1184\n",
      "Epoch [4/300], Loss: 0.1078\n",
      "Epoch [5/300], Loss: 0.0990\n",
      "Epoch [6/300], Loss: 0.0914\n",
      "Epoch [7/300], Loss: 0.0849\n",
      "Epoch [8/300], Loss: 0.0793\n",
      "Epoch [9/300], Loss: 0.0746\n",
      "Epoch [10/300], Loss: 0.0707\n",
      "Epoch [11/300], Loss: 0.0674\n",
      "Epoch [12/300], Loss: 0.0645\n",
      "Epoch [13/300], Loss: 0.0620\n",
      "Epoch [14/300], Loss: 0.0597\n",
      "Epoch [15/300], Loss: 0.0577\n",
      "Epoch [16/300], Loss: 0.0559\n",
      "Epoch [17/300], Loss: 0.0543\n",
      "Epoch [18/300], Loss: 0.0528\n",
      "Epoch [19/300], Loss: 0.0514\n",
      "Epoch [20/300], Loss: 0.0501\n",
      "Epoch [21/300], Loss: 0.0490\n",
      "Epoch [22/300], Loss: 0.0479\n",
      "Epoch [23/300], Loss: 0.0469\n",
      "Epoch [24/300], Loss: 0.0459\n",
      "Epoch [25/300], Loss: 0.0451\n",
      "Epoch [26/300], Loss: 0.0443\n",
      "Epoch [27/300], Loss: 0.0435\n",
      "Epoch [28/300], Loss: 0.0427\n",
      "Epoch [29/300], Loss: 0.0421\n",
      "Epoch [30/300], Loss: 0.0414\n",
      "Epoch [31/300], Loss: 0.0407\n",
      "Epoch [32/300], Loss: 0.0402\n",
      "Epoch [33/300], Loss: 0.0396\n",
      "Epoch [34/300], Loss: 0.0390\n",
      "Epoch [35/300], Loss: 0.0385\n",
      "Epoch [36/300], Loss: 0.0380\n",
      "Epoch [37/300], Loss: 0.0375\n",
      "Epoch [38/300], Loss: 0.0371\n",
      "Epoch [39/300], Loss: 0.0366\n",
      "Epoch [40/300], Loss: 0.0362\n",
      "Epoch [41/300], Loss: 0.0358\n",
      "Epoch [42/300], Loss: 0.0354\n",
      "Epoch [43/300], Loss: 0.0350\n",
      "Epoch [44/300], Loss: 0.0347\n",
      "Epoch [45/300], Loss: 0.0343\n",
      "Epoch [46/300], Loss: 0.0340\n",
      "Epoch [47/300], Loss: 0.0336\n",
      "Epoch [48/300], Loss: 0.0333\n",
      "Epoch [49/300], Loss: 0.0330\n",
      "Epoch [50/300], Loss: 0.0327\n",
      "Epoch [51/300], Loss: 0.0324\n",
      "Epoch [52/300], Loss: 0.0321\n",
      "Epoch [53/300], Loss: 0.0319\n",
      "Epoch [54/300], Loss: 0.0316\n",
      "Epoch [55/300], Loss: 0.0314\n",
      "Epoch [56/300], Loss: 0.0311\n",
      "Epoch [57/300], Loss: 0.0309\n",
      "Epoch [58/300], Loss: 0.0307\n",
      "Epoch [59/300], Loss: 0.0304\n",
      "Epoch [60/300], Loss: 0.0302\n",
      "Epoch [61/300], Loss: 0.0300\n",
      "Epoch [62/300], Loss: 0.0298\n",
      "Epoch [63/300], Loss: 0.0296\n",
      "Epoch [64/300], Loss: 0.0294\n",
      "Epoch [65/300], Loss: 0.0292\n",
      "Epoch [66/300], Loss: 0.0290\n",
      "Epoch [67/300], Loss: 0.0288\n",
      "Epoch [68/300], Loss: 0.0286\n",
      "Epoch [69/300], Loss: 0.0284\n",
      "Epoch [70/300], Loss: 0.0282\n",
      "Epoch [71/300], Loss: 0.0281\n",
      "Epoch [72/300], Loss: 0.0279\n",
      "Epoch [73/300], Loss: 0.0278\n",
      "Epoch [74/300], Loss: 0.0276\n",
      "Epoch [75/300], Loss: 0.0274\n",
      "Epoch [76/300], Loss: 0.0273\n",
      "Epoch [77/300], Loss: 0.0271\n",
      "Epoch [78/300], Loss: 0.0270\n",
      "Epoch [79/300], Loss: 0.0269\n",
      "Epoch [80/300], Loss: 0.0267\n",
      "Epoch [81/300], Loss: 0.0266\n",
      "Epoch [82/300], Loss: 0.0264\n",
      "Epoch [83/300], Loss: 0.0263\n",
      "Epoch [84/300], Loss: 0.0262\n",
      "Epoch [85/300], Loss: 0.0261\n",
      "Epoch [86/300], Loss: 0.0259\n",
      "Epoch [87/300], Loss: 0.0258\n",
      "Epoch [88/300], Loss: 0.0257\n",
      "Epoch [89/300], Loss: 0.0256\n",
      "Epoch [90/300], Loss: 0.0254\n",
      "Epoch [91/300], Loss: 0.0253\n",
      "Epoch [92/300], Loss: 0.0252\n",
      "Epoch [93/300], Loss: 0.0251\n",
      "Epoch [1/300], Loss: 0.3479\n",
      "Epoch [2/300], Loss: 0.1437\n",
      "Epoch [3/300], Loss: 0.1180\n",
      "Epoch [4/300], Loss: 0.1073\n",
      "Epoch [5/300], Loss: 0.0986\n",
      "Epoch [6/300], Loss: 0.0911\n",
      "Epoch [7/300], Loss: 0.0846\n",
      "Epoch [8/300], Loss: 0.0792\n",
      "Epoch [9/300], Loss: 0.0746\n",
      "Epoch [10/300], Loss: 0.0707\n",
      "Epoch [11/300], Loss: 0.0675\n",
      "Epoch [12/300], Loss: 0.0646\n",
      "Epoch [13/300], Loss: 0.0622\n",
      "Epoch [14/300], Loss: 0.0600\n",
      "Epoch [15/300], Loss: 0.0581\n",
      "Epoch [16/300], Loss: 0.0563\n",
      "Epoch [17/300], Loss: 0.0547\n",
      "Epoch [18/300], Loss: 0.0533\n",
      "Epoch [19/300], Loss: 0.0519\n",
      "Epoch [20/300], Loss: 0.0507\n",
      "Epoch [21/300], Loss: 0.0495\n",
      "Epoch [22/300], Loss: 0.0485\n",
      "Epoch [23/300], Loss: 0.0475\n",
      "Epoch [24/300], Loss: 0.0466\n",
      "Epoch [25/300], Loss: 0.0457\n",
      "Epoch [26/300], Loss: 0.0449\n",
      "Epoch [27/300], Loss: 0.0441\n",
      "Epoch [28/300], Loss: 0.0434\n",
      "Epoch [29/300], Loss: 0.0427\n",
      "Epoch [30/300], Loss: 0.0420\n",
      "Epoch [31/300], Loss: 0.0414\n",
      "Epoch [32/300], Loss: 0.0408\n",
      "Epoch [33/300], Loss: 0.0402\n",
      "Epoch [34/300], Loss: 0.0396\n",
      "Epoch [35/300], Loss: 0.0391\n",
      "Epoch [36/300], Loss: 0.0386\n",
      "Epoch [37/300], Loss: 0.0381\n",
      "Epoch [38/300], Loss: 0.0377\n",
      "Epoch [39/300], Loss: 0.0372\n",
      "Epoch [40/300], Loss: 0.0368\n",
      "Epoch [41/300], Loss: 0.0364\n",
      "Epoch [42/300], Loss: 0.0360\n",
      "Epoch [43/300], Loss: 0.0356\n",
      "Epoch [44/300], Loss: 0.0352\n",
      "Epoch [45/300], Loss: 0.0348\n",
      "Epoch [46/300], Loss: 0.0345\n",
      "Epoch [47/300], Loss: 0.0342\n",
      "Epoch [48/300], Loss: 0.0338\n",
      "Epoch [49/300], Loss: 0.0335\n",
      "Epoch [50/300], Loss: 0.0332\n",
      "Epoch [51/300], Loss: 0.0329\n",
      "Epoch [52/300], Loss: 0.0326\n",
      "Epoch [53/300], Loss: 0.0323\n",
      "Epoch [54/300], Loss: 0.0321\n",
      "Epoch [55/300], Loss: 0.0318\n",
      "Epoch [56/300], Loss: 0.0315\n",
      "Epoch [57/300], Loss: 0.0313\n",
      "Epoch [58/300], Loss: 0.0311\n",
      "Epoch [59/300], Loss: 0.0308\n",
      "Epoch [60/300], Loss: 0.0306\n",
      "Epoch [61/300], Loss: 0.0304\n",
      "Epoch [62/300], Loss: 0.0301\n",
      "Epoch [63/300], Loss: 0.0299\n",
      "Epoch [64/300], Loss: 0.0297\n",
      "Epoch [65/300], Loss: 0.0295\n",
      "Epoch [66/300], Loss: 0.0293\n",
      "Epoch [67/300], Loss: 0.0291\n",
      "Epoch [68/300], Loss: 0.0289\n",
      "Epoch [69/300], Loss: 0.0287\n",
      "Epoch [70/300], Loss: 0.0286\n",
      "Epoch [71/300], Loss: 0.0284\n",
      "Epoch [72/300], Loss: 0.0282\n",
      "Epoch [73/300], Loss: 0.0280\n",
      "Epoch [74/300], Loss: 0.0279\n",
      "Epoch [75/300], Loss: 0.0277\n",
      "Epoch [76/300], Loss: 0.0276\n",
      "Epoch [77/300], Loss: 0.0274\n",
      "Epoch [78/300], Loss: 0.0272\n",
      "Epoch [79/300], Loss: 0.0271\n",
      "Epoch [80/300], Loss: 0.0269\n",
      "Epoch [81/300], Loss: 0.0268\n",
      "Epoch [82/300], Loss: 0.0266\n",
      "Epoch [83/300], Loss: 0.0265\n",
      "Epoch [84/300], Loss: 0.0264\n",
      "Epoch [85/300], Loss: 0.0262\n",
      "Epoch [86/300], Loss: 0.0261\n",
      "Epoch [87/300], Loss: 0.0260\n",
      "Epoch [88/300], Loss: 0.0258\n",
      "Epoch [89/300], Loss: 0.0257\n",
      "Epoch [90/300], Loss: 0.0256\n",
      "Epoch [91/300], Loss: 0.0255\n",
      "Epoch [92/300], Loss: 0.0254\n",
      "Epoch [93/300], Loss: 0.0253\n",
      "Epoch [94/300], Loss: 0.0251\n",
      "Epoch [95/300], Loss: 0.0250\n"
     ]
    }
   ],
   "source": [
    "# define the optimizer\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "\n",
    "# define the loss function\n",
    "loss = torch.nn.MSELoss(reduction='sum')\n",
    "n_runs = 5\n",
    "loss_threshold = 0.025\n",
    "\n",
    "# train the model\n",
    "def train(model, train_loader, test_loader, optimizer, num_epochs=num_epochs):\n",
    "    model.train()\n",
    "    losses = []\n",
    "    test_losses = []\n",
    "    hidden_activations = []\n",
    "    step = 0\n",
    "    epoch_losses = []\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0\n",
    "        for inputs, labels in train_loader:\n",
    "            step += 1\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            # convert labels to one-hot encoding\n",
    "            labels_onehot = torch.zeros(labels.size(0), 10)\n",
    "            labels_onehot.scatter_(1, labels.unsqueeze(1), 1.0)\n",
    "            loss_val = 1/2 * loss(outputs, labels_onehot)\n",
    "            loss_val.backward()\n",
    "            optimizer.step()\n",
    "            batch_size = inputs.size(0)\n",
    "            loss_per_sample = loss_val.item() / batch_size\n",
    "            losses.append(loss_per_sample)\n",
    "            epoch_loss += loss_per_sample\n",
    "        epoch_loss /= len(train_loader)\n",
    "        epoch_losses.append(epoch_loss)\n",
    "        if epoch_loss < loss_threshold:\n",
    "            break\n",
    "        print(f\"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}\")\n",
    "    # at the end of training get the hidden representations for each input in the train set\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for test_inputs, test_labels in test_loader:\n",
    "            test_outputs = model(test_inputs)\n",
    "            test_labels_onehot = torch.zeros(test_labels.size(0), 10)\n",
    "            test_labels_onehot.scatter_(1, test_labels.unsqueeze(1), 1.0)\n",
    "            test_loss = 1/2 * loss(test_outputs, test_labels_onehot).item()\n",
    "            batch_size = test_inputs.size(0)\n",
    "            test_losses.append(test_loss/batch_size)\n",
    "            hidden_activation = model.get_hidden(test_inputs)\n",
    "            hidden_activations.append(hidden_activation.detach().numpy())\n",
    "            break # only get the hidden activations for the first batch\n",
    "    hidden_activations = np.concatenate(hidden_activations)\n",
    "    return epoch_losses, losses, test_losses, hidden_activations\n",
    "\n",
    "def train_n_runs(n_runs, model, train_loader, test_loader, optimizer, num_epochs):\n",
    "    losses_list = []\n",
    "    hidden_acts = []\n",
    "    losses_test_list = []\n",
    "    for i in range(n_runs):\n",
    "        # initialize the weights\n",
    "        model.fc1.weight.data = torch.randn_like(model.fc1.weight.data)*0.0002\n",
    "        model.fc2.weight.data = torch.randn_like(model.fc2.weight.data)*0.0002\n",
    "        _, losses ,test_losses , hidden_activations = train(model, train_loader, test_loader, optimizer, num_epochs)\n",
    "        losses_list.append(losses)\n",
    "        hidden_acts.append(hidden_activations)\n",
    "        losses_test_list.append(test_losses)\n",
    "\n",
    "    return losses_list, hidden_acts, losses_test_list\n",
    "\n",
    "losses_relu_no_bias, hidden_activations_relu_no_bias, losses_test_relu_no_bias = train_n_runs(n_runs, model, train_loader, test_loader, optimizer, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec8db07f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_size_scaled = int(hidden_size / 2)  # Reduce hidden size for linear model\n",
    "class LinearNet(torch.nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super(LinearNet, self).__init__()\n",
    "        self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=has_bias)\n",
    "        self.fc2 = torch.nn.Linear(hidden_size, output_size, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, input_size)\n",
    "        x = self.fc1(x)\n",
    "        # x = self.fc2(x)\n",
    "        # x = self.fc3(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "    def get_hidden(self, x):\n",
    "        return self.fc1(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d186d44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of parameters in layer 1: 401920\n",
      "Number of parameters in layer 2: 5120\n",
      "Epoch [1/300], Loss: 0.3248\n",
      "Epoch [2/300], Loss: 0.1395\n",
      "Epoch [3/300], Loss: 0.1018\n",
      "Epoch [4/300], Loss: 0.0889\n",
      "Epoch [5/300], Loss: 0.0806\n",
      "Epoch [6/300], Loss: 0.0744\n",
      "Epoch [7/300], Loss: 0.0692\n",
      "Epoch [8/300], Loss: 0.0651\n",
      "Epoch [9/300], Loss: 0.0616\n",
      "Epoch [10/300], Loss: 0.0586\n",
      "Epoch [11/300], Loss: 0.0560\n",
      "Epoch [12/300], Loss: 0.0538\n",
      "Epoch [13/300], Loss: 0.0518\n",
      "Epoch [14/300], Loss: 0.0499\n",
      "Epoch [15/300], Loss: 0.0483\n",
      "Epoch [16/300], Loss: 0.0469\n",
      "Epoch [17/300], Loss: 0.0456\n",
      "Epoch [18/300], Loss: 0.0444\n",
      "Epoch [19/300], Loss: 0.0433\n",
      "Epoch [20/300], Loss: 0.0423\n",
      "Epoch [21/300], Loss: 0.0414\n",
      "Epoch [22/300], Loss: 0.0404\n",
      "Epoch [23/300], Loss: 0.0396\n",
      "Epoch [24/300], Loss: 0.0388\n",
      "Epoch [25/300], Loss: 0.0382\n",
      "Epoch [26/300], Loss: 0.0375\n",
      "Epoch [27/300], Loss: 0.0369\n",
      "Epoch [28/300], Loss: 0.0363\n",
      "Epoch [29/300], Loss: 0.0357\n",
      "Epoch [30/300], Loss: 0.0352\n",
      "Epoch [31/300], Loss: 0.0346\n",
      "Epoch [32/300], Loss: 0.0342\n",
      "Epoch [33/300], Loss: 0.0337\n",
      "Epoch [34/300], Loss: 0.0333\n",
      "Epoch [35/300], Loss: 0.0328\n",
      "Epoch [36/300], Loss: 0.0324\n",
      "Epoch [37/300], Loss: 0.0321\n",
      "Epoch [38/300], Loss: 0.0317\n",
      "Epoch [39/300], Loss: 0.0313\n",
      "Epoch [40/300], Loss: 0.0310\n",
      "Epoch [41/300], Loss: 0.0307\n",
      "Epoch [42/300], Loss: 0.0303\n",
      "Epoch [43/300], Loss: 0.0300\n",
      "Epoch [44/300], Loss: 0.0297\n",
      "Epoch [45/300], Loss: 0.0294\n",
      "Epoch [46/300], Loss: 0.0291\n",
      "Epoch [47/300], Loss: 0.0288\n",
      "Epoch [48/300], Loss: 0.0285\n",
      "Epoch [49/300], Loss: 0.0283\n",
      "Epoch [50/300], Loss: 0.0281\n",
      "Epoch [51/300], Loss: 0.0277\n",
      "Epoch [52/300], Loss: 0.0275\n",
      "Epoch [53/300], Loss: 0.0274\n",
      "Epoch [54/300], Loss: 0.0272\n",
      "Epoch [55/300], Loss: 0.0268\n",
      "Epoch [56/300], Loss: 0.0267\n",
      "Epoch [57/300], Loss: 0.0265\n",
      "Epoch [58/300], Loss: 0.0262\n",
      "Epoch [59/300], Loss: 0.0261\n",
      "Epoch [60/300], Loss: 0.0257\n",
      "Epoch [61/300], Loss: 0.0257\n",
      "Epoch [62/300], Loss: 0.0255\n",
      "Epoch [63/300], Loss: 0.0252\n",
      "Epoch [64/300], Loss: 0.0251\n",
      "Epoch [1/300], Loss: 0.3250\n",
      "Epoch [2/300], Loss: 0.1384\n",
      "Epoch [3/300], Loss: 0.1013\n",
      "Epoch [4/300], Loss: 0.0890\n",
      "Epoch [5/300], Loss: 0.0807\n",
      "Epoch [6/300], Loss: 0.0746\n",
      "Epoch [7/300], Loss: 0.0696\n",
      "Epoch [8/300], Loss: 0.0654\n",
      "Epoch [9/300], Loss: 0.0620\n",
      "Epoch [10/300], Loss: 0.0590\n",
      "Epoch [11/300], Loss: 0.0564\n",
      "Epoch [12/300], Loss: 0.0541\n",
      "Epoch [13/300], Loss: 0.0522\n",
      "Epoch [14/300], Loss: 0.0504\n",
      "Epoch [15/300], Loss: 0.0489\n",
      "Epoch [16/300], Loss: 0.0474\n",
      "Epoch [17/300], Loss: 0.0461\n",
      "Epoch [18/300], Loss: 0.0449\n",
      "Epoch [19/300], Loss: 0.0438\n",
      "Epoch [20/300], Loss: 0.0428\n",
      "Epoch [21/300], Loss: 0.0419\n",
      "Epoch [22/300], Loss: 0.0410\n",
      "Epoch [23/300], Loss: 0.0402\n",
      "Epoch [24/300], Loss: 0.0394\n",
      "Epoch [25/300], Loss: 0.0387\n",
      "Epoch [26/300], Loss: 0.0380\n",
      "Epoch [27/300], Loss: 0.0374\n",
      "Epoch [28/300], Loss: 0.0368\n",
      "Epoch [29/300], Loss: 0.0363\n",
      "Epoch [30/300], Loss: 0.0357\n",
      "Epoch [31/300], Loss: 0.0352\n",
      "Epoch [32/300], Loss: 0.0347\n",
      "Epoch [33/300], Loss: 0.0342\n",
      "Epoch [34/300], Loss: 0.0338\n",
      "Epoch [35/300], Loss: 0.0334\n",
      "Epoch [36/300], Loss: 0.0330\n",
      "Epoch [37/300], Loss: 0.0325\n",
      "Epoch [38/300], Loss: 0.0321\n",
      "Epoch [39/300], Loss: 0.0318\n",
      "Epoch [40/300], Loss: 0.0315\n",
      "Epoch [41/300], Loss: 0.0311\n",
      "Epoch [42/300], Loss: 0.0308\n",
      "Epoch [43/300], Loss: 0.0305\n",
      "Epoch [44/300], Loss: 0.0302\n",
      "Epoch [45/300], Loss: 0.0299\n",
      "Epoch [46/300], Loss: 0.0296\n",
      "Epoch [47/300], Loss: 0.0292\n",
      "Epoch [48/300], Loss: 0.0290\n",
      "Epoch [49/300], Loss: 0.0287\n",
      "Epoch [50/300], Loss: 0.0284\n",
      "Epoch [51/300], Loss: 0.0283\n",
      "Epoch [52/300], Loss: 0.0280\n",
      "Epoch [53/300], Loss: 0.0278\n",
      "Epoch [54/300], Loss: 0.0276\n",
      "Epoch [55/300], Loss: 0.0273\n",
      "Epoch [56/300], Loss: 0.0272\n",
      "Epoch [57/300], Loss: 0.0269\n",
      "Epoch [58/300], Loss: 0.0268\n",
      "Epoch [59/300], Loss: 0.0265\n",
      "Epoch [60/300], Loss: 0.0263\n",
      "Epoch [61/300], Loss: 0.0262\n",
      "Epoch [62/300], Loss: 0.0259\n",
      "Epoch [63/300], Loss: 0.0257\n",
      "Epoch [64/300], Loss: 0.0256\n",
      "Epoch [65/300], Loss: 0.0254\n",
      "Epoch [66/300], Loss: 0.0252\n",
      "Epoch [67/300], Loss: 0.0251\n",
      "Epoch [1/300], Loss: 0.3254\n",
      "Epoch [2/300], Loss: 0.1403\n",
      "Epoch [3/300], Loss: 0.1021\n",
      "Epoch [4/300], Loss: 0.0892\n",
      "Epoch [5/300], Loss: 0.0808\n",
      "Epoch [6/300], Loss: 0.0745\n",
      "Epoch [7/300], Loss: 0.0695\n",
      "Epoch [8/300], Loss: 0.0653\n",
      "Epoch [9/300], Loss: 0.0619\n",
      "Epoch [10/300], Loss: 0.0589\n",
      "Epoch [11/300], Loss: 0.0563\n",
      "Epoch [12/300], Loss: 0.0540\n",
      "Epoch [13/300], Loss: 0.0520\n",
      "Epoch [14/300], Loss: 0.0501\n",
      "Epoch [15/300], Loss: 0.0484\n",
      "Epoch [16/300], Loss: 0.0470\n",
      "Epoch [17/300], Loss: 0.0457\n",
      "Epoch [18/300], Loss: 0.0445\n",
      "Epoch [19/300], Loss: 0.0433\n",
      "Epoch [20/300], Loss: 0.0424\n",
      "Epoch [21/300], Loss: 0.0414\n",
      "Epoch [22/300], Loss: 0.0406\n",
      "Epoch [23/300], Loss: 0.0398\n",
      "Epoch [24/300], Loss: 0.0390\n",
      "Epoch [25/300], Loss: 0.0384\n",
      "Epoch [26/300], Loss: 0.0376\n",
      "Epoch [27/300], Loss: 0.0371\n",
      "Epoch [28/300], Loss: 0.0366\n",
      "Epoch [29/300], Loss: 0.0360\n",
      "Epoch [30/300], Loss: 0.0354\n",
      "Epoch [31/300], Loss: 0.0350\n",
      "Epoch [32/300], Loss: 0.0344\n",
      "Epoch [33/300], Loss: 0.0339\n",
      "Epoch [34/300], Loss: 0.0336\n",
      "Epoch [35/300], Loss: 0.0331\n",
      "Epoch [36/300], Loss: 0.0328\n",
      "Epoch [37/300], Loss: 0.0323\n",
      "Epoch [38/300], Loss: 0.0319\n",
      "Epoch [39/300], Loss: 0.0315\n",
      "Epoch [40/300], Loss: 0.0312\n",
      "Epoch [41/300], Loss: 0.0309\n",
      "Epoch [42/300], Loss: 0.0306\n",
      "Epoch [43/300], Loss: 0.0303\n",
      "Epoch [44/300], Loss: 0.0299\n",
      "Epoch [45/300], Loss: 0.0296\n",
      "Epoch [46/300], Loss: 0.0295\n",
      "Epoch [47/300], Loss: 0.0292\n",
      "Epoch [48/300], Loss: 0.0289\n",
      "Epoch [49/300], Loss: 0.0286\n",
      "Epoch [50/300], Loss: 0.0283\n",
      "Epoch [51/300], Loss: 0.0281\n",
      "Epoch [52/300], Loss: 0.0280\n",
      "Epoch [53/300], Loss: 0.0277\n",
      "Epoch [54/300], Loss: 0.0275\n",
      "Epoch [55/300], Loss: 0.0273\n",
      "Epoch [56/300], Loss: 0.0269\n",
      "Epoch [57/300], Loss: 0.0268\n",
      "Epoch [58/300], Loss: 0.0266\n",
      "Epoch [59/300], Loss: 0.0264\n",
      "Epoch [60/300], Loss: 0.0262\n",
      "Epoch [61/300], Loss: 0.0260\n",
      "Epoch [62/300], Loss: 0.0259\n",
      "Epoch [63/300], Loss: 0.0255\n",
      "Epoch [64/300], Loss: 0.0254\n",
      "Epoch [65/300], Loss: 0.0253\n",
      "Epoch [66/300], Loss: 0.0250\n",
      "Epoch [1/300], Loss: 0.3250\n",
      "Epoch [2/300], Loss: 0.1393\n",
      "Epoch [3/300], Loss: 0.1017\n",
      "Epoch [4/300], Loss: 0.0889\n",
      "Epoch [5/300], Loss: 0.0806\n",
      "Epoch [6/300], Loss: 0.0745\n",
      "Epoch [7/300], Loss: 0.0696\n",
      "Epoch [8/300], Loss: 0.0654\n",
      "Epoch [9/300], Loss: 0.0619\n",
      "Epoch [10/300], Loss: 0.0590\n",
      "Epoch [11/300], Loss: 0.0564\n",
      "Epoch [12/300], Loss: 0.0542\n",
      "Epoch [13/300], Loss: 0.0522\n",
      "Epoch [14/300], Loss: 0.0504\n",
      "Epoch [15/300], Loss: 0.0488\n",
      "Epoch [16/300], Loss: 0.0473\n",
      "Epoch [17/300], Loss: 0.0461\n",
      "Epoch [18/300], Loss: 0.0449\n",
      "Epoch [19/300], Loss: 0.0438\n",
      "Epoch [20/300], Loss: 0.0427\n",
      "Epoch [21/300], Loss: 0.0419\n",
      "Epoch [22/300], Loss: 0.0409\n",
      "Epoch [23/300], Loss: 0.0401\n",
      "Epoch [24/300], Loss: 0.0394\n",
      "Epoch [25/300], Loss: 0.0386\n",
      "Epoch [26/300], Loss: 0.0380\n",
      "Epoch [27/300], Loss: 0.0373\n",
      "Epoch [28/300], Loss: 0.0367\n",
      "Epoch [29/300], Loss: 0.0361\n",
      "Epoch [30/300], Loss: 0.0356\n",
      "Epoch [31/300], Loss: 0.0350\n",
      "Epoch [32/300], Loss: 0.0346\n",
      "Epoch [33/300], Loss: 0.0341\n",
      "Epoch [34/300], Loss: 0.0337\n",
      "Epoch [35/300], Loss: 0.0333\n",
      "Epoch [36/300], Loss: 0.0328\n",
      "Epoch [37/300], Loss: 0.0324\n",
      "Epoch [38/300], Loss: 0.0320\n",
      "Epoch [39/300], Loss: 0.0317\n",
      "Epoch [40/300], Loss: 0.0313\n",
      "Epoch [41/300], Loss: 0.0311\n",
      "Epoch [42/300], Loss: 0.0306\n",
      "Epoch [43/300], Loss: 0.0303\n",
      "Epoch [44/300], Loss: 0.0300\n",
      "Epoch [45/300], Loss: 0.0297\n",
      "Epoch [46/300], Loss: 0.0295\n",
      "Epoch [47/300], Loss: 0.0292\n",
      "Epoch [48/300], Loss: 0.0289\n",
      "Epoch [49/300], Loss: 0.0287\n",
      "Epoch [50/300], Loss: 0.0284\n",
      "Epoch [51/300], Loss: 0.0282\n",
      "Epoch [52/300], Loss: 0.0279\n",
      "Epoch [53/300], Loss: 0.0276\n",
      "Epoch [54/300], Loss: 0.0275\n",
      "Epoch [55/300], Loss: 0.0272\n",
      "Epoch [56/300], Loss: 0.0270\n",
      "Epoch [57/300], Loss: 0.0268\n",
      "Epoch [58/300], Loss: 0.0265\n",
      "Epoch [59/300], Loss: 0.0263\n",
      "Epoch [60/300], Loss: 0.0262\n",
      "Epoch [61/300], Loss: 0.0260\n",
      "Epoch [62/300], Loss: 0.0257\n",
      "Epoch [63/300], Loss: 0.0256\n",
      "Epoch [64/300], Loss: 0.0254\n",
      "Epoch [65/300], Loss: 0.0253\n",
      "Epoch [66/300], Loss: 0.0251\n",
      "Epoch [1/300], Loss: 0.3247\n",
      "Epoch [2/300], Loss: 0.1394\n",
      "Epoch [3/300], Loss: 0.1016\n",
      "Epoch [4/300], Loss: 0.0890\n",
      "Epoch [5/300], Loss: 0.0808\n",
      "Epoch [6/300], Loss: 0.0745\n",
      "Epoch [7/300], Loss: 0.0696\n",
      "Epoch [8/300], Loss: 0.0653\n",
      "Epoch [9/300], Loss: 0.0618\n",
      "Epoch [10/300], Loss: 0.0587\n",
      "Epoch [11/300], Loss: 0.0561\n",
      "Epoch [12/300], Loss: 0.0539\n",
      "Epoch [13/300], Loss: 0.0519\n",
      "Epoch [14/300], Loss: 0.0501\n",
      "Epoch [15/300], Loss: 0.0486\n",
      "Epoch [16/300], Loss: 0.0471\n",
      "Epoch [17/300], Loss: 0.0458\n",
      "Epoch [18/300], Loss: 0.0445\n",
      "Epoch [19/300], Loss: 0.0435\n",
      "Epoch [20/300], Loss: 0.0425\n",
      "Epoch [21/300], Loss: 0.0416\n",
      "Epoch [22/300], Loss: 0.0406\n",
      "Epoch [23/300], Loss: 0.0399\n",
      "Epoch [24/300], Loss: 0.0392\n",
      "Epoch [25/300], Loss: 0.0384\n",
      "Epoch [26/300], Loss: 0.0378\n",
      "Epoch [27/300], Loss: 0.0371\n",
      "Epoch [28/300], Loss: 0.0365\n",
      "Epoch [29/300], Loss: 0.0360\n",
      "Epoch [30/300], Loss: 0.0355\n",
      "Epoch [31/300], Loss: 0.0349\n",
      "Epoch [32/300], Loss: 0.0344\n",
      "Epoch [33/300], Loss: 0.0339\n",
      "Epoch [34/300], Loss: 0.0336\n",
      "Epoch [35/300], Loss: 0.0331\n",
      "Epoch [36/300], Loss: 0.0326\n",
      "Epoch [37/300], Loss: 0.0323\n",
      "Epoch [38/300], Loss: 0.0319\n",
      "Epoch [39/300], Loss: 0.0315\n",
      "Epoch [40/300], Loss: 0.0312\n",
      "Epoch [41/300], Loss: 0.0308\n",
      "Epoch [42/300], Loss: 0.0305\n",
      "Epoch [43/300], Loss: 0.0303\n",
      "Epoch [44/300], Loss: 0.0299\n",
      "Epoch [45/300], Loss: 0.0296\n",
      "Epoch [46/300], Loss: 0.0294\n",
      "Epoch [47/300], Loss: 0.0291\n",
      "Epoch [48/300], Loss: 0.0288\n",
      "Epoch [49/300], Loss: 0.0286\n",
      "Epoch [50/300], Loss: 0.0283\n",
      "Epoch [51/300], Loss: 0.0281\n",
      "Epoch [52/300], Loss: 0.0278\n",
      "Epoch [53/300], Loss: 0.0275\n",
      "Epoch [54/300], Loss: 0.0274\n",
      "Epoch [55/300], Loss: 0.0271\n",
      "Epoch [56/300], Loss: 0.0269\n",
      "Epoch [57/300], Loss: 0.0267\n",
      "Epoch [58/300], Loss: 0.0265\n",
      "Epoch [59/300], Loss: 0.0264\n",
      "Epoch [60/300], Loss: 0.0261\n",
      "Epoch [61/300], Loss: 0.0260\n",
      "Epoch [62/300], Loss: 0.0257\n",
      "Epoch [63/300], Loss: 0.0255\n",
      "Epoch [64/300], Loss: 0.0254\n",
      "Epoch [65/300], Loss: 0.0252\n",
      "Epoch [66/300], Loss: 0.0250\n"
     ]
    }
   ],
   "source": [
    "bias_magnitude = 10\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    PerPixelNormalizeBias(pixel_means, bias_magnitude),\n",
    "])\n",
    "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "# create data loaders\n",
    "input_size = 784 + 1  # +1 for the bias pixel\n",
    "output_size = 10\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)\n",
    "\n",
    "model = ReluNet(input_size=input_size, hidden_size=hidden_size, output_size=output_size)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "\n",
    "# print number of parameters in layer 1 and 2\n",
    "print(f\"Number of parameters in layer 1: {model.fc1.weight.data.numel()}\")\n",
    "print(f\"Number of parameters in layer 2: {model.fc2.weight.data.numel()}\")\n",
    "\n",
    "losses_relu_bias, hidden_activations_relu_bias, losses_test_relu_bias = train_n_runs(n_runs, model, train_loader, test_loader, optimizer, num_epochs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c63ac3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def effective_rank(H):\n",
    "    S = np.linalg.svd(H, full_matrices=False, compute_uv=False)\n",
    "    p = (S**2) / (np.sum(S**2) + 1e-12)\n",
    "    H = -np.sum(p * np.log(p + 1e-12))\n",
    "    return np.exp(H)  # eRank in [1, min(N_out, N_hid)]\n",
    "\n",
    "def compute_effective_ranks(hidden_activations_list):\n",
    "    e_rank_list = []\n",
    "    for hidden_activations in hidden_activations_list:\n",
    "        H = hidden_activations @ hidden_activations.T\n",
    "        e_rank = effective_rank(H)\n",
    "        e_rank_list.append(e_rank)\n",
    "    e_rank_list = np.array(e_rank_list)\n",
    "    return e_rank_list\n",
    "\n",
    "e_rank_list_relu_no_bias = compute_effective_ranks(hidden_activations_relu_no_bias)\n",
    "e_rank_list_relu_bias = compute_effective_ranks(hidden_activations_relu_bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "3f4e4a72",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/zm/n63_ssc512d4ygg3lbvwql400000gq/T/ipykernel_5037/1997805938.py:24: FutureWarning: \n",
      "\n",
      "Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.\n",
      "\n",
      "  sns.barplot(x=\"Bias\", y=\"Test Loss\", data=df_test,capsize=.3, palette=custom_palette, ax=ax)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARkAAAETCAYAAAACi3JxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAozklEQVR4nO3de1xUdf7H8dcwyiUUkVjRn3JTTLwEupqoqASppJYpZFqmRmoKeC0TUTP197NUWrEQbwv2syzNK+omoWZqpUKIq6ulGywgrFviQiJCcpvfH/yYbRqQYZgjt8/z8fCP+c75nvM9c5w33/Odc75HpdFoNAghhELM6rsBQoimTUJGCKEoCRkhhKIkZIQQipKQEUIoSkJGCKEoCRkhhKIkZIQQipKQEUIoSkJGCKEoCRkhhKIkZIQQipKQEUIoSkJGCKEoCRkhhKIkZIQQipKQEUIoqkV9N0CIxuzXX3+ltLT0oW6zRYsWWFpaPtRt1oWEjBBGio6OJi4ujvLy8oe6XTMzM8aOHUtoaOhD3a6xGuzp0pkzZwgICMDT0xNfX1+2bt1KTdMRHzp0iNGjR+Ph4YG/vz979+594PLvvPMO3bp1M2WzRTNSHwEDUF5eTlxc3EPfrrEaZMikpKQQEhJCly5diIqKYsyYMURGRrJly5Zq68THxxMWFoa3tzfR0dEMGDCAZcuWcfjw4SqX/+677/j444+V2gXRDIwdOxYzs4f/FarsyTQWqob4tIJp06Zx584d9u3bpy2LiIjg008/5dy5c1Wej/r7++Pu7s7777+vLZs/fz5Xr17l+PHjOssWFhYyZswYSkpK+Omnn7h+/bpyOyPqVXm5BjMzlWLrN2ZMpqCggEmTJmlff/LJJ7Rq1crg+g9jTMaUn1uDG5MpLi4mMTGRuXPn6pT7+/sTExNDcnIygwcP1nkvOzubjIyMKuvEx8eTnp6Oq6urtnzt2rXY29szcOBANm3apNzOiHpnZqbi+DfXycsvVGT9JSXFlJeV1apO8f0indeff3UVcwsrg+ubqdW0bGleq23WRlubRxg+2HTDCA0uZLKysigpKcHFxUWn3NnZGYCMjAy9kElLSwN4YJ3KkPn22285dOgQBw8e5C9/+YvR7SwvL69xjEjUP7VaTV5+Ibdz75l83ZcS40n9IRHq+P/g05h3aldBpcKtuxeeXiPrtN2alNUQnmq12qD1NLiQyc/PB9DrPlpbWwMVXc3fu3v3rkF17t69y9KlS5k7d65Oz8YYV65coaSkpE7rEMqysrKiR48eiq0/7YekOgeMUTQa0n5IUjxkrl+/TlFRUbXv9+3b16D1NLiQqRytV6mqPh+saqCtujqVPY3KOu+88w7t27fnlVdeqXM7e/XqJT2ZZq5L9/4m6cnUlkplRpfu/RXfjql+eW1wIWNjYwPo91ju3avo7lY1QFZdncLCQm2dr776is8//5z9+/dTXl6u/QdQWlqKmZlZrX4pqI9fFUTD4uk1kp59n6r1mExdmanVtGih3JhMJUNPh2rS4ELGyckJtVpNZmamTnnlazc3N706lac+mZmZOt3j39aJiori/v37PPPMM3r1e/bsybhx41izZo3J9kM0Dy1amDfAb1HD0uA+HgsLC/r168fx48eZNm2a9hQoISEBGxsbPDw89Oo4Ozvj6OhIQkICI0f+5zw1ISEBFxcXOnbsyOzZs3V+NgTYs2cPe/bsYd++fbRt21bZHROimWpwIQMQHBxMUFAQ8+bNIzAwkIsXLxIbG8vChQuxtLSkoKCA1NRUnJycsLOzAyAkJITw8HBsbW3x8/Pj5MmTxMfHExkZCUCnTp3o1KmTznZOnToFwOOPP/5Q90+I5qRBDiwMHDiQqKgo0tPTCQ0N5ciRIyxatIjp06cDcPXqVSZMmKANCYCAgABWrlzJ2bNnCQ0NJSkpibVr1zJq1Kh62gshBDTQK36FMKU9Ry8qcp1MU2VvZ80Lo/qYbH0NsicjhGg6JGSEEIqSkBFCKEpCRgihKAkZIYSiJGSEEIqSkBFCKEpCRgihKAkZIYSiJGSEEIqSkBFCKEpCRgihKAkZIYSiJGSEEIqSkBFCKEpCRgihqAY5/aYwnDGPSa2rh/GYVNF0SMg0YtHR0cTFxWkf7fKwVD7wPTQ09KFuVzROcrrUiNVHwEDFw/Ti4uIe+nZF4yQh04iNHTu2Xh4yV9mTEcIQMpF4I1fbMZmCggKd50998sknVT6V80Ea25iMTCReO6aeSFzGZBq5un7ZW7VqVeuQEaI25HRJCKEoCRkFaephULYpkM+taZHTJQWpzMy4c2IXZXm36rspWvful+i8zovbQrFFy3pqjT5123a0GfZifTdDmJCEjMLK8m5Revufiq3/fmk5pbX4y19YUqbzOv+nLEpbqmu1zRZmZli0kE6wMIyETCO2469ZJKTeoi4/D86Nv1LrOirA360dU3s71mHLormQP0eN2LG0ugWMsTT/v20hDCEh04iN6NIOVT1s10xVsW0hDCGnS43Y1N6OTOzVsVZjMqYgYzKiNiRkGjmLFmZYSIdUNGDyv1MIoSgJGSGEoiRkhBCKkpARQihKQkYIoSgJGSGEoiRkhBCKkpARQihKQkYIoSgJGSGEohpsyJw5c4aAgAA8PT3x9fVl69at1DTn+aFDhxg9ejQeHh74+/uzd+9evWX27Nmjs8yOHTtqXK8QwngN8t6llJQUQkJCGDlyJPPnz+fChQtERkZSXl5OcHBwlXXi4+MJCwtjypQpDBkyhBMnTrBs2TIsLCwYM2YMAJ9++ikrV65kxowZeHt7c+nSJdauXUtRURGzZs16mLsoRLPRIEMmOjoad3d3IiIiABg6dCilpaVs27aNoKCgKmfo37BhA/7+/ixZsgSAIUOGcOfOHaKiohgzZgwajYY///nPjBw5koULFwIwcOBAMjIy2Llzp4SMEAppcKdLxcXFJCYmMmLECJ1yf39/CgsLSU5O1quTnZ1NRkZGlXVu3LhBeno6ADExMbz55ps6y7Rs2ZLi4mIT74UQolKD68lkZWVRUlKCi4uLTrmzszMAGRkZDB48WOe9tLQ0gAfWcXV1pUuXLgBoNBru3LnD8ePHiYuLY9q0abVuZ3l5eY1jOWp17ebOFf9RVlZW80IGkGNgvJqOgaGfrclDpqysjPz8fNq2bWtU/fz8fAC9B45ZW1sDFU9A/L27d+/Wqk5KSgovvfQSAD179mTy5Mm1bueVK1coKSmp9n0rKyt69OhR6/WKCtevX6eoqKhO65BjUDc1HYO+ffsatB6jQyY3N5c9e/bg4+ND9+7dAdi1axfvvfcehYWFODo6snz5cr1eR00qHyCvUlU9sWRVz36urk5lT+P3dTp16sTHH3/Mzz//TFRUFIGBgezbtw97e3uD29mrVy/5VUpB3bp1q+8mNHumOgZGhczPP//M888/z+3bt2nbti3du3fnhx9+YNWqVWg0GmxsbLhx4wazZs1i3759uLu7G7xuGxsbQL/3ce9exbOMq3qkanV1CgsLq6zj4OCAg4MDAJ6enowYMYK9e/dW+8tVVerjQffNiZzm1D9THQOjvinbtm0jJyeH4cOHM3DgQAB2796NRqMhKCiIpKQktmzZQmlpKTExMbVat5OTE2q1mszMTJ3yytdubm56dVxdXXWWqapOQUEBhw8f1lvGycmJNm3a8K9//atW7RRCGMaokPn666/p2LEjGzZswMnJCYCvvvoKlUrF1KlTAXjyySfp3bs3SUlJtVq3hYUF/fr14/jx4zqnIwkJCdjY2ODh4aFXx9nZGUdHRxISEnTKExIScHFxoWPHjqjVapYuXaoXepcvX+aXX36pVW9LCGE4o0+XfHx8tKcMf//737l16xaurq60b99eu5yDgwNXrtT+4WHBwcEEBQUxb948AgMDuXjxIrGxsSxcuBBLS0sKCgpITU3FyckJOzs7AEJCQggPD8fW1hY/Pz9OnjxJfHw8kZGRQMUg4IwZM9i0aRO2trYMGjSI9PR0Nm7ciLu7O4GBgcZ8FEKIGhgVMq1ateLXX3/Vvj5z5gwA3t7eOsvdunVL+wtPbQwcOJCoqCg++OADQkNDcXBwYNGiRbz66qsAXL16lSlTpvDuu+8SEBAAQEBAAMXFxWzfvp39+/fj6OjI2rVrGTVqlHa9s2fPxt7enl27drFjxw7atGmjvarYwsKi1u0UQtTMqJBxdXUlOTmZf//737Ru3ZqDBw+iUqnw8/PTLpOSksJf//pX+vfvb1TDhg8fzvDhw6t8z8vLi+vXr+uVT5w4kYkTJ1a7TjMzM1566SXtz9dCCOUZNSbz/PPPU1hYyDPPPMPTTz9NWloaLi4u2kHgt99+W9vreNCXXgjR9BnVkxk7diz5+flERkaSl5dH586d2bBhg/Y6le+++47S0lLCw8MZOXKkSRsshGhcjL4Yb8qUKUycOJGCggLt4GullStX0rVrV2xtbevaPiFEI1en2wrMzc31AgbgiSeeqMtqhRBNiNGXrZaUlHDkyBEyMjK0ZV999RWjRo2iT58+vPLKK1y7ds0UbRRCNGJGhUx+fj7jxo1j0aJF2qkXMjMzmTNnDv/4xz8oKiri/PnzvPzyy2RnZ5u0wUKIxsWokImNjSU1NRVPT0/tzZG7d++mtLSU5557juTkZN566y0KCgrYtm2bSRsshGhcjAqZL7/8Ent7ez766CN69uwJwIkTJ1CpVMyaNYtWrVoxadIk3N3d+eabb0zaYCFE42JUyGRnZ9O7d2/Mzc2BilOlrKwsOnTooL1ZESruKcrJyTFNS4UQjZJRIWNhYUFpaan29ddffw3AoEGDdJbLy8vTBpEQonkyKmRcXFy4dOkSRUVFaDQaDh8+jEqlwtfXV7tMeno6KSkpdO3a1WSNFUI0PkaFzMiRI8nLyyMgIIAXX3yRy5cv065dO4YOHQrA1q1bmTRpEmVlZYwbN86kDRZCNC5GXYz3yiuvcPPmTT766CMA2rRpQ0REBC1btgRg//795ObmMnXqVCZMmGC61gohGh2jr/hdsmQJQUFB5OTk8Nhjj+k8C2n+/Pl07txZJoISQtTttoIOHTrQoUMHvfLfzuEihGje6hQyOTk57Nq1i6SkJHJycjA3N+fRRx/Fy8uLsWPHVhlAQojmxeiQOX36NG+88Qb37t3TmYv3xx9/JDExkZiYGCIiInQmshJCND9GhUxaWhrz5s3j/v37jB07llGjRtGpUyc0Gg1ZWVkcPXqUw4cP88Ybb3DgwAGdC/SEEM2LUSGzdetW7t+/z+rVq7Vz7Fbq3LkzPj4+9O/fn6VLl/Lhhx+yatUqkzRWCNH4GHWdzLlz5+jWrZtewPxWYGCg3LskhDAuZPLy8gw6BXJ1deX27dvGbEII0UQYFTJt27YlPT29xuXS09Np06aNMZsQQjQRRoXMgAEDuH79OnFxcdUuExcXx7Vr1/Dy8jK2bUKIJsCogd+ZM2eSkJDAkiVLSExM5Omnn6ZTp05AxTQQCQkJxMXFYW5uzmuvvWbSBgshGhejQsbNzY3169ezcOFCDh48qNej0Wg0WFlZsW7dOh577DFTtFMI0UgZfTHesGHDOHHiBLt37yY5OZlbt26h0Who164dTzzxBOPHj8fBwcGUbRVCNEJ1uq3A3t6e2bNnm6otQogmyOhHohgiKiqKNWvWKLkJIUQDp2jIfPbZZ+zYsUPJTQghGjhFQ0YIISRkhBCKkpARQihKQkYIoSgJGSGEoiRkhBCKMuhivClTphi18l9++cWoekKIpsOgkElKSjJ6AyqVyui6QojGz6CQeffdd5VuhxCiiTIoZORRs0IIY8nArxBCUQ02ZM6cOUNAQACenp74+vqydetWnec7VeXQoUOMHj0aDw8P/P392bt3r94yCQkJPP/88/zxj3/Ex8eHxYsXyzzEQiioQYZMSkoKISEhdOnShaioKMaMGUNkZCRbtmyptk58fDxhYWF4e3sTHR3NgAEDWLZsGYcPH9ZZZu7cufTo0YMPPviABQsWkJSUxNSpU7l///7D2DUhmp06zSejlOjoaNzd3YmIiABg6NChlJaWsm3bNoKCgrC0tNSrs2HDBvz9/VmyZAkAQ4YM4c6dO9qQAti0aRM+Pj46z4Hq3Lkz48eP56uvvuLpp59+CHsnRPPS4HoyxcXFJCYmMmLECJ1yf39/CgsLSU5O1quTnZ1NRkZGlXVu3LhBeno65eXleHt788ILL+gsU/lolxs3bph4T4QQ0ABDJisri5KSElxcXHTKnZ2dAcjIyNCrk5aWBvDAOmZmZixevJhhw4bpLHPs2DEAmYtYCIUYdboUFxeHo6Mjffv2feByX375JVevXmXu3LkGrzs/Px+AVq1a6ZRbW1sDUFBQoFfn7t27ta4DFeGzbt06evbsydChQw1uI0B5eXmNA9FqtbpW6xT/UVZWZpL1yDEwXk3HwNDP1qiQWbx4Mc8991yNIXPo0CG+/vrrWoVMeXk5UP2VwmZm+p2v6upUhkBVddLS0ggKCsLc3Jz333+/ymUe5MqVK5SUlFT7vpWVFT169KjVOsV/XL9+naKiojqtQ45B3dR0DGr6/lcyKGTi4uIoLS3VKcvMzGTfvn3V1ikoKODcuXO0bNnSoIZUsrGx0db/rXv37gH6vZUH1SksLKyyzvnz55kzZw7W1tZs374dR0fHWrURoFevXjX2ZITxunXrVt9NaPZMdQwMCpmrV6/y8ccfa3sKKpWKS5cucenSpQfW02g0tb5a2MnJCbVaTWZmpk555Ws3Nze9OpWDt5mZmTp/uaqqc+TIEcLDw3FxcSEmJob27dvXqn2VatvzEbUjpzn1z1THwKCQmTt3rk4v4eDBgzg5OT2wu2RhYYGLiwsvvvhirRpkYWFBv379OH78ONOmTdMGW0JCAjY2Nnh4eOjVcXZ2xtHRkYSEBEaOHKktT0hIwMXFhY4dOwJw+vRpwsLC6Nu3L5s2baJ169a1apsQovYMCpnWrVvr3CR58OBBevfurdiNk8HBwQQFBTFv3jwCAwO5ePEisbGxLFy4EEtLSwoKCkhNTcXJyQk7OzsAQkJCCA8Px9bWFj8/P06ePEl8fDyRkZEA3L9/n6VLl2Jtbc2sWbO0v0hVat++vdG9GiFE9Ywa+L127Zqp26Fj4MCBREVF8cEHHxAaGoqDgwOLFi3i1VdfBSpO36ZMmcK7775LQEAAAAEBARQXF7N9+3b279+Po6Mja9euZdSoUUDFVcQ5OTkA2vX81uzZs5kzZ46i+yVEc6TS1GH08ueff8bS0pI2bdoAFRfFxcTE8NNPP+Hh4cGUKVOqHKhtTnL3vk/p7X/WdzMajRb2HbEbP8+k69xz9CK3c++ZdJ1Nmb2dNS+M6mOy9Rk1elleXs6yZcvw9fXlzJkzQMW1Ki+++CKfffYZp06dIioqihdffFH7C48QonkyKmQ+++wz9u3bR+vWrXnkkUe0ZTk5OfTq1Yvo6GhGjRrFjz/+yPbt203aYCFE42JUyBw6dAhLS0v27dvHU089BVT8kqNSqVi8eDFPPfUU69at47/+679ISEgwaYOFEI2LUSGTmprKE088ob2ILTc3lytXrmBjY6P9WVutVtOjRw+ys7NN11ohRKNjVMiUlZVhZWWlfX327Fk0Gg1eXl46yxUXF8tVsUI0c0aFTKdOnbh+/br29bFjx1CpVAwZMkRbVlBQwKVLl7QXwgkhmiejQmbQoEHcuHGDsLAw/vSnP3Hs2DEsLCwYPnw4AMnJycycOZP8/Hy9qRWEEM2LURfjhYaGcunSJQ4dOqQte/3117G1tQVg/vz53L59G09PT2bMmGGShgohGiejQsbGxoaPPvqI+Ph4cnJyeOKJJ/D09NS+/+yzz9KhQwcmTpyIubm5yRorhGh8jJ7j19zcnOeee67K98LCwoxukBCiaanzROKXL18mMTGRn376CXd3d8aPH8+pU6fw8PDQ3rwohGi+jA6Zmzdv8uabb5KSkqIte/bZZxk/fjybN2/m2rVrvPfee9rBYCFE82TUr0t5eXm8/PLLXLhwga5du/Lqq6/qXA/j7OzM/fv3WbBggeJ3bAshGjajQmbr1q3cvHmT4OBgDh8+zJtvvqnz/rp161ixYgWlpaX8+c9/NklDhRCNk1Ehc+LECZydnZk3r/pb8idOnEjXrl1rnKJTCNG0GRUyP//8M+7u7jUu5+rqyq1bt4zZhBCiiTAqZFq3bs0//1nzREzZ2dkyj64QzZxRIdOvXz++//57kpKSql3m3LlzfP/99wY/m0UI0TQZFTIzZsxApVIRHBzM//7v/2p/QSorKyMrK4udO3cyd+5czMzMCAoKMmmDhRCNi0HXydy8eZNHHnlEe2/S448/zv/8z/+wfPly1q5dC1Q8i+no0aMcPXoUqHgu0dKlS+nTx3RzhQohGh+DejJPPfWU3uNPxo0bR1xcHC+88ALOzs5YWFjQokULOnTowHPPPceePXuYNGmSIo0WQjQeBvVkNBpNlZNPdenShZUrV5q8UUKIpkOetSqEUJSEjBBCURIyQghFGXwX9okTJ7SPP6kNlUrFiRMnal1PCNE0GBwyhYWFRj0NUqVS1bqOEKLpMDhkvL29mTlzppJtEUI0QQaHzKOPPkr//v2VbIsQogmSgV8hhKIkZIQQipKQEUIoyqAxmdmzZ9OtWzel2yKEaIIMDhkhhDCGnC4JIRQlISOEUJSEjBBCURIyQghFScgIIRQlISOEUJSEjBBCUQ02ZM6cOUNAQACenp74+vqydevWKucZ/q1Dhw4xevRoPDw88Pf3Z+/evdUuW1BQgJ+fHwcOHDB104UQv9EgQyYlJYWQkBC6dOlCVFQUY8aMITIyki1btlRbJz4+nrCwMLy9vYmOjmbAgAEsW7aMw4cP6y37yy+/MGPGDIOegimEqBuDp3p4mKKjo3F3dyciIgKAoUOHUlpayrZt2wgKCsLS0lKvzoYNG/D392fJkiUADBkyhDt37mhDqtKJEydYvXq1URNwCSFqr8H1ZIqLi0lMTGTEiBE65f7+/hQWFpKcnKxXJzs7m4yMjCrr3Lhxg/T0dADy8/OZM2cO/fv3JyYmRrmdEEJoNbieTFZWFiUlJbi4uOiUOzs7A5CRkcHgwYN13ktLSwN4YB1XV1csLS35/PPP6dy5M9nZ2XVqZ3l5eY1jRGq1uk7baM7KyspMsh45Bsar6RgY+tk2uJDJz88HoFWrVjrl1tbWQMWA7e/dvXvXoDrm5uZ07tzZJO28cuUKJSUl1b5vZWVFjx49TLKt5uj69esUFRXVaR1yDOqmpmPQt29fg9bT4EKmvLwcqH4CcjMz/TO86upU9jSqqlNXvXr1qrEnI4wnU4vUP1MdgwYXMjY2NoB+j+XevXuAfm/lQXUqB3erqlNXSgSX+A85zal/pjoGDe6b4uTkhFqtJjMzU6e88rWbm5teHVdXV51lDKkjhHg4GlzIWFhY0K9fP44fP65zOpKQkICNjQ0eHh56dZydnXF0dCQhIUGnPCEhARcXFzp27Kh4u4UQVWtwp0sAwcHBBAUFMW/ePAIDA7l48SKxsbEsXLgQS0tLCgoKSE1NxcnJCTs7OwBCQkIIDw/H1tYWPz8/Tp48SXx8PJGRkfW8N0I0bw2uJwMwcOBAoqKiSE9PJzQ0lCNHjrBo0SKmT58OwNWrV5kwYQKnTp3S1gkICGDlypWcPXuW0NBQkpKSWLt2LaNGjaqnvRBCAKg08hOJonL3vk/pbbl9wVAt7DtiN36eSde55+hFbufeM+k6mzJ7O2teGNXHZOtrkD0ZIUTTISEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEUJSEjhFCUhIwQQlESMkIIRUnICCEU1WBD5syZMwQEBODp6Ymvry9bt25Fo9E8sM6hQ4cYPXo0Hh4e+Pv7s3fvXr1lLl++zMsvv0yfPn3w9vZm7dq1FBcXK7UbQjR7DTJkUlJSCAkJoUuXLkRFRTFmzBgiIyPZsmVLtXXi4+MJCwvD29ub6OhoBgwYwLJlyzh8+LB2mRs3bhAUFISlpSUbNmxg2rRp7Ny5k1WrVj2M3RKiWWpR3w2oSnR0NO7u7kRERAAwdOhQSktL2bZtmzYkfm/Dhg34+/uzZMkSAIYMGcKdO3e0IQUQExODtbU1mzZtwtzcHB8fHywtLfnv//5vgoOD6dix48PbSSGaiQbXkykuLiYxMZERI0bolPv7+1NYWEhycrJenezsbDIyMqqsc+PGDdLT0wH45ptvePLJJzE3N9cu8/TTT1NeXs4333yjwN4IIRpcTyYrK4uSkhJcXFx0yp2dnQHIyMhg8ODBOu+lpaUBPLBOhw4d+Oc//4mrq6vOMnZ2drRq1YqMjAyD26jRaCgtLa1xjEitVqNq2x4zldrgdTd3Kts/UFZWRllZmUnWp1araWtjhUr14GMl/sO2tZVBx0CtVmNmZoZKpXrgcg0uZPLz8wFo1aqVTrm1tTUABQUFenXu3r1bY53q1lu5XFXrrU55eTmXL182bOG2bhX/hOH++leTrs7OquKfMNQ9/mrgMejduzdq9YP/iDa4kCkvLweoNh3NzPTP8KqrU9nTMDMze2CvQ6PR1JjGv29D7969DV5eiKaqqu/j7zW4kLGxsQH0eyz37t0Dqu6JVFensLBQW6d169Y66/n9cpXvG0KlUtWY3kKICg1u4NfJyQm1Wk1mZqZOeeVrNzf9U4/KcZYH1XnkkUdwcHDQWyY3N5eCgoIq1yuEqLsGFzIWFhb069eP48eP65ziJCQkYGNjg4eHh14dZ2dnHB0dSUhI0ClPSEjAxcVF+9O0t7c3p06d0rn47osvvkCtVjNgwACF9kiI5q3BnS4BBAcHExQUxLx58wgMDOTixYvExsaycOFCLC0tKSgoIDU1FScnJ+zs7AAICQkhPDwcW1tb/Pz8OHnyJPHx8URGRmrXO336dD7//HOmT59OUFAQGRkZrF+/ngkTJtChQ4f62l0hmjZNA3Xs2DHNM888o+nZs6fGz89PExsbq33v/Pnzmscee0yzf/9+nTq7du3SDB8+XNOrVy/NyJEjNQcPHtRb73fffacZP368plevXpohQ4Zo3nvvPU1JSYnSuyNEs6XSaGq42EMIIeqgwY3JCCGaFgkZIYSiJGSEEIqSkGnEJk+eTLdu3XT+9evXjylTppCUlKSz3OTJk+uxpY3D5MmT6dGjB3/729+qfN/Pz4/FixfXeTt+fn46x8zd3R0vLy9mzZrFtWvXFNlmfWqQP2ELw/Xo0YO3334bgLKyMvLy8ti1axfTpk3jwIEDdO3aVfu+qFlZWRnh4eEcOHBA5259U/Px8SEkJASA0tJSbt26xfbt25k6dSpHjx7l0UcfBWDjxo1VXuXemEjINHKtWrXSu49q0KBBDBw4kAMHDhAWFiZXM9dC69at+fHHH4mOjmbBggWKbcfOzk7vuD3++OMMGzaML774gkmTJgEVf0QaOzldaoKsrKywsLDQ3vT5+9Ol3NxcVq5cia+vL7169aJ///6EhoaSnZ2tXSYrK4vg4GC8vLzw9PRkwoQJnD59+qHvy8PWvXt3xo4dS0xMDFeuXHngsmVlZXzyySc8++yzeHh48OSTT/Lee+9x//59o7Zta2urV/b706Xs7GwWLVrE4MGD6dmzJwMHDmTRokXk5eVpl7l69SpTp06lb9++9OnTh1deeYVLly4Z1SZTkJBp5DT/P7dNaWkpJSUl5OTksH79eoqLiwkMDKxy+ZkzZ/Ltt9/yxhtvEBsbS0hICGfPnmX58uVAxV3tM2fOpLCwkHXr1rFp0yZsbW0JCQnRu/erKVq6dCl2dnaEh4c/cP7n5cuX88477+Dn58fmzZuZNGkSO3fuJCQkpMa5hn573IqLi7l58yarV6/G3t6ekSNHVlmnqKiIKVOmkJaWxttvv01sbCwvv/wyf/nLX1i/fj1QcZPw9OnTadu2LR988AGRkZEUFRUxbdo07ZQoD5ucLjVy3333HT179tQrf/311+nSpYte+a1bt7CysiIsLIx+/foB4OXlRXZ2Nrt37wbg3//+N2lpacyaNQsfHx8APDw82Lhxo9F/pRsTGxsbVq5cSXBwcLWnTampqezbt4/58+cTHBwMVNwb165dOxYtWsSZM2e0n11V4uLiiIuL0ylTqVRERERob5X5vYyMDNq3b8+aNWtwcnICYMCAAfztb3/TDvSnpqaSm5vL5MmT6du3LwCdO3dm9+7dFBQU1Gq2AVORkGnkevbsycqVK4GKv475+fmcOXOGyMhICgsL9b4gDg4OfPTRRwDcvHmTzMxM0tLSSElJoaSkBAB7e3vc3Nx46623OHv2LEOHDmXw4MGEh4c/3J2rR35+fowZM4aYmBhGjBihF+SVX+pnn31Wp3z06NGEh4eTmJj4wJDx9fUlNDQUqDhuubm5xMfHs3DhQoqKinjhhRf06nTv3p1PP/2U8vJysrKyyMjI4Mcff+Qf//gHpaWlAHTt2hU7OzuCg4MZOXIkPj4+2lOq+iIh08hZW1vz+OOP65QNHjyYwsJCYmJimDJlil6dw4cPs379ev71r39ha2uLu7u7zuTsKpWK7du3s3nzZo4fP87Bgwdp2bIlw4YNY8WKFVWOHTRFy5Yt49y5cyxevJj9+/frvHfnzh0A/vCHP+iUt2jRgrZt29Z4amJra6t33J588klu3bpFREQEgYGBVc5Z9OGHH7J161by8vKwt7enZ8+eWFlZabdnbW3NJ598wubNmzl69Ci7d+/GysqKMWPGsHTpUiwsLGr9OdSVjMk0Ud27d6e0tFRnMBcgOTmZsLAwhg8fzunTp0lMTGTHjh16v3Q4ODiwYsUKvvnmG+Li4pg2bRrHjh3Tuau9qWvTpg0rVqzg73//O5s3b9Z7DyAnJ0envKSkhLy8PNq2bWvUNt3d3cnPz9cZyK105MgR1qxZw6uvvsq5c+f49ttv2bZtm97c1p07dyYiIoLz58+ze/duxo4dy2effcaOHTuMalNdScg0URcvXkStVuPo6KhXXl5ezty5c2nfvj1Q8SvJ2bNngYpB34sXLzJo0CAuX76MSqWie/fuLFiwgMcee4yffvrpoe9LfRo2bBjPPPMM27ZtIzc3V1vev39/oOKL/1uff/45ZWVl2vGQ2rp06RJt2rSpMqQuXLhA69atee2117TjNvfu3ePChQvaKWi/+OILBgwYQE5ODmq1mj59+rBixQpsbGzq7djJ6VIjV1BQoDPpc0lJCV9++SVHjhxhwoQJeoOIlZN+rVq1isDAQPLz89m5c6f2StPCwkJ69OiBpaUlixYtYs6cOdjb23P27Fl++OGHKk+/mrq33nqL8+fPc/v2bW2Zm5sb48aNY+PGjfz66694eXnxww8/sHHjRry8vBgyZMgD15mbm6tz3IqKioiLi+PChQu8/vrrVZ4qeXh4sGvXLtasWYOvry+3bt0iNjaW27dva3tWf/zjHykvLyc0NJTXXnsNa2tr4uPjuXv3rt4jgx4WCZlG7vvvv2fChAna1xYWFjg5ObFgwQKmTZumt7yXlxfLly/nww8/5IsvvsDe3h4vLy82btxIaGgoFy5cwMfHh+3bt/OnP/2J1atXk5+fj4uLC6tWrSIgIOBh7l6DYGtry4oVK5g9e7ZO+erVq3F2dmb//v3ExsbSrl07Jk+eTGhoaI0TbJ8+fVrnuqNHHnkEV1dX3n77bV566aUq64wbN47s7Gz279/Pp59+ioODAz4+Prz00ku89dZbpKam4ubmRkxMDO+//z5Lly6lqKiIrl27EhUVVW+zP8p8MkIIRcmYjBBCURIyQghFScgIIRQlISOEUJSEjBBCURIyQghFScgIIRQlISOEUJSEjBBCURIyQghFScgIIRT1f1ZH12JGt3j5AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 250x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "last_test_losses_relu_no_bias = []\n",
    "for run in losses_test_relu_no_bias:\n",
    "    last_test_losses_relu_no_bias.append(run[-1])\n",
    "\n",
    "last_test_losses_relu_bias = []\n",
    "for run in losses_test_relu_bias:\n",
    "    last_test_losses_relu_bias.append(run[-1])\n",
    "\n",
    "test_losses = [last_test_losses_relu_bias, last_test_losses_relu_no_bias]\n",
    "model_names = [\"ReLU\", \"ReLU\"]\n",
    "test_loss_labels = [\"Bias\", \"No Bias\"]\n",
    "df_test_data = []\n",
    "for model, bias, losses in zip(model_names, test_loss_labels, test_losses):\n",
    "    for loss in losses:\n",
    "        df_test_data.append({\"Model\": model, \"Bias\": bias, \"Test Loss\": loss})\n",
    "df_test = pd.DataFrame(df_test_data)\n",
    "\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rcParams[\"font.family\"] = \"sans-serif\"\n",
    "custom_palette = [sns.color_palette(\"Set2\")[1], sns.color_palette(\"Set2\")[2]]\n",
    "\n",
    "# Create the barplot - seaborn will automatically compute confidence intervals\n",
    "fig, ax = plt.subplots(figsize=(2.5,3))\n",
    "sns.barplot(x=\"Bias\", y=\"Test Loss\", data=df_test,capsize=.3, palette=custom_palette, ax=ax)\n",
    "\n",
    "ax.set_ylabel(\"Test Loss\", fontsize=16)\n",
    "ax.tick_params(axis='x', labelsize=12)\n",
    "ax.tick_params(axis='y', labelsize=12)\n",
    "# remove the x label\n",
    "ax.set_xlabel(\"\", fontsize=14)\n",
    "\n",
    "sns.despine(left=True)\n",
    "plt.show()\n",
    "# plt.tight_layout()\n",
    "# plt.show()\n",
    "save_figure('/relu_nets_bias/', 'Test_Loss_ReLU_MNIST', fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "a07f36f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/zm/n63_ssc512d4ygg3lbvwql400000gq/T/ipykernel_5037/595282805.py:18: FutureWarning: \n",
      "\n",
      "Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.\n",
      "\n",
      "  sns.barplot(x=\"Bias\", y=\"Effective Rank\", data=df_e_rank,capsize=.3, palette=custom_palette, ax=ax)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPAAAAEiCAYAAADH1F8CAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAsh0lEQVR4nO3daVgUV9428Lu7WQXZXZFGRNlF0ICKUQSXjgsKqDBvjDhoYhB1iIp7zKgxA+o4YhCvBEVGjYo4PooSkeAomqghYTBRZFEUENEgPIABAVm63g8M/aRtlqa7Gij8/64rH7q66pzTFW9qO+cUj2EYBoQQTuJ3dwMIIYqjABPCYRRgQjiMAkwIh1GACeEwCjAhHEYBJoTDKMCEcBgFmBAOowATwmEUYEI4jAJMCIdRgAnhMAowIRxGASaEwyjAhHAYBZgQDqMAE8JhFGBCOEytuxtACFvq6urQ2NiosvLV1NSgpaWlsvIVQQEmvUJUVBTOnz8PsVissjr4fD68vb2xYsUKldXRWTyalZL0BtOmTVNpeFvw+XykpKSovB550TUw6RW8vb3B56v2n3PLEbgnoSMw6VZiMQM+n8dKWfJeA1dXV2PhwoWSzydOnICurm6H27F5DczW76ZrYNKt+HweUn7IRcXvNUqX1dBQD3FTU4fr1b+ulfr87bX70NDU7nA7vkAAdXUNhdvXwlCvD6a9a610OQAFmPQAFb/XoKz8lVJl/JqWhLzsNECBE8qTh/8m34o8HobbjsWosTM6XYeq0DUw6RUeZf+kUHg7hWGa6+lBKMCkV7C0dQV47FxLt4XH4zfX04PQKTTpFUaNnQH7MVPkugZWFF8ggJqa8tfAbOJUgBmGQXx8PL755hs8ffoURkZG8PT0REhISJt3ER89eoSZM2fKLLewsMDly5dV3WTShdTUNDj2L1p5nPq5hw8fxr59+7B06VKMHz8ehYWF2L9/Px4+fIjY2FjwWjmFysnJAQAcO3YMmpqakuU9rUscIYrgTIDFYjGio6Ph7++PtWvXAgDc3NxgYGCATz75BJmZmRg5cqTMdtnZ2TA1NcXYsWO7usmEqBxnbmJVV1djzpw5mD17ttRyCwsLAEBRUVGr22VnZ8PW1lbl7SOkO3DmCKynp4etW7fKLP/uu+8AACNGjGh1u5ycHFhaWsLf3x9ZWVnQ09ODj48PQkJCoK6urtI2E6JqnAlwazIyMnDo0CFMnTq11QCXlZWhrKwMPB4PoaGhGDx4MG7fvo1Dhw7h+fPn2Lt3r9x1icViUK9T9gkEgu5uQrdpaueOubz7hbMBTk9PR1BQEIRCIb744otW19HV1UVsbCwsLCwwaNAgAICrqys0NDQQERGB4OBgWFpaylVfZmYmGhoaWGs/AbS1tWFnZ9fdzeg2ubm5qK2tbfW7MWPGyFUGJwP87bffYuPGjbCwsEBMTAwMDAxaXU9LSwtubm4yyydPnoyIiAjJ6bU8HBwc6AhMWGVtrXx/aFYCfPv2bfz444/IyMhASUkJKioqoKWlBSMjI1hZWcHV1RWTJ09Gv379lK7r8OHD+Pvf/w4XFxccPHgQffv2bXPdx48fIy0tDV5eXlLPievq6gAAhoaGcter6qFq5O3DxuWDwgF+9eoVjh8/jvj4eDx//lxydNLU1IS+vj5ev36Nhw8fIjc3FxcvXoSamho8PT2xePFiuU8P3hQXF4c9e/ZgxowZ2L17NzQ02u8VU1JSgm3btkFdXR3z58+XLL906RJ0dHRgb2+vUDsI6SkUCvCpU6dw4MAB/O///i9sbGzg7+8PJycnODg4QEdHR7IewzAoKCjAr7/+ips3b+Lf//43UlJSMGXKFGzYsAFmZmZy11laWoqwsDCYmprigw8+QFZWltT3QqEQGhoayMvLg1AohJGREVxdXeHq6orw8HDU1tZi2LBhSE1NxfHjx7F+/Xro6+sr8vMJ6TEUGtBvb28PLy8vLF26tM3HN62pq6tDYmIioqOjMWfOHKxcuVLubf/1r39hy5YtbX7fEu6AgACEhYXB19cXAFBVVYXIyEhcuXIFpaWlEAqFWLx4Mfz8/OSum6hW/KU7Sg8n5BITIx34zXRmpSyFAvzkyRMIhUKFK21qasJvv/0GU1NThcsgvQcFWHEK3ZkRCoXtPsPqiEAgoPASwgKFb2KNHj0aVlZWsLW1hYODA+zs7GBtbU29mwjpQgoHWFNTE/fu3cO9e/dw5swZAM1H1hEjRsDOzg52dnawt7eHra2t1CggQgh7lJqVsrCwEHfv3sW9e/dw9+5dZGVlob6+vrng/w7tEwgEsLCwkAQ6ICCAnZaTXoOugRXH6rSy69atQ2JiIsLDw5GTk4OsrCzk5OTg5cuXzZXxeMjOzmarOtJLUIAVp5KulHPnzsXcuXMln589e4bs7GwKLyEs65K+0IMHD8bgwYMxZcqUrqiOkLcGdfAlhMMowIRwmMIBLikpYbMdhBAFKHwN7O7ujgEDBsDR0REjR47EqFGjJMP0CCFdQ+EAjxgxAvn5+UhJSUFKSorUlK6rV6+GjY0NbGxsYGtri/79+7PSWEKINKWeA9fX1yMnJwf379/H/fv3kZWVhQcPHkhe8dgSakNDQ9ja2sLW1hahoaHstJz0GvQcWHGsvx+4oaEBubm5yMrKQmZmpiTU9fX11JGDtIoCrLhOnUKXlZXBxMSk3XXU1dXh4OAABwcHyZjbxsZG5OXlITMzU/GWEkJkdOou9MSJE1FeXt7pStTU1GBjYyM1rQ0hRHmdCjDDMDQzIyE9CHXkIITDOh3gK1eu4NGjR3QkJqQH6PRz4M8//xxNTU3Q0tKCtbU1bG1tYWdnB1tbW1hZWXU41SshhD2dDvC///1vVFZWSoYHZmVlITExEVVVVVBTU8OwYcNw4cIFVbSVEPKGTgWYx+NBTU0N1tbWsLa2hre3t+S7oqIiZGVl0XNeQrpQpwLc3nWvmZkZzMzMIBKJlG4UIUQ+nbqJ9dVXX7X7LiJCSNfq1BF48uTJKmoGIUQRCj8HVmZid0IIO2hid0I4jCZ2J4TDaGJ30u1oOKHiaGJ30u0owIqjid0J4TBOTezOMAzi4+PxzTff4OnTpzAyMoKnpydCQkKgq6vb5nYJCQmIjo5GUVERBg0ahA8//BALFixQqi2E9ARdEmC2HD58GPv27cPSpUsxfvx4FBYWYv/+/Xj48CFiY2OlJtZrkZSUhA0bNiAgIAATJ07ElStX8Omnn0JTUxNz5szphl9BCHs4E2CxWIzo6Gj4+/tj7dq1AAA3NzcYGBjgk08+QWZmJkaOHCmzXUREBEQiETZv3gygeVaRly9fIjIykgJMOI8zE7tXV1djzpw5mD17ttRyCwsLAM2DKd709OlTFBQUYPr06VLLRSIRnjx5gvz8fNU1mJAuwJmJ3fX09LB161aZ5d999x2A5nmq3/To0SMAwNChQ6WWm5ubAwAKCgokfwAI4SJOT+yekZGBQ4cOYerUqa0GuKqqCgBkbnDp6OgAaD6qy0ssFtMsJCogEAi6uwndpr3uyPLuF4UDfPHixTYndk9KSkJSUpJKJ3ZPT09HUFAQhEIhvvjii1bXEYvFACBzc6sliHy+/FcQmZmZaGhoULC1pDXa2tqws7Pr7mZ0m9zcXNTW1rb63ZgxY+QqQ6mbWBoaGnB0dISjo6NkWVsTu9+8eRO3bt1iJcDffvstNm7cCAsLC8TExMDAwKDV9fT09ADIHmlramoAyB6Z2+Pg4EBHYMIqa2trpctg/S60qid2P3z4MP7+97/DxcUFBw8ebHd8csv1bWFhodRf+sLCQgDA8OHD5a63M0drQuTBxuVDl/yrZGti97i4OOzZswfvvfceYmJiOpxcwNzcHGZmZkhOTpZanpycjKFDh8LU1FSp9hDS3TjzHLi0tBRhYWEwNTXFBx98gKysLKnvhUIhNDQ0kJeXB6FQCCMjIwBAcHAwNm3aBAMDA3h6euLq1atISkrCvn37uuNnEMIqhQL88ccfY9WqVXBwcOj0tnV1dThx4gS0tbXx/vvvy73d9evXUVdXh+LiYixcuFDm+5ZwBwQEICwsDL6+vgAAX19f1NfX48iRIzh79izMzMywa9cuzJw5s9NtJ6SnUWg00pw5c/Dw4UO4urpi7ty5mD59eoc3hO7du4cLFy4gMTERNTU1CA8Px4wZMxRuOOk9aDSS4hQKMMMwOHv2LA4ePIhnz56Bz+dj2LBhsLOzg7GxMfT19VFXV4eXL1+isLAQmZmZqKqqAp/Px3vvvYfVq1djyJAhrPwAwn0UYMUpNR5YLBYjNTUV586dw08//SQZ9/tHfD4f1tbWmDJlChYsWIABAwYo1WDS+1CAFafUTSw+nw9PT094enoCaO66+Ntvv6GyshKampowMjLCiBEjaCpaBdTV1aGxsVElZaupqUFLS0slZZOuxepdaEtLS1haWrJZ5FspKioK58+fl/QkYxufz4e3tzdWrFihkvJJ12EtwDk5OThx4gQePHgAPp8PoVCI8ePHQyQSQVtbm61q3gqqDC/QfOlz/vx5CnAvwEpHjuvXr2PBggU4c+YMfv31V9y5cwcJCQnYtGkT3N3dcerUKTaqeWt4e3urtOdXyxGYcB8rk9r5+PggNzcXa9aswcyZM6GlpYUXL17g2rVrOHPmDJ4/f46FCxfi008/ZaPNbwV5roGrq6ulnomfOHFCrv7dPe0amG5iKY6VU+i8vDzMmDEDH374oWSZkZERbGxs8Oc//xmfffYZTpw4gdGjR1MHCjkpEjBdXd1ODdAg3MfKeZpAIGjzua62tjbCw8Nhbm6Of/7zn2xU1yMxKrxm7cne1t/dU7ByBLawsMDdu3fb/F4gEMDDw6NXXwvz+Hy8vHIKTRUvWCnvdUMjGsXtX928ei09Pvnp6UjoaHb8ahs1Pg+a6sr/rxcY9of+1P+ndDlEcawEOCgoCCEhIUhISJCaD/qPGhoa0KdPHzaq67GaKl6gsaxY6XKO/lKE5LwX6OzNiRUnr8i1Hg+AaHh/LHYy63TbSM/Cyim0SCSCj48PNm7ciI0bN8qMFHrw4AEuXryIqVOnslFdr/fdo86HtzOY/9ZBuI+158BffPEFjIyMcPToUSQkJMDIyAgDBgxAXV0d8vPz4e7uji1btrBVXa823bK/QkdgefF5zXUQ7mP13UgAkJ+fj/Pnz+PGjRvIzc2VmpfK2NgY1tbWsLW1lUx415t6bpWf2c/KKTQAvG4Uo1FFN4jU+Hxoqil/8qVmYgqjBSFKl0OPkRTH+oB+CwsLrF69GqtXr0ZdXR2ys7ORlZUlmfQuLS0NN2/eBEAvO2uPphofml0zYQrhMJXOyKGlpQVnZ2c4O//fX5v6+no8ePAA9+/fp/ASoqQun1JHQ0NDMukdIUQ5dI5GCIdRgAnhMAowIRxGASaEwyjAhHAYBZgQDuuyx0i3bt3CV199BU1NTbi4uOCjjz6SeWsgIaRzuuwIXFZWhp9++gn+/v64ffs2YmNju6pqQnqtLjsCT5gwAceOHYOrqyvGjx+P+vr6rqqakF6LtSNwR4E0NjbG77//DgDQ0dGBoaEhW1UT8tZiLcCrVq1qcxK2qqoqrFu3DqtWrWKrOkIIWAzw9evXsXbtWpn5jK9fv45Zs2bh4sWLcHJyYqs6QghYDPD69euRnJyMTZs2AQBevXqFLVu2ICgoCFVVVdiyZQtOnjzJVnWEELB4E2vJkiXg8XjYtWsXampqcP/+fTx79gxubm7YsWMHvY2QEBVg9S50YGAgeDwewsPDIRAIsHPnTsyfP5/NKiSeP38OLy8vREVFYezYsW2u9+jRo1bnorawsMDly5dV0jZCugrrj5H+/Oc/g8fjISwsDHl5eWwXDwAoLi7G0qVLUVVV1eG6OTk5AIBjx45BU1NTsrwnvZmAEEUpHGAbG5sOe1IdPXoUR48elXzm8XgyM1Z2hlgsxrlz57B79265t8nOzoapqWm7R2lCuErhALu4uLDZDrnk5uZi27ZteP/99+Hm5oZly5Z1uE12djZsbW27oHWEdD2FA3z8+HE22yGXQYMGISUlBQMHDkRaWppc2+Tk5MDS0hL+/v7IysqCnp4efHx8EBISAnX1jt9iQEhP1uVzYinDwMCgU+uXlZWhrKwMPB4PoaGhGDx4MG7fvo1Dhw7h+fPn2Lt3r9xlicVitDcDr0Ag6FTbepOmpiaFt6X91jp59wurAa6urkZSUhLKyspabRyPx+vSl0rr6uoiNjYWFhYWGDRoEADA1dUVGhoaiIiIQHBwsNzzUmdmZqKhoaHV77S1tWFnZ8dau7kmNzcXtbW1nd6O9lvb+23MmDFylcFagO/du4clS5agurq6zSNVVwdYS0sLbm5uMssnT56MiIgIyem1PBwcHNo9Ar/NrK2tu7sJnMTGfmMtwBEREairq8Nf/vIXjBw5EhoaGmwVrbDHjx8jLS0NXl5eUu/NraurA4BODajg82nug7a8zafBymBjv7EW4IyMDAQGBmL58uVsFam0kpISbNu2Derq6lIdSi5dugQdHR3Y29t3Y+sIUR5rAW7vJd9dpbq6Gnl5eRAKhTAyMoKrqytcXV0RHh6O2tpaDBs2DKmpqTh+/DjWr18PfX39bm0vIcpi7bzQ2dkZP//8M1vFKeT+/fvw9/dHamoqgOY/KgcPHoSvry9iY2MRFBSEW7duYceOHViyZEm3tpUQNrD2dsLc3Fx88MEHWL58Oby8vNCvXz82iuUUNt9OyAX0dkLF9Mi3E4aGhkIgEGDPnj3Ys2dPq+so25WSECKNtQAbGBh0uqMFIUQ5rAW4O7pWEvK2o4ebhHAYq10pnz9/juTkZNTU1EjNjdXU1ISqqir88MMPNIieEBaxFuDbt2/jo48+QlNTExiGAY/Hk3Q9bBk33L9/f7aqI4SAxQB//fXXUFdXx+bNmwEAO3fuRFRUFKqrq3H8+HHk5eXRpHaEsIy1a+CsrCz4+fnh/fffx/z588Hn86GmpoY5c+bg2LFjMDExQVRUFFvVEULAYoBramowYsQIAICGhgaGDBmC3NxcAM3Dxnx8fPCf//yHreoIIWAxwPr6+qiurpZ8HjJkCB49eiT53L9/f7x48YKt6gghYDHAo0aNwoULF/D69WsAzdO2pqenSwb2P3jwADo6OmxVRwgBiwFevHgxcnJyMG3aNFRWVsLLywtPnz5FYGAg/vrXv+LkyZNyzzJACJEPawEeO3Ys9u/fDxMTE+jp6cHR0RGrV6/Gzz//jNOnT8PU1BShoaFsVUcIAcsdOaZNm4Zp06ZJPn/88ceYO3cuKisrMXz4cKipcWoOPUJ6PNaOwPv27Wv1LvPAgQNhY2ND4SVEBVgL8LFjx+gxESFdjLUA9+nThyY3I6SLsRbgNWvWIDo6GnFxcSgpKZF50TchhH2sXZj+85//RH19PbZv347t27e3ug7NyEEIu2hGDkI4jGbkIITDaEYOQjiMAkwIh1GACeEwCjAhHEYBJoTDVBbg+vp66sxBiIqxGuDKykrs2LED7777LpycnJCWlob09HQEBQUhPz+fzaoIIWAxwJWVlfD398fJkyehra0tmVL25cuXSE1NxcKFC1FUVMRWdYQQsBjgAwcOoLi4GLGxsTh9+rQkwFOmTEF0dDRqampw8OBBtqojhIDFAF+9ehV+fn4YP368ZCL3FpMmTYK/vz/S0tLYqo4QAhYD/OLFC9jY2LT5vaWlJUpLS9mqDs+fP8c777wj1x+FhIQEzJo1C46OjhCJRDhz5gxr7SCkO7EWYGNjYxQXt/1y6wcPHsDQ0JCVuoqLixEYGIiqqqoO101KSsKGDRswYcIEREVFYdy4cfj0009x4cIFVtpCSHdibTDDpEmTEBcXhwULFshMH5uRkYH4+HjMnj1bqTrEYjHOnTuH3bt3y71NREQERCKR5JUvEydOxMuXLxEZGYk5c+Yo1R5CuhtrR+CVK1dCXV0dPj4+2LRpE3g8HuLi4hAUFIRFixZBW1sbwcHBStWRm5uLbdu2wdvbW64QP336FAUFBZg+fbrUcpFIhCdPntCjLcJ5rAV4wIABOHXqFJydnXHjxg0wDIPk5GSkpqbCyckJx48fx5AhQ5SqY9CgQUhJScGmTZugpaXV4fotb4YYOnSo1HJzc3MAQEFBgVLtIaS7sXYKfe3aNbi7uyM6OhpVVVUoKCiAWCzGkCFDYGxszEodnZ0woOUaWVdXV2p5yyn+H18F0xGxWCx5NNaat3k+sJa3byiC9lvr5N0vrAV4+fLlMDExgY+PD3x9fTFy5Ei2ilZYS1fONx9rtQSRz5f/BCQzMxMNDQ2tfqetrQ07OzsFW8l9ubm5qK2t7fR2tN/a3m/yvsWEtQCvX78eCQkJOHToEA4fPgxnZ2fMmzcPM2bMQJ8+fdiqplP09PQAyB5pa2pqAMgemdvj4ODQ7hH4bWZtbd3dTeAkNvYba9fAS5YsQUJCAhISErB48WI8efIEW7ZswbvvvovNmzcjPT2drarkZmFhAQAoLCyUWt7yefjw4XKXxefzIRAI2vzvbdbefunov7cZG/uF9dFI1tbW2LhxI27cuIGvv/4aU6ZMweXLl7Fo0SKIRCK2q2uXubk5zMzMkJycLLU8OTkZQ4cOhampaZe2hxC2qex9J3w+H6NHj0Z5eTnKy8tx8+ZNPH36VFXVAWg+Vc7Ly4NQKISRkREAIDg4GJs2bYKBgQE8PT1x9epVJCUlYd++fSptCyFdgfUANzY24vr160hISEBqaioaGhowePBgrFy5Er6+vmxXJ+X+/fsICAhAWFiYpC5fX1/U19fjyJEjOHv2LMzMzLBr1y7MnDlTpW0hpCvwGJbuzPz6669ISEjApUuX8PLlS2hoaGDq1KmYP38+xo8fz0YVPV75mf1oLGu7O2lvo2ZiCqMFIUqXE3/pDsrKX7HQIm4wMdKB30xnVspi7Qjs7+8PALC3t8df/vIXeHl5oW/fvmwVTwhpBWsBDggIwPz582FlZcVWkYSQDrAW4JbBAoSQrkOzUhLCYQofgW1tbbF79254eXkBAGxsbGS6LL6J3k5ICLsUDvA777wDExMTyWcXFxdWGkQIkZ/CAX7zbYQdvZ2wsbGx3Rk7CCGdx9o1sK2tLRITE9v8/vz58yrvyEHI20bhI3BJSQlu374t+cwwDH7++Wc0NjbKrCsWi3Hx4kV6UwMhLFM4wIaGhoiMjMSzZ88ANN+gio+PR3x8vMy6LZ29aA4qQtilcIA1NDRw4MAB5OTkgGEYbN68GX5+fnB2lu0ixufzYWJignHjxinVWEKINKU6ctja2sLW1hYA8PPPP2PevHkYNWoUKw0jhHSMtZ5YYWFhAJonkhMKhVBXVwfQHOy+ffu2O+k7IUQxrN2FbmxsRGhoKGbPni01A0ZcXBx8fHzw+eef05Q0hLCMtSNwbGwsEhMTMWvWLMlgegBYtmwZtLW1cfLkSVhZWUlGLRFClMfaEfj8+fOYOXMm9u7dKxVga2tr7Ny5EyKRCKdOnWKrOkIIWAxwcXFxu3eZx48fLzO5HCFEOawFuG/fvnjy5Emb3z979kyutykQQuTHWoDd3Nxw8uRJPHz4UOa7goICnDx5kp4DE8Iy1m5iBQcHIyUlBfPmzYO7uzuGDRsGAMjPz8f3338PPp+PVatWsVUdIQQsBtjc3BwnT57Ezp07ceXKFalHRk5OTti6dask1IQQdrA6rayNjQ2++eYbVFRU4NmzZ2hsbGT15WaEEGkqmdjd0NAQOjo6UFNT69QLxAghncNquiorK7Fjxw68++67cHJyQlpaGtLT0xEUFEQv0yZEBVgLcGVlJfz9/XHy5Eloa2tLroFfvnyJ1NRULFy4EEVFRWxVRwgBiwE+cOAAiouLERsbi9OnT0sCPGXKFERHR6OmpgYHDx5kqzpCCFgM8NWrV+Hn54fx48fLzE45adIk+Pv7Iy0tja3qCCFgMcAvXrxod8igpaUlSktL2aqOEAIWA2xsbNzurJMPHjyAoaEhW9URQsBigCdNmoS4uLhW3wGckZGB+Ph4vPvuu2xVRwgBi8+BV65ciWvXrsHHxwdjxowBj8dDXFwcjh49iu+//x66uroIDg5Wup4bN24gIiICjx49gpGREf70pz9h2bJlbb4V4tGjR62+C9jCwgKXL19Wuj2EdCfWAjxgwADExcVhx44duHHjBhiGQXJyMgBgzJgx+Otf/4ohQ4YoVUdGRgaCg4MxY8YMfPLJJ/jPf/6Dffv2QSwWY/ny5a1uk5OTAwA4duwYNDU1JctpZBTpDRQO8Lx587B8+XJMnToVQPPcV5aWloiOjkZVVRUKCgogFotZ7UoZFRUFGxsb7NmzB0DzaXtjYyOio6MRGBjYaiizs7NhamqKsWPHstIGQnoSha+BHzx4gPLycsnngIAA3Lx5E0Dz2OCRI0di1KhRrIW3vr4eaWlpmD59utRykUiEmpoapKent7pddna2ZOZMQnobhY/A/fr1Q0xMDOrr66GrqwuGYZCeno6mpqZ2t/P29laovqKiIjQ0NGDo0KFSy83NzQE0jzlu7SZZTk4OLC0t4e/vj6ysLOjp6cHHxwchISGSmTMJ4SqFA7x06VJ8/vnn+OKLLwBIv5nhzdkneTweGIYBj8dTOMC///47AEBXV1dquY6ODgCgurpaZpuysjKUlZWBx+MhNDQUgwcPxu3bt3Ho0CE8f/4ce/fulbt+sVjc7qyaAoFA7rJ6m47+aLeH9lvr5N0vCgd44cKFcHFxwYMHD1BfX4/NmzfD398fTk5OihbZrpb3KrV1t7m1UU+6urqIjY2FhYUFBg0aBABwdXWFhoYGIiIiEBwcDEtLS7nqz8zMRENDQ6vfaWtrw87OTq5yeqPc3FzU1tZ2ejvab23vtzFjxshVhlJ3oa2srGBlZQWguS/0pEmTMGXKFGWKbJOenh4A2SPtq1evAMgemYHmO81ubm4yyydPnoyIiAjJ6bU8HBwcaF7rNlhbW3d3EziJjf2m8E2sefPm4cqVK5LPu3btavW9SGwRCoUQCAQyM1u2fB4+fLjMNo8fP8apU6dkQl9XVwcAneoZxufzIRAI2vzvbdbefunov7cZG/uF1bvQt27dUrS4DmlqauKdd95BSkqK1JEwOTkZenp6cHR0lNmmpKQE27Ztk+mwcenSJejo6MDe3l5l7SWkK7B6F7qt9wP/kaI3sQBg+fLlCAwMREhICObNm4c7d+4gJiYGoaGh0NLSQnV1NfLy8iAUCmFkZARXV1e4uroiPDwctbW1GDZsGFJTU3H8+HGsX78e+vr6CreFkJ6Axyh4YXfixAl8/vnnbd5UelPLXejs7GxFqpNISUnBl19+ifz8fAwYMAALFy7EkiVLAABpaWkICAhAWFgYfH19AQBVVVWIjIzElStXUFpaCqFQiMWLF8PPz0+pdrSm/Mx+NJa1PaCjt1EzMYXRghCly4m/dAdl5a9YaBE3mBjpwG8mO5ebCgcYaD6N7uxdaB8fH0Wr6/EowIqhACuOM3ehCSGyFL6JFRYWhqysLMnnq1evYsqUKaiurpY8s/2jK1euULgJYZnCAT569CgePXoktayiogIuLi6tTp1TU1ODZ8+eKVodIaQVrE/aTJ0dCOk6NOs6IRxGASaEwyjAhHAYBZgQDqMAE8JhSnXkeHMGjpahfTdv3kRJSYnUuhkZGcpURQhphVIBbpmB400xMTEyy1r6QhNC2KNwgFeuXMlmOwghCqAAE8JhdBOLEA6jABPCYRRgQjiMAkwIh1GACeEwCjAhHEYBJoTDKMCEcBgFmBAOowATwmEUYEI4jAJMCIdRgAnhMAowIRxGASaEwyjAhHAYBZgQDqMAE8JhnAvwjRs34Ovri1GjRsHDwwNff/11h+9jSkhIwKxZs+Do6AiRSIQzZ850UWsJUS1OBTgjIwPBwcGwtLREZGQk5syZg3379uGrr75qc5ukpCRs2LABEyZMQFRUFMaNG4dPP/0UFy5c6MKWE6IaSk0r29WioqJgY2ODPXv2AAAmTZqExsZGREdHIzAwEFpaWjLbREREQCQSYfPmzQCAiRMn4uXLl5I/AIRwGWeOwPX19UhLS8P06dOllotEItTU1CA9PV1mm6dPn6KgoKDVbZ48eYL8/HyVtpkQVeNMgIuKitDQ0IChQ4dKLTc3NwcAFBQUyGzT8gLyzmxDCJdw5hT6999/BwDo6upKLdfR0QEAVFdXy2xTVVXV6W1awzAMGhsb271ZJhAIwDMcCD5PIFeZvQHPoB+ampqkXq/TWQKBAIZ62uDx3p4Xwxv01e5wvwkEAvD5/A7fZsKZAIvFYgBo8wfx+bInE21t0xLE1rZpq+67d+92vKLh8Ob/3ia//KJ0EUbazf+9PV7hFzn2m5OTEwSC9g8InAmwnp4eANmjZssL1d48yra3TU1NTZvbtIbP58PJyalT7SVEWfIcYDgTYKFQCIFAgMLCQqnlLZ+HD5c98llYWEjWsbOzk2ub1vB4vA7/EhLSHThzE0tTUxPvvPMOUlJSpK5Fk5OToaenB0dHR5ltzM3NYWZmhuTkZKnlycnJGDp0KExNTVXebkJUiTNHYABYvnw5AgMDERISgnnz5uHOnTuIiYlBaGgotLS0UF1djby8PAiFQhgZGQEAgoODsWnTJhgYGMDT0xNXr15FUlIS9u3b182/hhAWMBzz3XffMbNnz2bs7e0ZT09PJiYmRvLdjz/+yFhZWTFnz56V2ubUqVPMtGnTGAcHB2bGjBnMuXPnurjVhKgGj2E66EhMCOmxOHMNTAiRRQEmhMMowIRwGAW4B/H09MTGjRvb/H7jxo3w9PTswhb1PosWLcKiRYvaXScyMhLW1tZd1CLlcOox0tsuODgYAQEB3d2MXm/BggWYOHFidzdDLhRgDhEKhd3dhLfCwIEDMXDgwO5uhlzoFLqHaWhowM6dO+Hi4gIXFxds2LAB5eXlAGRPoevq6rB3715Mnz4dDg4OGD16NAIDA5GdnS1Zp7y8HKGhoZgwYQJGjhyJuXPn4vz58139szrF09MTX375JXbt2gU3Nzc4Ojpi6dKlMuO3b968iffffx9jxozB2LFjsXbtWjx//lyuOqKiouDm5gZnZ2cEBwejqKhI8t2bp9BNTU2Ijo7G7Nmz4ejoCCcnJ/zpT3/C7du3Jeu8fv0a27dvx6RJk+Dg4ID33nsPR44cUXJPyKG7H0ST/+Ph4cHY2toy/v7+zJUrV5j4+HjG1dWV8ff3ZxiGYTZs2MB4eHhI1l+1ahUzbtw45syZM0xaWhpz+vRpxs3NjRGJRIxYLGYYhmGWLFnCzJ07l0lJSWFu3brFbNy4kbGysmJ+/PHHbvmN8vDw8GDGjBnDLFu2jElNTWUSEhIYV1dXxs/PT7LO+fPnGSsrK+aTTz5hUlNTmXPnzjEeHh7MxIkTmbKysjbL/uCDDxhbW1tGJBIxSUlJTGJiIuPh4cF4enoyr1+/ZhiGYb788kvGyspKsk14eDjj6OjIHDt2jElLS2MSEhKY6dOnMy4uLsyrV68YhmGYrVu3Mh4eHkxiYiLz448/Mrt37261UxHb6BS6h9HT08Phw4clI6UMDQ2xYsUK/PDDD1Lr1dfX49WrV9i6dStmzpwJAHB1dcWrV68QHh6O0tJS9O/fHz/99BOCg4MxdepUAMDYsWNhYGDQ4wdn6Onp4eDBg5J2PnnyBJGRkaioqIC+vj727NkDNzc3qS6xo0ePxsyZM3HkyBGsW7euzbL5fD5iYmIkfeEtLS3h7e2Nc+fOwd/fX2b9Fy9eYPXq1VI3v7S0tLBq1Srk5ubC2dkZP/30E9zc3DBr1iwAzfu5T58+MDQ0ZGV/tIUC3MO4u7tLDXP09PSEuro6bt26JbWehoYGYmJiADT/AyssLMTjx49x7do1AM2n4kDzP6TIyEjk5OTA3d0dkyZNwoYNG7ro1yhu5MiRUn9kWq5Ja2trUV5ejtLSUqxZs0ZqG6FQCGdnZ6SlpbVbtpOTk9RAFhsbGwwZMgS3bt1qNcB79+4F0Hw5UlhYiPz8fFy9ehWA9H6Oi4tDSUkJPDw84O7ujhUrVijwyzuHAtzDmJiYSH3m8/kwMDCQzEjyR99//z3+9re/4fHjx9DR0YG1tbVkthHmvz1kW2btTEpKwuXLl8Hn8+Hm5oZt27bBzMxM9T9IQdra0iP8W8bGisViVFZWApDdVy3LsrKy2i27te2MjY1b3ccAcO/ePWzfvh337t2DlpYWhg8fLvkD0LKft2zZgoEDB+LChQvYvn07AMDZ2RmfffaZ1FBWttFNrB7mzX9ETU1NqKiogLGxsdTyJ0+eYMWKFbCxsUFKSgoyMjJw6tQpeHh4SK3Xt29frFu3TjIKa82aNcjIyJD8I+MiAwMDAEBZWZnMd6WlpR2etrYW1NLSUskItj+qrq7Ghx9+iD59+iAxMRF37tzB2bNnMW/ePKn1NDQ0sHz5ciQlJeHatWv47LPPUFRUhLVr13bil3UeBbiHuXXrFhobGyWfk5OT0djYiLFjx0qtl5mZidevX+Pjjz+Werz0/fffA2g+MhQXF8Pd3R2XL18GAAwbNgwfffQR3Nzc8Ntvv3XBr1ENCwsL9OvXDxcvXpRaXlRUhF9++QWjR49ud/s7d+5I5ksDgLt376K4uBjjxo2TWffx48eorKxEQEAARowYITkTuHHjBoDmM4K6ujqIRCLJXefBgwdj4cKFmDVrlsr3M51C9zBlZWVYtWoVFi1ahIKCAvzjH//AhAkTMH78eKnJ6O3t7aGmpoY9e/ZgyZIlqK+vx//8z/8gNTUVQPO0QdbW1hg4cCB27tyJ6upqCIVCZGZm4vr16/j444+76Rcqj8/nY82aNdi0aRNWr14Nb29vVFRU4MCBA9DX10dgYGC724vFYixbtgxBQUGoqKjA3r17YWVl1eo84RYWFtDV1cVXX30FNTU1qKmpITk5Gf/6178ANF+Ta2lpwd7eHgcOHIC6ujqsra2Rn5+Pc+fOQSQSqWQftKAA9zB+fn6oq6vDihUroKGhAS8vL6xbt05mYj5zc3Ps3bsXBw4cwPLly6Gvrw8nJyccP34cixYtQnp6OqytrXHgwAH84x//wP79+1FRUYFBgwZh5cqVWLZsWTf9Qnb4+vpCR0cHX3/9NVasWAFdXV1MnDgRa9asQb9+/drd1sPDA0KhEOvWrUNjYyM8PDywZcsWaGpqyqzbt29fHDx4ELt370ZISAh0dHRga2uLb775Bh999BHS09Ph6emJHTt2ICIiAkeOHEFpaSmMjY0xf/58hISEqGoXAABoPDAhHEbXwIRwGAWYEA6jABPCYRRgQjiMAkwIh1GACeEwCjAhHEYBJoTDKMCEcBgFmBAOowATwmEUYEI47P8D8jwpdsUsVTkAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 250x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Create a long-format DataFrame where each row is one observation\n",
    "e_ranks = [e_rank_list_relu_bias, e_rank_list_relu_no_bias]\n",
    "# model_names = [\"linear + bias\", \"linear no bias\", \"relu + bias\", \"relu no bias\"]\n",
    "model_names = [\"ReLU\", \"ReLU\"]\n",
    "bias = [\"bias\", \"no bias\"]\n",
    "df_data = []\n",
    "for model_name, bias, e_rank_array in zip(model_names, bias, e_ranks):\n",
    "    for e_rank_value in e_rank_array:\n",
    "        df_data.append({\"Model\": model_name, \"Bias\": bias, \"Effective Rank\": e_rank_value})\n",
    "\n",
    "df_e_rank = pd.DataFrame(df_data)\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rcParams[\"font.family\"] = \"sans-serif\"\n",
    "custom_palette = [sns.color_palette(\"Set2\")[1], sns.color_palette(\"Set2\")[2]]\n",
    "\n",
    "# Create the barplot - seaborn will automatically compute confidence intervals\n",
    "fig, ax = plt.subplots(figsize=(2.5,3))\n",
    "sns.barplot(x=\"Bias\", y=\"Effective Rank\", data=df_e_rank,capsize=.3, palette=custom_palette, ax=ax)\n",
    "\n",
    "ax.set_ylabel(\"Effective rank, $\\\\rho(H^TH)$\", fontsize=14)\n",
    "ax.tick_params(axis='x', labelsize=12)\n",
    "ax.tick_params(axis='y', labelsize=12)\n",
    "# remove the x label\n",
    "ax.set_xlabel(\"\", fontsize=14)\n",
    "\n",
    "sns.despine(left=True)\n",
    "plt.tight_layout()\n",
    "save_figure('/relu_nets_bias/', 'Effective_rank_ReLU_MNIST', fig)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "behav-analysis-fmri",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
