{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "1f027494-3f2e-499a-9a2a-e39887f87aa6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import math\n",
    "from copy import deepcopy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "\n",
    "import utils\n",
    "import analysis \n",
    "import trainer\n",
    "import data\n",
    "import models\n",
    "\n",
    "plt.style.use('./mpl.style')\n",
    "torch.set_default_tensor_type(torch.FloatTensor)\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "def train_net(net, dataloader, nsteps=100, lr=0.1, num_checkpoint=0, display=500):\n",
    "    print(lr)\n",
    "    history=trainer.train(net, dataloader, nsteps, lr,\n",
    "                  display=display, \n",
    "                  num_checkpoint=num_checkpoint)\n",
    "\n",
    "    # plt.semilogy(history['train'],'-')\n",
    "    print('train error', trainer.validate(net, dataloader).item())\n",
    "    return history\n",
    "\n",
    "\n",
    "def compute_geoinfo(net, dataloader, batch_size, lr):\n",
    "    ana = analysis.AnalyzeNet(net, dataloader)\n",
    "    ana.compute_grads()\n",
    "\n",
    "    H = ana.hessian_fro()\n",
    "    F = ana.fisher_fro()\n",
    "    mu, alpha = ana.mu(), ana.alpha()\n",
    "    bound = math.sqrt(batch_size/mu)/lr\n",
    "\n",
    "    print('Hessian:{:.1e}'.format(H), \n",
    "          'alpha: {:.1e}'.format(alpha), \n",
    "          'mu: {:.2e}'.format(mu),  \n",
    "          'bound: {:.1e}'.format(bound))\n",
    "    \n",
    "    return H, mu, alpha, bound\n",
    "    \n",
    "def myplot(x, y, label=None, xlabel=None, ylabel=None, title=None, savefile=None):\n",
    "    plt.plot(x, y, lw=3, label=label, color='C2')\n",
    "    plt.xlabel(xlabel)\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.title(title)\n",
    "    if savefile is not None:\n",
    "        plt.savefig(savefile)\n",
    "    \n",
    "def plot_alignment_factors(stats, title=None, savefile=None, loc=0):\n",
    "    iters = stat['iter']\n",
    "\n",
    "    plot = plt.plot\n",
    "    plot(iters, stat['mu'], label=r'$\\mu(\\theta_t)$')\n",
    "    plot(iters, stat['alpha'], label=r'$\\alpha(\\theta_t)$')\n",
    "    plot(iters, stat['beta'], label=r'$\\beta(\\theta_t)$')\n",
    "    plt.xlabel(r'\\# of steps')\n",
    "    # plt.ylim(bottom=0.5)\n",
    "    plt.legend(loc=loc)\n",
    "\n",
    "    if title is not None:\n",
    "        plt.title(title)\n",
    "    if savefile is not None:\n",
    "        plt.savefig(savefile)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "48f6b7b2-c762-44d0-aecf-16c360ca14c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# random feature models\n",
    "d = 10\n",
    "batch_size = 5\n",
    "lr = 0.0005\n",
    "train_loader, test_loader = data.gen_rfm_data(n=200, d=10, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "6b3c97ad-67d7-40dd-833b-59c6e2311c0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.01\n",
      "train error 2.2108060875325464e-05\n",
      "Hessian:7.5e+01 alpha: 9.7e-01 mu: 9.83e-01 bound: 2.3e+02\n",
      "(0, 0)-> took 1.9 seconds, hessian: 7.5e+01\n",
      "0.01\n",
      "train error 4.1905022953869775e-05\n",
      "Hessian:7.5e+01 alpha: 9.6e-01 mu: 1.02e+00 bound: 2.2e+02\n",
      "(0, 1)-> took 1.9 seconds, hessian: 7.5e+01\n",
      "0.01\n",
      "train error 2.0334819055278786e-05\n",
      "Hessian:7.8e+01 alpha: 9.7e-01 mu: 1.02e+00 bound: 2.2e+02\n",
      "(0, 2)-> took 1.9 seconds, hessian: 7.8e+01\n",
      "0.01\n",
      "train error 2.44935854425421e-05\n",
      "Hessian:7.5e+01 alpha: 9.7e-01 mu: 9.94e-01 bound: 2.2e+02\n",
      "(0, 3)-> took 1.9 seconds, hessian: 7.5e+01\n",
      "0.005\n",
      "train error 2.686300604182179e-06\n",
      "Hessian:1.5e+02 alpha: 9.6e-01 mu: 9.84e-01 bound: 4.5e+02\n",
      "(1, 0)-> took 2.0 seconds, hessian: 1.5e+02\n",
      "0.005\n",
      "train error 1.9194892502127914e-06\n",
      "Hessian:1.5e+02 alpha: 9.4e-01 mu: 1.02e+00 bound: 4.4e+02\n",
      "(1, 1)-> took 2.0 seconds, hessian: 1.5e+02\n",
      "0.005\n",
      "train error 8.000350817383151e-07\n",
      "Hessian:1.5e+02 alpha: 9.7e-01 mu: 9.99e-01 bound: 4.5e+02\n",
      "(1, 2)-> took 2.0 seconds, hessian: 1.5e+02\n",
      "0.005\n",
      "train error 6.805821612942964e-06\n",
      "Hessian:1.5e+02 alpha: 9.8e-01 mu: 1.01e+00 bound: 4.4e+02\n",
      "(1, 3)-> took 2.1 seconds, hessian: 1.5e+02\n",
      "0.003\n",
      "train error 1.5370437722594943e-07\n",
      "Hessian:3.0e+02 alpha: 9.7e-01 mu: 1.02e+00 bound: 7.4e+02\n",
      "(2, 0)-> took 2.3 seconds, hessian: 3.0e+02\n",
      "0.003\n",
      "train error 1.4463624609106773e-07\n",
      "Hessian:3.1e+02 alpha: 9.2e-01 mu: 9.70e-01 bound: 7.6e+02\n",
      "(2, 1)-> took 2.3 seconds, hessian: 3.1e+02\n",
      "0.003\n",
      "train error 9.602536010788754e-08\n",
      "Hessian:3.0e+02 alpha: 9.6e-01 mu: 1.00e+00 bound: 7.4e+02\n",
      "(2, 2)-> took 2.3 seconds, hessian: 3.0e+02\n",
      "0.003\n",
      "train error 3.0721952271051123e-07\n",
      "Hessian:3.1e+02 alpha: 8.9e-01 mu: 1.03e+00 bound: 7.3e+02\n",
      "(2, 3)-> took 2.3 seconds, hessian: 3.1e+02\n",
      "0.001\n",
      "train error 1.8709168614350347e-07\n",
      "Hessian:6.0e+02 alpha: 9.4e-01 mu: 9.96e-01 bound: 2.2e+03\n",
      "(3, 0)-> took 2.8 seconds, hessian: 6.0e+02\n",
      "0.001\n",
      "train error 1.3035644030878757e-07\n",
      "Hessian:6.0e+02 alpha: 9.3e-01 mu: 1.01e+00 bound: 2.2e+03\n",
      "(3, 1)-> took 2.8 seconds, hessian: 6.0e+02\n",
      "0.001\n",
      "train error 8.928589778633977e-08\n",
      "Hessian:6.0e+02 alpha: 9.2e-01 mu: 9.99e-01 bound: 2.2e+03\n",
      "(3, 2)-> took 2.9 seconds, hessian: 6.0e+02\n",
      "0.001\n",
      "train error 9.528863387231468e-08\n",
      "Hessian:6.0e+02 alpha: 9.2e-01 mu: 1.02e+00 bound: 2.2e+03\n",
      "(3, 3)-> took 2.8 seconds, hessian: 6.0e+02\n",
      "0.0007\n",
      "train error 5.008696124519929e-09\n",
      "Hessian:1.2e+03 alpha: 9.5e-01 mu: 9.95e-01 bound: 3.2e+03\n",
      "(4, 0)-> took 3.9 seconds, hessian: 1.2e+03\n",
      "0.0007\n",
      "train error 1.5554848076249073e-08\n",
      "Hessian:1.2e+03 alpha: 9.0e-01 mu: 9.90e-01 bound: 3.2e+03\n",
      "(4, 1)-> took 3.9 seconds, hessian: 1.2e+03\n",
      "0.0007\n",
      "train error 1.5227845651111238e-08\n",
      "Hessian:1.2e+03 alpha: 9.6e-01 mu: 1.01e+00 bound: 3.2e+03\n",
      "(4, 2)-> took 3.9 seconds, hessian: 1.2e+03\n",
      "0.0007\n",
      "train error 5.0089288272658905e-09\n",
      "Hessian:1.2e+03 alpha: 9.6e-01 mu: 1.01e+00 bound: 3.2e+03\n",
      "(4, 3)-> took 3.9 seconds, hessian: 1.2e+03\n"
     ]
    }
   ],
   "source": [
    "widths = [400, 800, 1600, 3200, 6400]\n",
    "learning_rates = [0.01, 0.005, 0.003, 0.001, 0.0007]\n",
    "ntries = 4\n",
    "\n",
    "model_size = torch.zeros(len(widths))\n",
    "hessian_fro = torch.zeros(len(widths), ntries)\n",
    "mu = torch.zeros(len(widths), ntries)\n",
    "bounds = torch.zeros(len(widths), ntries)\n",
    "alphas = torch.zeros(len(widths), ntries)\n",
    "\n",
    "for i, m in enumerate(widths):\n",
    "    lr = learning_rates[i]\n",
    "    for j in range(ntries):\n",
    "        time_st = time.time()\n",
    "        \n",
    "        net = models.build_rfm(d, m)\n",
    "        train_net(net, train_loader, nsteps=20000, lr=lr, display=40000)\n",
    "        H_i, mu_i, alpha_i, bound_i = compute_geoinfo(net, train_loader, batch_size, lr)\n",
    "        hessian_fro[i,j], mu[i,j], alphas[i,j], bounds[i,j] = H_i, mu_i, alpha_i, bound_i\n",
    "        \n",
    "        print('({:}, {:})-> took {:.1f} seconds, hessian: {:.1e}'.format(i,j, time.time()-time_st, H_i))\n",
    "    model_size[i] = utils.num_para(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "cacbc6a4-052e-4258-b4c1-056e76507040",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAAEvCAYAAAByhLuPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjGklEQVR4nO3dXXCc133f8e9/AYIUSZFLUFTq2JaopR07yYwTgVAS140dW4B74ZukAaUbT6aeqYAm06ukAaWLXjQXlcF0cpHpJAWUmV40nY5IOr7oeMYx4NfWdlyBkOPa4yoxVpSlxoopgktJhEgQ2H8vzlng4YN9BRbAAfD7zFDA87pnd4Xfnj3Pec4xd0dERHZeYacLICIigQJZRCQRCmQRkUQokEVEEqFAFhFJhAJZRCQRvTtdAJGtZGbjwDAwFFddBhaA/szPaXefanH8YFw1E4+r6QdK8R/ufqLO41Zq69so7yVgBJgDyu5+rq0nKnuCqR+y7AdmdgOYdffhOtsmCcE57O7lTo+P20eBSeBM9hzx3KPAOXe/3KKMReA5QiCfaVQW2bvUZCH7xQJQqbfB3ceAMjDd4viGYg17CijmNs0TatVjbZRxKFOGpo8ne5MCWSSYBkpmNrKJc0wSmy7qrB+KNeBmas0osk8pkEWCWvNAvUBti7vPEUI1v77WVPFEo2PNrATMbvSxZW9QIIsEj8WfM50cZGZDuVWN2n0v07zZYigGuuxjCmTZ92JTwihwYQOheM9FPndvFOiTwECsCYvUpUCWfcvMirF3xCXgKXc/3+nxwEA7+8agrhB6UOTPM0CHNXPZm9QPWfaTUgzgmrOENuPzbdaMB2M3NghtxUN01u57kdBscSF/3kb9oGV/USDLflLOB1+s5V4xs0l3zwdl3mzsIlc77gmgkxs3JoFRMxvIfQCoZ4UAarKQfc7dK4Ra60Qn7bvuXonh3nabcwzhMvBkbV28KKiLeQIokEVgrWdEvsdEO+Y73P8y4QJiTUl35EmNAllkrcngbKcHbqDtdxIoZrrLqblCVqkNWWTNlndJc/eymZWBc2YG6l0hGaohy74X25EhE8ixS1w3AvpknXWThAuCpcxjiyiQZe+LPSJKhG5vxQa7TeW2D7F+mM11t0W38bj1xsaoDUJUr7mio8eQvUXDb8qeFsclPpNbPV1vKEwzmyAE9/MQxqCIxz/GWrBeBhZq3d9aPO5YPN8c8Gz2Mc3sUnas41gbP0/4IKgdM9vqcWRvUSCLiCRCTRYiIolQIIuIJEKBLCKSCAWyiEgidGNIAw888ICfPn16p4shInvMlStX3nD3U/W2KZAbOH36NLOzmlFHRLrLzF5ptE1NFiIiiVAgi4gkQoEsIpIIBbKISCIUyCIiiUiul0VmSvaT7cwCHAdxKRNHycoOGJ6Z9wziADOdziwsIrJdkgrkOItCkfWjczXafwJ4oTaKlplNmNlIZlStCcKMwpW4/YqZjbcxmaWIyLZLqsnC3WdimFbaPGQ0N4zi84QhD2sGuXeetDJhKEURkeQkVUPuhJkN1FldIRPA7p6fI22AUGsW6YrTT38BgKuf/dQOl0T2gqRqyB3qZ/2MCw0njIxtzTMbmJRSRGRb7OZALjbakJ2mJ86NVpt2vemU7WY2amazZjZ77dq1rhRSRKRduzmQK6yff2zdfGTuXnH3qXghb9jMLjU6Ydxv0N0HT52qO/aHiMiW2c2BvMD6WnIRQgjHmvF4bvs09SedFBHZcbs2kN19jvW9MfqBmfj7IDDRZJZhEZGk7KpANrOSmWVruBdzy8PAJIQudGT6IGe2qw+yiCQpqW5vsSvbELFZIdMzYi7uMkII1csA7j5mZuPxhpISMJ/rl3w502xxkjD9uwJZRJKUVCDH4J2jQS02humFOusana/c6FwiIqnZVU0WIiJ7mQJZRCQRCmQRkUQokEVEEqFAFhFJhAJZRCQRCmQRkUQokEVEEqFAFhFJhAJZRCQRCmQRkUQokEVEEqFAFhFJhAJZRCQRCmQRkUQokEVEEqFAFhFJhAJZRCQRCmQRkUQokEVEEqFAFhFJhAK5C04//QVOP/2FnS6GiOxyCmQRkUQokEVEEqFAFhFJhAJZRCQRCmQRkUT07nQB8sysCIwCJ939fBv7jwNloB/A3afqnAvgMWA6u11EJCVJBbKZDQFF4Eyb+08AL7j75dqymY3UloFnsqFuZvNmhkJZRFKUVJOFu8/EMK20echoJnwBngfGYLV2XMrtPwm0rHWLiOyEpAK5E2Y2UGd1BRjKLA+ZWSm3PR/SIiJJSKrJokP9wEJu3eqyu1eAE7ntw8DM1hZLRGRjdnMgFxttMLNiDOR71hFqz483OW6UeBHwoYce6kYZ94XabeNXP/upHS6JyO62a5ssCM0P/bl1+eWs54Bz7j7XaAd3n3L3QXcfPHXqVBeKKCLSvt0cyAusryUXYbW5YlXsGjfp7mquEJFk7dpAjjXdSm51P7k2YjMbAeZqYRy71omIJGdXBbKZlWLA1lzMLQ8TurbV9h8ihPSsmRVjj4t6vTNERHZcUhf1Yle2IWAkLo8DM5l23xFC6F4GcPcxMxuPwVsC5jM3iRSB6XjcakjXjhURSU1SgRyDdw640GD7hfy2uK7evhXAulxEEZEts6uaLERE9jIFsohIIhTIIiKJUCCLiCRCgSwikggFsohIIhTIIiKJUCCLiCRCgSwikggFsohIIhTIIiKJUCCLiCRCgSwikggFsohIIhTIIrJhp5/+wuokt7J5CmQRkUQokEVEEqFAFhFJhAJZRCQRCmQRkUQokEVEEqFAFhFJhAJZRGQDtqIPtgJZRCQRCmQRkUQokEVEEtG70wXIM7MiMAqcdPfzbew/DpSBfgB3n9rM+UREdkpSgWxmQ0ARONPm/hPAC+5+ubZsZiOZ5Y7OJyKyk5JqsnD3mRimlTYPGa2Fb/Q8MLaJ84mI7JikArkTZjZQZ3UFGNrmooiIdMWuDWRCm/FCbl1+WURk19jNgVxstCFeyOuYmY2a2ayZzV67dm2j5RIR2ZBNB7KZHTOzXzazT8Sfp7tQrnZUiD0rMvLLHXH3KXcfdPfBU6dObeZUIiId67iXRQzcEeCThPbaG/FfJe5SNLN+4DgwB8wAk+5+dfPFvccC62vJRQB3ryAissu0Hchm9igwATghZM+7+4ttHDMITJnZCeApd//uxou7xt3nzKySW90fyyYisuu0Fchm9p8JteBz7n6z3ZPHwH4ReM7MjgPPmNmYu//uRgprZiVgINPV7WK23zEwDExu5NwiIjutaSDHEJ0AJtz95c08UAzyp83suJn9OaGG/Wbu8QYIzSAjcXkcmHH3ubjLCCF0L8dzjpnZeLwBpATMZ/slt3E+EZFktKohP+Hu/7qbDxiD+XfN7Cngudy2OUK784UGx17Ib4vrGj1W0/OJiKSkaS8Ld3+u2fbN2Mpzi4jsRru5H7KIyJ7SSS+LTwADwMm4ah4ou/tXtqJgIiL7TctAjl3XvhwXFwhdy4px2c0MYJrQ1/jzW1BGEZF9oZ0mi1HgEXfvd/f3xZ8FQm+HZwAj3CTyOTO7bmb/YQvLKyKyZ7UTyOV6fY/d/cuxh8NMDOhPErqjPR2D+eNdLquIyJ7WTiCfNLNfarLdYXXs4bEYzhOEGvMfdKOQIiL7QctAdvenCeHadlNErDmXgF+NFwNFRKSFdru9DQKfrLURm9mxzDard4C7V9z9CUJbs4iItNBWIMdwHSQ0RTwN3DCz/21m/5bYZCGyn9y6s8yfTL+0ujzwR1/iT6Zf4tad5R0slex2HQ2/Wbt12cxGgXHiLclmtkK4RXmW0D+5Quga9xhhRmiRPePWnWV+68++ySvXF1fXLSzeZfLrZb74/df5/O99hCMHk5o/uOtu3Vlm8hvzq8sDf/QlPv3hhxn76Jk9/9y30oZeOXefIgypOQA8QRjA5wxwNu5SIYTzpLt/rgvlFEnG5DfmeeX6IneWq/esv7Nc5ZXri0x+Y57fH/7ADpVu6+kDaets6lXLDN6zL+33WsJuff7Ly1WuLy5x/e07XL+1xPW3l7ixuERlcYmb7yzz1u27vHV7mVt3lnl7aZl3lla4fXeF23erLK1UqSzebXjuO8tV/vTLP+I/feVHmBkGmIFh4adBwcLvhbiuUAj79ZiF323t94JBT8EomNFTsLDejJ6eteWegtFbCPv3xn89BaO3p0BvoUBPAXp7ChwoGD09BQ70GAcKBXoKxoFe40Dc1ttToK+3QG/B6OvpCfv1FDjQW6Av/t7XW+D5F17l5TducXfF1z33/fCBtJVaDb95LD9EZrds5bm3w36vJWzX869Wq7x1e5lrb2XC850lKot3qSwu8ebtEKBv31lm8c4Kt5aWV8PzznKVu8tV7q5UWa46K+74Nl3xqDrc+2D741LLneUqf/E/X+bc4Ht574nDO12cXafVX8yTZvZCt2b5qDGzXyb03PiLbp53OzX72lq+dovfv/hdPvWhd1Eg1GZqtaGegmEWaj7GWu2nYEahUKs9Gb0F4n5h30KtxlQ7phDOWegJNa1CwegpQIEChULt3OH3AtBb+73QnfGkmj3/q2/c4t//jx/wz3/xn3D91hI3F5eovLPMzcUl3roTAvTWnRUWl5ZZXKpyZznUQJeWq9xdcZarVVaqHkJtC9Vqor2Z2uHB3gKHDhQ4dKCHw309HDnYy9GDvdx/6ADHDvVy/L4D/NnX5llcWml43mOHevkvn/kVlldCjXppucryinNnZYW7y+HD4e5ylaUVZ6Uaft5dCfvUPjyWq87ySvgwWcksr/5edVbiv+WqU11dV6VahRX3+BrGn1WnCvf8Xo3b3aHq4fX22k/Ceq/ze6v3ZXFphV+f+CoFg+P3HeCh/sP8wruO8WtnTvIbP3eK44f7uvo+7iVNA9ndnzOzPzSzIXf/j914QDP7w3Dq7pxvp/zlt19ZF0Y1y1Xnr3/wj/z1D/5xm0u1dfJ9G5v9TS6tOBdnX+Pi7GtdffzaV/XegtHXW6Cvp8Ch3gKH+nq4r6+HI30hPI8e7OXYfQc4fl8vx+/ro/9IHycOH+DU/Qc5efQgJw/30du78Q+mpZUqk18v133/D/YW+JcfOc3Zh09s4tmmbeCPvsRCk2abmqrDjcW73Fi8yd++dpP//sKrABzoMU4e6eORB47yofcc45+9/xS/+shJ+jbxnuwVLb9Tuvsfm9mjZvYl4EvAVKdNDbHf8hjh4t/5bte4d8KNNv6HPBj/B/P4H8/EWPbbrN/zi68LO1/3y/Z/Ad7I4x3oMXoLoc0y1D57OHSgwH0Hejh8sJejB3s4evAA98ea5/H7DnDiSB/9hw9w4shBHjx6kBNH+jh6KK2mn7GPnuGL33993TeEg70FHj55mLGPntnB0m29T3/44aYfSGMfK/FvPv5+vvPydf7X31/je6+9yctvvM31W0vcXXHurjivv3mH19+8w7fL15n8RpiM6L4DPTx47CDvO3WUgYdP8LGfO8UvvOv+rn2r2w3MO2hUM7PfJgTrI4TJROcJ3doqhJHgYG00uBKh58Vw3G/S3f+qS+XecoODgz47O9twe6taQv+RPub+XXr3xFTjV9rlajV+bQ0fAeGrLFSpslKNX2dxqiuEn7mvvE9Mfps3bzfuc5vq8++W2gXNP/3yj4DwfD/9aw8lf0GzG7LXD+p9IDW7fnBzcYmv/d01vlO+zg/+4U1evfEOlcWlhs0gBhw91MvPHr+PD77rfh57+AQf//kHeXdx59unTz/9BQCufvZTHR1nZlfifR3rt3USyJkTHifUdh8jBG8x/oS1cJ4DXiAMPtT2xKipaBXIfzL9Ustawl6+0rzfn3/NRv8od7tufyC9en2Rr/7dT3nh6gIvvf4WP6nc5q0mN9n0mHH8cGif/sWfPcaHSyf52AdOcf+hAxt+Tp3aikDeaD/km8Dn4r99ab9/bd3vz3+/O3Kwl98f/sBqIG/229B7Tx7mdz58mt/58OnVddVqle//w5t87aVrfPfVG/zop2/z07fucPtulRV3Fm4tsXBrie++WuG/fefHQGgme+DoQR554Ai/9J4iv/7+B/iV0/2bumawnfb2d6stdORgL5//vY/s26+t+/35y9YrFAp86D1FPvSe4j3rl5arfGv+Db75ozf43ms3ufrGLRYW19qnf3LzNj+5eZtvzV/nz78e+snfd6CHnzl2kPc/eJSzD/fz0Q88wAd/Jr326S35q4kjvM3u5n7G7eh2LWG32e/PX3ZGX2+B3/jAg/zGBx68Z31lcYmvvXSNb5ev88OfvMmrC4vcfOcuVYd37q5w9foiV68vMv3Dn/LZL4b26fsP9fLu4n188F3HeOx0P4///IP8zLFDTR9/K2+I2vDRLW7suAGMmVk/8OxeD2YR2XnFw3385qPv5jcfffc961+5fouv/N+fcuWVG6F9+uY7vH1nBQfevL3Mm6+/xQ9ff4vPv/j/4POhffrE4QM8dPIwv/izx/mn7zvJr7/vFEcP9W75DVGbifOr8eLeHGFOvRlirdjdXwReBDCzZwlTPYmIbLuHTx7hMx95hM985JHVddVqlb997Sbf+PtrvPjjCvM/fZtrb93h9nJon37j1hJv3Fpi7scV/uvfvAJAX7x56NbS8ro7Prt12/hmAnmQ0AXuUcKQnE8TJj0tE8OZUFMuNTyDiMgOKBQKPPrQCR596N4beG4vLfOt8gLf/NEb/J/XKly9vsjCrSWWqx7uulypfzMYhFD+y7/58c4EsruXgfO15Tg79TChO9yThLB24NyGSyciso0O9fXyiQ8+yCc+eG/79MLbS3z1pX/kDy59r+nxNxaXNvX4XbvE6O4vuvsFd/+ku/cThuX84910M4iISD39R/v47bPvpf9w837OJzY5TseW9U1y98tmdsXM/pW7tz2IkJkVgVHgpLufb7E7ZjZOuFuwPz7uVCfbRUTa1eq28U//2kObOv+WdsJz95eBtkdZMbMh1ga7L7ax/wRQdvfLMWjPmNlIu9tFRDox9tEzPHzy8Oo4NTXduiFqw4FsZn9uZs+a2b/ITXqa1/a92e4+4+6XCbdft2M07l/zPKHtut3tIiJtq90QNfaxtb4K/Uf6GPtYqStjgG/m6DdZu6iX7V0xTegKt0Co7Q4DXR9qM04flVeJj9lyu4jIRmzlDVEbriG7+3l3LxDm0XsGeJnQu+IyYXS3G8AloNyiBr1R/ayNMFez0MF2EZGkbLoNuU7vijPA7wJ/BdwkNBHcMLO/j80cv9WlgC422hAvDLbaXm/9qJnNmtnstWvXNlk8EZHOdP2inru/7O5T7n4uF9AvEmrQnwOudOGhKsSeExn9HWxfJ5Z70N0HT506tekCioh0YsuH5Io9LabiP8zsEdroQdGGhTrnKcbHrJhZ0+1deHwRka7a9jESY0B34zxzZlbJre4nXFhsuV1EJDVpDQbagpmVcv2IL+aWh4HJDraLiCQjqUA2s4F4Z90IMGRm47nuayNk+hG7+xhQMrMhMxsF5rP9jlttFxFJSVLTOrj7HKEP84UG2y/kt8V1zc7ZdLuISCqSqiGLiOxnCmQRkUQokEVEEqFAFhFJhAJZRCQRCmQRkUQokEVEEqFAFhFJhAJZRCQRCmQRkUQokEVEEqFAFhFJhAJZRCQRSY32JiK7y9XPfmqni7CnqIYsIpIIBbKISCIUyCIiiVAgi4gkQoEsIpIIBbKISCIUyCIiiVAgi4gkQoEsIpIIBbKISCIUyCIiiVAgi4gkIsnBhcxsHCgD/QDuPtVi/wlgHjgDzOf3b7VdRCQFyQVyDM8X3P1ybdnMRmrLdfafBibcfSYuT5rZUGa56XYRkY3YipHuUmyyGM2F7/PAWL0dzawE5MN1GjjfznYRkZQkFchmNlBndQUYanBIvf3Lmf1bbRcRSUZSgUxoM17IrcsvZ80BmFkxd47aulbbRUSSkVogFxttqBeg7l4mhG4ps3qg3e11HmPUzGbNbPbatWvtl1pEpAtSu6hXIdZgM/LLeY8Dz5jZIKE2XQZw90qb21fF3hdTAIODg76RJyAislGpBfIC62vJRagfoJn1qxfpzGyI2FTRznYRkVQk1WTh7nOEWnJWP9Cwi1rsSZF1Dphsd7uISCqSCuToopmNZJaHyQVsbvuVWu+M2M48mLvxo9V2EZEkpNZkgbuPmdl4bFooEe6sy/ZLHiGEdG3dU0ApthGfcfezuVO22i4ikoTkAhnA3S+02HYhs1z3Dr52t4uIpCLJQJbdZStuIRXZj1JsQxYR2ZcUyCIiiVAgi4gkQoEsIpIIBbKISCIUyCIiiVAgi4gkQoEsIpIIBbKISCIUyCIiidCt012gW4dFpBtUQxYRSYQCWUQkEQpkEZFEKJBFRBKhQBYRSYQCWUQkEQpkEZFEKJBFRBKhQBYRSYQCWUQkEQpkEZFEKJBFRBKhQBYRSYQCWUQkEQpkEZFEJDkespmNA2WgH8Ddp1rsPwHMA2eA+fz+8XyVuFh09wvdLrOIyGYlV0OO4Vp298sxWM+Y2UiT/aeBaXefcvfzwFkzG8psH3f3C3H7FDATA1pEJCnJBTIw6u6XM8vPA2P1djSzEjDk7jOZ1dPA+czyk9lj3H0OeKxLZRUR6ZqkAtnMBuqsrgBDddYD1Nu/nNt/wcwuZR5jlBDyIiJJSSqQCW3GC7l1+eWsOQAzK+bOkV03BgyZ2Y3YVLGQq4GLiCQhtUAuNtqQC10A3L1MCOVSZvVAnX2eBWaBCZo0V5jZqJnNmtnstWvXOiq4iMhmpRbIFWINNyO/nPc48GQM0xFCkwXuXgEws0lgxt2HgWFgNNuEkRUv/A26++CpU6c2/ixERDYgtW5vC6yvJRdhLWDz4vrVi3ixh0WtKWMAqMQLebj7jJk9Arzc3WKLiGxeUoHs7nNmVsmt7gdm6uwOhJ4WsVmi5hwwmTn2eu4xKmbW8Hwinbj62U/tdBFkD0mtyQLgYq7f8TBrAYuZlXLbr9R6Z8R25sHajSGxO9xw9uRxn2yAi4gkIakaMoC7j5nZeGx6KBHuvMv2ihghhGxt3VNAycwGgTPufjZ3yrHMnXy1xziPiEhizN13ugxJGhwc9NnZ2Z0uhojsMWZ2xd0H621LsclCRGRfUiCLiCRCgSwikggFsohIIhTIIiKJUCCLiCRC3d4aMLNrwCs7XY597AHgjZ0uhLSk96lzD7t73cFyFMiSJDObbdRXU9Kh96m71GQhIpIIBbKISCIUyJKqpjONSzL0PnWR2pBFRBKhGrKISCKSG35T9r44JvUTcfEMtB4SNU5QWyZO6VUb81q2TqfvU5zR/SxQmyLtHDCRm0BCmlCThWy7OM/h+cy8h1eA5939QoP9J4AXauNi55dla2zgfRolTCRcJEyj9lRt+jRpj5osZCcMAkOZ5TJNZgMHRnPh+zwwthUFk3t0+j7h7ifc3dz9rMK4c2qykG1XZ1aXAULNap3a9Fw5Fe4NCtkCnbxP0h0KZNlRsW14pkmbcD9hNvKs/LJssTbep9p+o4T3R239G6BAlh2Ru2A032TXYrNz1No3ZWt08D4BzAKV2kU8M7tkZgtq62+fLurJjjOzacIf8rk624aAS+5+IrOuRAiHEwrk7dPsfWqw/zjwZJ2mD2lAF/VkW5lZMf6hZk0TZhOvZ4H1teQigMJ462zgfap9eGaVCe3O0iYFsmy3QWAifhVuKV6pr+RW9wMz3S2W5HT0PsVvLdN19lcf5A4okGVbufsMmb6t0TCw2rfVzEpmlq2JXcwtDwOTW1rQfa7T9ym2G+f3fxL1yuiI2pBl28XaVC1gTwLXszcbxK/Kw+4+nFs3B5RAV++3Q6fvU5395/U+dUaBLCKSCDVZiIgkQoEsIpIIBbKISCIUyCIiiVAgi4gkQoEsIpIIBbLsKWY22u7dZbL99P40p0CW2h1X02Z2I84SsSvFGxVmUx7jIr7O83H2jU6P3fXvU7xRRHfvNaBAFty9HO+2miWOY7vb1AayT32Wivg6b2g4yq18n+JgQvNxPOOtNrFbP1C2mgJZsnbzQDATjeZ6S1CrcYVb2c3vU23ci0qD2WD2NQ1QL7teHOBmeqfLsZvFZp4z2/iQzxJmpx5uteN+ohqy7AVjgAax2UUyM1mXdrgoSVEgy64Wr9iXUr6QJw01HfB+P1KThbQlXuwpEgaLLxKm8mlYK43tg9n51UqE2SPmauu6ZIg6g9XHx7tEGK7zInCetbnhzgK4+1jct3Yhq0iY5v6pegHf6WsQjxmJZagd0/Sio5lNsNbGfBaY3OyFyvhaDLE2+Wgxlmd1vrs4PdMgcLH2usT1Hss8w71t3xP5fTdQ/pl4nt3S9r/13F3/9A93hzDo+6U66y8BQ7l1Iw32HYn7D2R+H4/LA/WO2WSZJ4DxJtun4/Mar7N+ghBUxcz6ceDKZl6D3OuZf9yB+Nj1HuMKMJBZLhLCbaDOedt6HeM5Jhs8n5E6r8ml3LHTdY4dB25kX7dOyp87xrf7//OU/6nJQpqqzavmYQaJVR5qVqVsN6lYK75EqGHOxX2eB8bi8py3OUFmB0o073VQJtSM8zXZaWAU1s3NNwcMZG9e6OQ1yBwzQgjwC7lj5qhTS441y7JnapOxXJPAc02eXytDrJ8CC8JFtbz869hPeD+z5SwRPsju+RaxifJXdKPIGgWytDJBCNV6Jrm3k/+ThCaJSmbdDCG08hNgdkuR+oGTVfb1TRBlQg0v39xRC6VsP99OXoPsMY36G1+vs26c+j1FZtjcRKFlYDTfxSwGZztNR7O55UvA5fhhlLXR8i8QZ4ERtSFLE5kr4I3+cMtA0cyKdQJvu5QIf9TNNCp/pc66e861idegRJv9jTNheSY3dyDEDwYzK/kG2t7dfc7MZoArZlYmfEhMu/uMt2ibzj9e/CZQAh7vYvkr7NKbkbaCAlmaqf2hVVrsVyJ8DV9tBsgYJFz8qnfhrUhoTjjnmfnzMtvHM49dzH/93yadvgbZEG/1QZE9FuD5BiG5qS597n4uBuWThPdnPIb0uXY/SONzmmxwzGbL3+7rtOepyUKaqdVoii32q8BqG+uMmU3GW3EHCL0bHs8fELc9Ec+9roaUabed8tCTYa7B7baVesd3UUevAdxTs2y3XFt2u3ftw8HdL7v7OXc/QbgBpJ/OxpRY11SRaYbaTPmLtP6w2zcUyNJQprYz2GCXel3bJgkhPAT0u/twvVpTvMA3ReOmgGfI1Kxi2NcbZ2GB1mG5YZ2+Bhll2rzzLXNs3Xb2TV70GqrTflwGzjV6vDqPX2uqeCq3qTZ+yGbK37+Rppi9SoEsrVwg/PHW8yQhfGsGCE0LlVgjW9dM0Y4Y7HXbpetcHJxj6y8KdfIa1NS61NXzGOs/RM7Hc9Wz2QF/1p03hmCl1YGZD9l6fbOzHzgbLX+xVRn2EwWyNOXu52G1G9eq2KSw4PfeGDEHPNOF22EbHV9h/R/wC7SuieaPgcbNCevWd/ga1I6ZAhbyXeLiB0ox/zixfXyh1lST2X+AzQ8mNJp/T2I56jUBFXPLdXtVxONXn8NGyh+3JT0633YzD52zZR+LXymfY61GNwOcz36VzPyhVWhyl1ps583XiMqZc1Zy+48Az7j72cy6IUJPAMvtO08Y1W0qs65IuMniTG7fEvfenbf6nGIZhwjBf5lwMepyfI5PEu8ojOsvZM7Z1muQK0ftmDIhwMrx/BPxMc5nv0nE/U8SemgskLkg2s77VOfxR+Jj1gI5G6JTTc77VFy+RPiGUOsxUiTU8EeAKV9/p17D8tcp2yhwpvaBJwpk6aJ4c8ALuQs/RUIYjAFPxItK2WM6CeQbhACayq2fpoMeA5KG+L6NqQ15jZospCtisBbzX21je/JcrEnN1umnWk+jP9Big231auWSsPhBrQt6OQpk6ZZ+2rhjrp0T+doA5uvakut9/Y0fAhpXd3d5hvW9NvY9BbJ0RWxGGGh0i3RcX6pzy22ji2vPkumlEGvWzdprz+cvKEmaYu242OpOwf1IbcjSVbkhKmuKhPEksm3LJcKFodpFtAusb38eZ61W/Viriz9x/5a3BMvOMrPJ/MVACRTIsqfED4SLusCXJr0/zSmQRUQSoTZkEZFEKJBFRBKhQBYRSYQCWUQkEQpkEZFEKJBFRBLx/wFReD0J5IYryQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = model_size.log10()\n",
    "y_mean = mu.mean(dim=1)\n",
    "y_std = mu.std(dim=1)\n",
    "plt.plot(x,y_mean,'-o', color='C0', markersize=8)\n",
    "plt.errorbar(x,y_mean, yerr=y_std, label=r'$\\mu(\\theta)$')\n",
    "\n",
    "\n",
    "# y_mean = alphas.mean(dim=1)\n",
    "# y_std = alphas.std(dim=1)\n",
    "# plt.plot(x,y_mean,'-o', color='C1', markersize=8)\n",
    "# plt.errorbar(x,y_mean, yerr=y_std, label=r'$\\alpha(\\theta)$')\n",
    "plt.xlabel(r'$\\log_{10}$(model size)')\n",
    "plt.ylabel(r'$\\mu(\\theta)$')\n",
    "# plt.legend()\n",
    "plt.title('RFM')\n",
    "plt.savefig('../figs/alignment_rfm_modelsize.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "25ea5d60-3e1a-466c-96c4-bac83a736336",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2.6021, 2.9031, 3.2041, 3.5052, 3.8062])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ac14960e-414e-4e06-bd94-193d6bba4b0a",
   "metadata": {},
   "source": [
    "# Linear network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "ee72111e-8158-464d-bac5-5b45a093828c",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 50\n",
    "batch_size = 5\n",
    "d = 100\n",
    "\n",
    "train_loader, test_loader = data.gen_linear_net_data(d=d, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "8b983824-9aa9-4f11-9229-fb144e28fedd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3\n",
      "train error 5.01096856540882e-16\n",
      "Hessian:2.3e+00 alpha: 8.7e-01 mu: 1.05e+00 bound: 7.3e+00\n",
      "(0, 0)-> took 1.1 seconds, hessian: 2.3e+00\n",
      "0.3\n",
      "train error 4.73792325969666e-16\n",
      "Hessian:2.4e+00 alpha: 9.4e-01 mu: 1.03e+00 bound: 7.3e+00\n",
      "(0, 1)-> took 1.1 seconds, hessian: 2.4e+00\n",
      "0.3\n",
      "train error 1.057110458942286e-15\n",
      "Hessian:2.0e+00 alpha: 9.4e-01 mu: 9.75e-01 bound: 7.5e+00\n",
      "(0, 2)-> took 1.1 seconds, hessian: 2.0e+00\n",
      "0.3\n",
      "train error 6.098640396604515e-16\n",
      "Hessian:2.2e+00 alpha: 8.9e-01 mu: 1.09e+00 bound: 7.1e+00\n",
      "(0, 3)-> took 1.1 seconds, hessian: 2.2e+00\n",
      "0.3\n",
      "train error 3.3285467269470056e-16\n",
      "Hessian:2.2e+00 alpha: 8.8e-01 mu: 9.72e-01 bound: 7.6e+00\n",
      "(0, 4)-> took 1.1 seconds, hessian: 2.2e+00\n",
      "0.3\n",
      "train error 2.6420299642687434e-16\n",
      "Hessian:1.9e+00 alpha: 8.9e-01 mu: 1.04e+00 bound: 7.3e+00\n",
      "(1, 0)-> took 1.2 seconds, hessian: 1.9e+00\n",
      "0.3\n",
      "train error 3.188814878012016e-16\n",
      "Hessian:2.5e+00 alpha: 9.0e-01 mu: 9.91e-01 bound: 7.5e+00\n",
      "(1, 1)-> took 1.2 seconds, hessian: 2.5e+00\n",
      "0.3\n",
      "train error 5.525834366023794e-16\n",
      "Hessian:2.0e+00 alpha: 8.8e-01 mu: 9.58e-01 bound: 7.6e+00\n",
      "(1, 2)-> took 1.2 seconds, hessian: 2.0e+00\n",
      "0.3\n",
      "train error 4.860740919466165e-16\n",
      "Hessian:1.9e+00 alpha: 9.4e-01 mu: 1.00e+00 bound: 7.4e+00\n",
      "(1, 3)-> took 1.1 seconds, hessian: 1.9e+00\n",
      "0.3\n",
      "train error 4.582491655084312e-16\n",
      "Hessian:2.4e+00 alpha: 9.1e-01 mu: 9.61e-01 bound: 7.6e+00\n",
      "(1, 4)-> took 1.1 seconds, hessian: 2.4e+00\n",
      "0.3\n",
      "train error 2.37869880297261e-16\n",
      "Hessian:2.4e+00 alpha: 7.9e-01 mu: 1.04e+00 bound: 7.3e+00\n",
      "(2, 0)-> took 1.2 seconds, hessian: 2.4e+00\n",
      "0.3\n",
      "train error 3.6403999059288054e-16\n",
      "Hessian:2.3e+00 alpha: 8.4e-01 mu: 1.00e+00 bound: 7.4e+00\n",
      "(2, 1)-> took 1.2 seconds, hessian: 2.3e+00\n",
      "0.3\n",
      "train error 3.4916977133845966e-16\n",
      "Hessian:2.2e+00 alpha: 7.8e-01 mu: 1.09e+00 bound: 7.1e+00\n",
      "(2, 2)-> took 1.2 seconds, hessian: 2.2e+00\n",
      "0.3\n",
      "train error 3.41996646354857e-16\n",
      "Hessian:2.3e+00 alpha: 9.0e-01 mu: 9.74e-01 bound: 7.6e+00\n",
      "(2, 3)-> took 1.2 seconds, hessian: 2.3e+00\n",
      "0.3\n",
      "train error 4.866639445152607e-16\n",
      "Hessian:2.2e+00 alpha: 9.2e-01 mu: 9.62e-01 bound: 7.6e+00\n",
      "(2, 4)-> took 1.2 seconds, hessian: 2.2e+00\n",
      "0.3\n",
      "train error 4.326186367387857e-16\n",
      "Hessian:2.4e+00 alpha: 6.4e-01 mu: 1.15e+00 bound: 6.9e+00\n",
      "(3, 0)-> took 1.4 seconds, hessian: 2.4e+00\n",
      "0.3\n",
      "train error 2.379978352118556e-16\n",
      "Hessian:2.3e+00 alpha: 8.2e-01 mu: 9.69e-01 bound: 7.6e+00\n",
      "(3, 1)-> took 1.4 seconds, hessian: 2.3e+00\n",
      "0.3\n",
      "train error 2.368464262689614e-16\n",
      "Hessian:2.5e+00 alpha: 7.1e-01 mu: 1.19e+00 bound: 6.8e+00\n",
      "(3, 2)-> took 1.5 seconds, hessian: 2.5e+00\n",
      "0.3\n",
      "train error 3.3861396734043776e-16\n",
      "Hessian:2.5e+00 alpha: 7.2e-01 mu: 1.00e+00 bound: 7.4e+00\n",
      "(3, 3)-> took 1.5 seconds, hessian: 2.5e+00\n",
      "0.3\n",
      "train error 4.040824143831435e-16\n",
      "Hessian:2.6e+00 alpha: 7.4e-01 mu: 1.06e+00 bound: 7.2e+00\n",
      "(3, 4)-> took 1.5 seconds, hessian: 2.6e+00\n",
      "0.3\n",
      "train error 1.8666950472770693e-16\n",
      "Hessian:2.4e+00 alpha: 6.4e-01 mu: 1.04e+00 bound: 7.3e+00\n",
      "(4, 0)-> took 2.3 seconds, hessian: 2.4e+00\n",
      "0.3\n",
      "train error nan\n",
      "Hessian:nan alpha: nan mu: nan bound: nan\n",
      "(4, 1)-> took 2.2 seconds, hessian: nan\n",
      "0.3\n",
      "train error 1.5349944305032232e-16\n",
      "Hessian:2.4e+00 alpha: 8.3e-01 mu: 1.09e+00 bound: 7.1e+00\n",
      "(4, 2)-> took 2.2 seconds, hessian: 2.4e+00\n",
      "0.3\n",
      "train error 1.488352031867073e-16\n",
      "Hessian:2.5e+00 alpha: 7.6e-01 mu: 1.12e+00 bound: 7.1e+00\n",
      "(4, 3)-> took 2.2 seconds, hessian: 2.5e+00\n",
      "0.3\n",
      "train error 1.6316185228211754e-16\n",
      "Hessian:2.4e+00 alpha: 7.3e-01 mu: 9.62e-01 bound: 7.6e+00\n",
      "(4, 4)-> took 2.2 seconds, hessian: 2.4e+00\n"
     ]
    }
   ],
   "source": [
    "widths = [10, 20, 40, 80, 160]\n",
    "learning_rates = [0.01, 0.005, 0.003, 0.001, 0.0007]\n",
    "ntries = 5\n",
    "\n",
    "model_size = torch.zeros(len(widths))\n",
    "hessian_fro = torch.zeros(len(widths), ntries)\n",
    "mu = torch.zeros(len(widths), ntries)\n",
    "bounds = torch.zeros(len(widths), ntries)\n",
    "alphas = torch.zeros(len(widths), ntries)\n",
    "\n",
    "for i, m in enumerate(widths):\n",
    "    lr = 0.3\n",
    "    for j in range(ntries):\n",
    "        time_st = time.time()\n",
    "        \n",
    "        net = models.build_linear_net(d, m)\n",
    "        train_net(net, train_loader, nsteps=5000, lr=lr, display=40000)\n",
    "        H_i, mu_i, alpha_i, bound_i = compute_geoinfo(net, train_loader, batch_size, lr)\n",
    "        hessian_fro[i,j], mu[i,j], alphas[i,j], bounds[i,j] = H_i, mu_i, alpha_i, bound_i\n",
    "        \n",
    "        print('({:}, {:})-> took {:.1f} seconds, hessian: {:.1e}'.format(i,j, time.time()-time_st, H_i))\n",
    "    model_size[i] = utils.num_para(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "596fae76-5a70-451a-9bc5-f287a93fee71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAAEvCAYAAAByhLuPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAn4UlEQVR4nO3de3RcV2Hv8e/W25IfY8l2XsQPmTxIAiGyHBICARI7hbKAppXj0gZKH5ECC9pV2srJXeX2Lm5pkNPF5RaSIqeUe7l0QWynAUraECk8EpJAbCsB8nISKXZC4qfksS3Jeoxm3z/OHvl4NDMaaUaaPdLvs5ZXMuecObN1pPnNnn32w1hrERGRwispdAFERCSgQBYR8YQCWUTEEwpkERFPKJBFRDyhQBYR8YQCWQrKGNNqjNlhjLHu3w5jTHuhyyVSCEb9kMUHxphjQI+1dl2Wx0eAPUCbtXbbTJZtLjHGNFhruwpdDklNNWTxRZ/7JzNrc6ELIOmVFboAItNhrY0CawtdjmLivlVEClwMyUA1ZJH5o7nQBZDMFMgi84AxpgFoK3Q5JDM1WUhRMsZ0AI3Admtti9tWD+wA6oHtwBbgJveUtUAkcWyK87UB3e7hOqA9+eaX+8rfDEQJvvqvBba45pPEMeEy3AHsBJrcsd3W2q0ZfqYZKb8xpgnY6Mq9IdyLJd35pDDUy0K8YIzpJuhlsXEKz+kAotbaTSm295AUgG57l7V2S9Lxe4BbQgEWIejBsSkcysaY1qTzbSAI0DXhUA6dsxPotdZudYHZZK2dtN17Bsu/A+hTCPtLTRZSzHoybL8JSO4O10FQWx3ngrInHFwuXNuBe0LHbQA2urBLHNdJ0DMkVdtsj3utbe7YLQQ112zkvfxSHBTIMlf1JNdaCYKuNmlbK0HQJesEGkKP+wiaSJKf30P63h7RcBlSlCeTfJdfioDakKVoGGMiUwi1dLXn8PkSgbXWtbOG1bpj6q21iRro0tBz6wnaemszvNakZcggr+XPoRwyixTIUkyagbQ3xaah3v333jSj185oMnAhnGi/7SCohWYKu5ke6DKl8ov/FMgyn2U9hNjVRhM3ynaGts9EubKV0xBo1x5er6HU/lAbshST9fk8Weir/IZU+8M38AhukO0Mh7ETCR2f8jwzZYrlT6d+8kNktiiQpSgYY5qZmSaALaSf3yHce6IB2JXimEQ7cuL/Z1u25Yfg+oVvCtaTWzu35JkCWXyR3HtgnLth1U4wsCFZJM3TUm2f8Bqun2+fMaY16TUbODOsukiqoYfKla4MuUh1zlzKD0Gf6XBtulHNFX7RwBApKBckGzkdFDs5XROuJaiZJmqeLdbabe6r+D2h53QCt7jjw6PbOglG0vW40Wkb3Ll2EtwIC7cFtwJ1BKPd+gi6rHWG9kc4PfR4j/vvbmttV2LABac/NNpCZdsOdKRo6kh1LRI3DfNe/jTH9aQ6RgpHgSwi4gk1WYiIeEKBLCLiCQWyiIgnFMgiIp7QSL00li1bZlevXl3oYojIHLNnz56j1trlqfYpkNNYvXo1u3fvLnQxRGSOMcbsT7dPTRYiIp5QIIuIeEKBLCLiCQWyiIgnFMgiIp5QIIuIeMK7bm9uVq1moC55ufPpHO9mtxpfHNJaq2VtRMRLXgWyW3EhQvpVfKd0vFsifVdimkJjTJsxpimbqRBFRGabV00W1tpOF5bRPB3fnBS+9wItORVSRARYfdsDrL7tgbye06tAzqfQEulhUdKsPyYiUmhzNpAJ2oyT12Cb6WXZRUSmbS4HciTdjnSr8Rpjmo0xu40xu48cOTJT5RIRSWkuB3KUiYtCpl1IE4IeGNbaRmtt4/LlKSdjEhGZMXM5kPuYWEuOAFhro7NcFhGRSc3ZQHbLm0eTNtcSrOQrIuKdogpkY0y9MaZpCk/ZnnT8RoKl2kVEvOPbwJAGgm5pTe5xK9Dparu47RuBndkcb61tMca0ugEk9UC3BoWIiK+8CmQXpF3A1jT7t4b3TXZ86DkiIt4rqiYLEZG5TIEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4oqzQBUhmjIkAzUCdtXZLFse3Aj1ALYC1dltoXzOwDtjhNm0C2qy1PXkutohIzrwKZGPMBiACrM3y+DZgl7V2Z+KxMaYp8di5iSDgu4BbFMYi4iuvAtla2wlgjFlPEMyTaU6qRd8LtAHjgWytXZrPMoqIzJSibUM2xjSk2BwFNsxyUURE8sKrGvIU1QJ9SduSHyfakftI0cYsIuKTYg7kSLodxpiItTYK7AaiiXZjY8wOY0xfUhtz+HnNBO3NrFy5Mu8FFhHJpGibLAiaJ2qTtp3x2FrblXQTbxdwe7oTWmu3WWsbrbWNy5cvz1tBRUSyUcyB3MfEWnIEwNWOE702wnqAVG3PIiIFV7SBbK3tIqglh9UCiZ4a9UCH69ccpm5vIuKlogpkY0y9MaYptGl70uONQDuAa6rYkqgtO5sJusWJiHjHq5t6rivbBqDJPW4FOl1tGLd9I66fsbW2xRjT6pom6oHupBt2O905AOqADvWyEBFfeRXILni7gK1p9m9N3ue2pTtfT7pziYStvu0BAPZ98YMFLonMZ0XVZCEiMpcpkEVEPKFAFhHxhAJZRGQKBoZjfKlj7/jjhs8/xJc69jIwHMv53ApkEZEsDQzHuPHux2j/6enhDH2Do7T/tIcb734s51DOOZCNMYuNMW83xlzn/rs613OKiPio/ZFu9vcOMhyLn7F9OBZnf+8g7Y9053T+KXd7c4HbBNxA0Gf4mPsXdYdEjDG1wBKCLmydQLu1dl9OJRURKbBvPbF/QhgnDMfifOvnr/LZjRdN+/xZB7Ix5gqCUW6WIGS3WGufyuI5jcA2Y8xSghU7np52aUVECujY4Ogk+0dyOn9WgWyM+RpBLXiTtfZ4tid3gf0UcI8xZglwuzGmxVr7yWmVVkSkgMpKDaNjNu3+pdUVuZ0/004Xom0EC4O+kssLuSC/zRizxBjzzwQ17BO5nFNEZDb0D8X4nbsfyxjGlWUl3HxVbvOoT3ZT7yZr7a25hnGYtfa4qyFvztc5RURmyouHTnL1HQ/z8uF+AJZUlVFZdmZ0VpaVsKqumpZrs1qfOa2MgWytvSensxfo3CIyfatve2B8bo/57ntPv84HvvwIJ4djlBjY2vQ2Hr/9elreUz9+TG1NBS3vqef+T11DTWVu0wN5NbmQiIgv/u77z/J/H98HQHVFKTtvvZpLzl0CwGc3XsQ/PfwyAF2f25i315xKL4vrCFbbqHObuoEea+2P8lYaEZECG4nF2bztCZ56NQrA6rpq/uMz72JRVfmMv/akgey6rj3sHiZWb464x9YYA9BB0Nf4/hkoo4jIrHg9OsiHv/IYvQNB97X3X3oWd/9hAyUlszOoOZsacjOwJrm7mzHmemAd8EWCQSI3GGOOEQTzf8t7SUVEZtBP9x7mz765m9ExiwFu+8DFtLwnt5t0U5VN7Pek6ntsrX3YTQ7faa0tIQjlnQRd23qNMe/Lc1lFRGbEPz38En/0jV2MjlkqSkv49i1XzXoYQ3aBXGeMuTzDfgtgre201ra4cG4D7jPG/FU+CikiMhPi8Th//I0n+VLHiwCsWFTJo1vex1Vr6yZ55syYNJCttbcRhOs/ZHtSV3OuB97hbgaKiHglOjjCtXf+hB/vPQLAVfW1PL7lOs5aXFWwMmXbUt1I0Ebca4z5B2PM4tA+k+oJ1tqotfYmgkVJRUS88cvXjnH1HT/iN8dOAdBy7Rq+03w1ZWWFnZE4q25v1too0OhWcP4isMUYswfYjmuyEBEpBv/2i/387XefwVooLTF89Q+u4AOXnVPoYgFTHBiSWPXZGNMMtOJWdDbGjBFMtbmboH9ylKBr3HqgJ9W5RERm219tf5r7ul4HYMmCMr77qWtYs3xhgUt12rRG6llrtxFMqdkA3EQwL/Jagm5wEATyboIucPfloZwiItM2OBLjd+9+nBcOngTgLecs4r5PvpPqCr8GK+dUGmttF0HNWETES91H+rnxrsc4MRQsr7Rp3Zu4c1OmjmOFM9n0m4tnaorMmTz3bEtMxLLvix8scElEJOw/f3WAz3znKcbiFmPgC7/zVv7gHblNkTmTJqshbzbG7Mr3Kh/GmLcT9Nz4l3yeV0Qk4QsPPMc9jwYzBy8oL+U7ze/g8vOXFrhUmWUMZGvtPcaYvzHGbLDW/mM+XtAY8zfBqfNzPhGRsFgszh98/Rc8+UofAOcvXcB/fOZdRHJczWM2TNqGbK290xhzhTHmIeAhYNtUmxpcv+UWgpt/W7SuXnFQU4wUm0MnhvjQV37G4ZPDAFx/8Qru+fi6WZscKFfZ9kN+imBgyO8BO40xawgWOu0m6NYWJZgJDk7PBldP0PNiozuu3Vp7Zz4LLyKS8PjLR/nEN3YxMhasCv3XN1zIp6+7oMClmpqp9kO+j2AY9RKC2u564EpOBzCcDucugtC+bSoLo4qITNXXftJN24MvYIHyUsO/fmI9775geaGLNWXT7Yd8HLjP/RMRKYh4PM6t3+rioecOAbCspoL/+My7OCeyoMAlmx6/ekWLiGTp+OAIH77rMfb3DgLQuGop37nlqoLPR5GLGSm5Mea6pAmIRETy5rk3jnP1F380HsZ/cs1qdn7ynUUdxpBDDXmSgR3HgBZjTC1wx1wZACIihbd992vcdt+viFsoNYb/9fuX8+HLzyt0sfIilyaLfe7mXhfBmnqdwG5r7QnXK+MpAGPMHcDtOZdUROa92+/7Fd/e9RoAi6rKuP+T7+TNZy0qcKnyJ5dAbiToW3wFcJv7Z40xPbhwJqgp16c9g4hIFoZGYjR97QmeeSP4sn3hWQv5909ew8KquXUbbNo/jbW2B9iSeOxWp95I0B1uM0FYW2BTjmUUkXlsf+8AH7nrMaKDowB85O3n8r9//4oCl2pm5O3jJdRMkZgjuQlotNb+e75eQ0Tml4eePcin/q2LWDxYCfp/fPhS/uidqwtdrBkzY/V9a+1OY8weY8yfWWs1iZCITMmdP9zLXT9+GYCq8hL+7c+uYt0qvycHytWMNsBYa18xxsztKygieRWPx/nY15/kse5eAM5dUsUPPvNuahf6PzlQrnLp9vbPBMOkdwGdGbq2ac09EcnK0f4hPvSVxzhwfAiAay9cxv/5xPqimRwoV7nUkE9w+qZeuHdFB0FXuD6CG3wbAU21KSIZPflKLx/7+pMMx4LJgf7i+jfzlxsvKnCpZlcuvSy2EKw+na53RcK2ubQ6iIjk39cf7eHvH3h+fHKg9o81ct3FKwpdrFmXcxtyit4VawgCeiNwPUFAN4dq0A8BDyugRSQej/Pn33maH/zqAABLq8v5/mfexflLqwtcssLI+009a+0rwDb3LxzQ4Rp0N1BcE5WKSF71D8X4yF0/o/vIAACXn7+EHS3vpKLI56PIxYwPc0kT0JGZfl0R8dcLB07Q9LXH6R8eA+Dmd6zk7298a4FLVXizPu7QBbSIzFP3d73OX+14mriFEgN3Nr2N31t3fqGL5YW5NRBcRLz237/3DN98Yj8ANZWl7Lj1ai45Z0mBS+UPBbKIzLiRWJyb2p/g6deiAKxZVsP3P30Ni6rKC1swzyiQRWRGvXZskI989TH6BkYA+OBbz+YrH71i3gz2mApdEZnXBoZjfKlj7/jjhs8/xJc69jIwHCtgqeaOn+w9zHX/+BP6BkYwwOc++Bbu+sN1CuM0VEOWeWtgOMaNd59ekw2gb3CU9p/28OAzB7n/U9dQU6m3yHR9ufNFvtz5EgCVZSV880+u5B31dQUuld/0MSXzUjwe5x8f2ssrRwbGh+omDMfi7O8dpP2R7gKVrrjF43H+6F+fHA/jsxZX8mjr+xTGWfDu498YEwGagTo3PHuy41uBHqAWwFq7bSr7czEwHDvjTdvw+Ye4+epVtFy7VjWrAhqJxXnx0EleOHCC7iMD7O8b4MDxIXpPDnN8aJTBkTFGxzLPeTUci/PVH73M3gMnefcFy3j/W89m2cKqWfoJildf/wgf+urPeD16CoB3rq3jm398ZdEvPjpbvEoNY8wGgkEja7M8vg3YZa3dmXhsjGkKP860Pxf6ulsYff0jPHvgOC8ePMkrRwd47dgpDp8Yom9ghJPDMYZGx4jnaX7BuIUfPneIHz53iL/93rNUlJVwzpIqLjt3MddeuJz3X3o2S6rn/pSQ2Xrq1WN89J6fMzQafOP41HvX0vr+iwtcquLiVWJYazsBjDHryW40X3NSLfpeoA3YmeX+aWt/pJv9vYMZv+5+dp7NVJWLeDxOz5EBnjt4gu7D/ezvPcUb0UGO9A8THRxlYGSM0Vg867lcDUG7ZU1lGZHqCs5aXMl5kQWsqqvmghWLuOScxXz4q49y7FT6m3cVpYbzllbzRvQUw7E4I+53u793kAd+fZAt9/2aqrISzo0s4G1vivDei5az4S1nzbl13rLxzSf28XfffxZroazEcPcfNnDDpWcXulhFp2j/cowxDSk2RwnmzJh0f66+9cT+CWGcMByLc9ePX+bJV/pYWFnGwspyFi8oY3FVOUury4lUl1O3sJLamgqWL6ykbmGlV+P3890U0z8U4zlXq+05OsBrfYMcOB7Uak8MxTg1OsbYFKq1pSWGBeUlLKoqp7amgrMXV3F+bTX1y2u46KxFXHLu4qz6t37snatp/2lPyt9jZVkJLe+pH/9QjQ6O8F/PHOTRl47w7OsnOHB8iJGxOEOxOD1HB+g5OsB3n34dgAXlpZxfu4DL3xThfRet4LqLl1NVUbRvtUn95b1Pc/9Twc8eWVDO9z59DavqagpcquJUzH8ltQRzLof1TWF/To65BRfTGYvDz3um9nLGQKkxlJYYyksN5aUlVJaVUFVeyoLyUqorS6muKGNRVRkLK8tYsqCcJQvKWVpTwdIFFdQtrKCupoJliypZXFU2ra5FU2mKicfjHDgxxLOvn+Clw/3s6x3gjWOnOHxymGMDI/SPxBiOxbFTaEKoKC2hurKUyIJyli+q5JwlC1hdV83aFQt5yzmLqa+ryVt7ZMu1a3nwmYMTvulUlpWwqq6almtPt5xFqiv46JUr+eiVK8e3HT4xxIPPHuRnLx3luQMnOHRiiNExy6nRMV481M+Lh/rZsec3ANRUlLKytpq3nx/h+res4NoLV3j1ITwdgyMxbrzrcfYeOgnApecu5r5br57THz4zrZivXCTdDndjMON+a200xfZmghuKrFy5Mnn3GZZWl9OXIZTLSw1XnB9hYHiMU6NjDI2OBV97x+KMjsWJjVni1p7R3mktxKwlFrcE3WDHMpYhGyUmqFGWlZRQXmqoKCulqjwR8sFX+pqKMhZVlbOoqoznD5yg+0g/Y0mVxuFYnJcO9XP1HQ8DcGp08htjyeWoLCtlUVUZtTUVrFhcyZsiC1izrIYLz1rEpectnvWbZjWVZdz/qWtof6Sbf3o4WLuttqaCm69amdW3gRWLq/j41av5+NWrx7e9Hh3kwV8f5LHuo7xw4CSHTw4Ti1sGRsZ4/uBJnj94km/veg2ARZVlrKyrpmHlUja8ZQXXrF1WNDe/Xj50khv/+XFODgVNPpvXn0/b772twKUqfsUcyFFcz4mQ2insn8D1wNgG0NjYmDFtbr56VdZfdycTi8XpGxzh6MAwvf0j9A0E/46fGiU6OMrJoVFODsUYGI4xMDLG4Ehw82poNAj4kVicWNwy5kI+XPC4hfiYZXRsjFOjANMf8GCBE0MTn19WYqiuKGXJgqAp5pwlVaysXUD98qBWe+GKhd7Wmmoqy/jsxovGA7nrcxtzOt95kWr+9N31/Om768e37e8d4L9+fZDHe46y9+BJjp4cYcxaTg7HePaNEzz7xgn+38+D+R0WV5WxZlkN61Yt5YZLzubKNUu9G0Txg1++wV9852nGrKXEwB2/+1Y2r89cgZHs+PkuyU4fE2vBEQBrbdQYk3F/ri8+la+7kykrK2HF4ipWLM5fDbF/KEbvwDBH+0fo7R/m2OAIxwZGOX5qlOOngh4JJ4diDA7HzqjFv+HWMsvkr2+4MLgxdt5izltS5V1g+GZVXQ23vnctt7739N/Ey4dO8p/PHOQXPb28eLif3v5h4jb4wPvlb47zy98c518f24cBllSXU7+shvWrl3LDpedwxflLCnbN/+cPnuPrPwsmbKyuKGV7y9Vcdp4mB8qXog1ka22XMSaatLmWYFWSSffnKtevuzNtYVUZC6vKpnxzpeHzD2VsiqmtqeDT12ltgVy9+axF/PlZi+D609fyuTeO8+AzB/nFK328fLifvsERrIXo4Chdr0bpejVK+yOvYIClNRW8eXkN69fU8YHLzp7xUIzF4nz0X37Orn3HAFhVW833P32Nuv3lWVEFsjGmHmgI9SPentSveCPQHnrKZPtzku+vuz6YrCnm5qv01XSmXHLuEi4593SwxuNxfv36cR585hC79vXRfaSf6OAoFugbGOHJgRGe3HeMu378MiUm+LC8YMUi3lFfy2+/9RwuPGtRXsp1IHqKD331MY72DwNwwyVn8bWbG/TNaAZ4Fciuq9oGoMk9bgU6rbVd7pAmglDdCWCtbTHGtLoBJfVAd3jQx2T7ZaJ8NsVIbkpKSrj8/KVcfv7S8W3xeJzd+4/x0HOH2L3vGK8cHeD4qVHiFo72j3C0v5cnenr5cudLlBqoW1jJRWcv4qr6On77srNZs3xh2tdL1d3x2guX88CvDzA6ZjHAlvdffEbTi+SXV4HsgrcLt2Bqiv1bk/e5bZnOmXG/nMn3ppj5rqSkhCvX1HHlmtPzQsRicX7+Si8dzx+ma38f+3oHOTkUY8zC4ZPDHD45zKMvHeXOH+6ltMSwYlEQ0tesreO333YO50Wq03Z3/O7TbwBBr6FvfGI977pg+az/zPOJ3l0ywVxsipnLyspKeNcFy88Iy1gsziMvH+Xh5w/x1KvHeLVvkP7hYADOgeNDHDg+xE/2HuEL//kC5aWGqvJSBoZjaYedf/yq1QrjWaBAFpmDyspKuO7iFVx38YrxbUMjMX764lEefuEQv3ztOK8dGxyfaGl0LHN3yPuffp3PfeiSmS72vKdAFpknqirK+K3Lzua3Ljs9x8TgSIzO5w/x599+OuNzjw2OzHDpBDQfssi8Vl1RxocvP4/a6sxzfyxV97ZZoUAWEW6+ehWVaYZtq7vj7FEgiwgt165lVV31hFBWd8fZpUAWkfHuji3vOT0HR21NBS3vqddiC7NIV1lEAHV39IFqyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCcUyCIinlAgi4h4QoEsIuIJBbKIiCc0Qb2IyDTs++IH835OBXIezMQvRkTmHzVZiIh4QoEsIuIJNVlIWmqKEZldqiGLiHhCgSwi4gkFsoiIJxTIIiKeUCCLiHhCgSwi4gkFsoiIJxTIIiKeUCCLiHhCgSwi4gkFsoiIJxTIIiKeUCCLiHhCgSwi4glNvymCphoVP6iGLCLiCQWyiIgnFMgiIp5QIIuIeEKBLCLiCQWyiIgnFMgiIp5QIIuIeEKBLCLiCS9H6hljWoEeoBbAWrttkuPbgG5gLdAdPt4Y0wysA3a4TZuANmttzwwUXURk2ryrIbtw7bHW7nTButYY05Th+A6gw1q7zVq7BVhnjNmQdNhNQAfQBrQrjEXER94FMtBsrd0Zenwv0JLqQGNMPbDBWtsZ2twBbAkfZ61daq011tp11tquvJdYRCQPvApkY0xDis1RILnGm5Dq+J4Mx4uIeMu3NuRaoC9pW/LjsC4AY0zEWhsNneOMba4duY8s26RFRArBt0COpNuRFLoAWGt7jDFdQD0unJlYa94NRBPtxsaYHcaYvqRmkcRrNAPNACtXrpzuzyBS1DQVaeF41WRB0DxRm7Qt+XGy64HNxphmd/OvByAR3tbarqSbeLuA21OdyN0YbLTWNi5fvnwaxRcRmT7fash9TKwlR+B0wCZz28dv4rkeFl3hx0k3/XpI3fYsIlJQXtWQXQ+IaNLmWqBz4tEB19MibBPQHtrXYYyJJB2jbm8i4h2vAtnZntTveCMuYCEI2aT9exK9M1zwNiZu2rmmii1JtevNBP2RRUS84luTBdbaFmNMq2t6qCcYeRe+AddEENKJbbcA9caYRmCttXZd0il3upF/AHW4QSQz+COIiEyLsdYWugxeamxstLt37y50MURkjjHG7LHWNqba52OThYjIvKRAFhHxhAJZRMQTakNOwxhzBNhf6HJ4bBlwtNCF8Jyu0eTm4zVaZa1NOfJMgSzTYozZne7GhAR0jSana3QmNVmIiHhCgSwi4gkFskyXBtdMTtdocrpGIWpDFhHxhGrIIiKe8G4uCyk8N0nTTe7hWgC3gGy64+f9yt7GmHZrbcq1H0PHTGk19blmsmukvyMFsqTWRmiWPGPMHmNMq7V2a4bn3ESw2koXcMt8ehO5ldKTp4FNdcyuxERZxpg2Y0xTqpVr5qJsrpEzb/+OQIEsqTUSLBSbCIseYH2mJ1hrl850oXyUZmHeVJqTvmXcS/DBN+cDeQrXaN7+HSWoDVkmsNauS6q5NQAdhSqP5xqZ5NpMYzX1uWbSayQB1ZAlI9fu2TlZe+d8XNnbLZSwnSBwMpnqaupzxhSuUeL4efd3FKZAlpSSbux1T3J41it7zxXu+kSttVFjzGSHRzKdJ916kcVuitcI5uHfUTI1WUhK1tqoW4V7K7DRGLMjw7FZr+w9h9yUtHhuJlGmvpr6XDCVazRf/47OoECWMxhjIqElrxI6CJbOSvec5LbQOb2yt2sTzjpomMZq6sVuGtdo3v0dpaImC0nWCLQZY7ZlExahlb2XJh0/l7sr1QIbQl/D1xOs69gK7EzuqmWt7TLGRFOcY0qBVWSmdI3m6d/RBApkOYO1ttMYk7xS90ZgvA+ye/M0WGt3Wmt7Uhw/p1f2dl/Dx8PU3YiqD/fTDl8jt2l7Ur/jM1ZTn2umeo3m499RKprLQiZwb5REE0Ud0Jv0RmoFNlprN6Y5vnu+3B13QbOJ4JvFHcA2dxPrjGvkjm0lGPBQD/OnB0G212g+/x0lKJBFRDyhm3oiIp5QIIuIeEKBLCLiCQWyiIgnFMgiIp5QIIuIeEKBLHOKMabZTWojHtLvJzMFsmCMqTfGdBhjjhljinb0mBtosNvn+SHcde42xuyZxnOL/vfkBnrMq9F3U6FAFqy1PW601G6KdBayxCTw1tquQpclE3edpzWd5Ez+ntykUt1uVN1MayvWD5SZpkCWsGKeyKVtkjX/fDLZ/NKTKebfE25ioehUlnaaLzS5kBQ9tyqFlgjKgWvmWTuLL3kHwerSGyc7cD5RDVnmghZgXk1CU+xCK5pnsxL1vKFAlqLm7tjX+3wjT9LKuPDBfKQmC8mKu9kTIViOKEKw9lnaWqlrHwyvj1ZPsPpD8jI9udpAione3evtIJjqcjuwhdNrBK4DsNa2uGMTN7IiBBOp35Iq4Kd6DdxzmlwZEs/JeNPRGNPG6TbmdUB7rjcq3bXYwOnFQyOuPOPr1RljOgimx9yeuC5uu3Vl7uTMtu+25GOnUf5Od55iafufedZa/dM/rLUQTJi+I8X2HcCGpG1NaY5tcsc3hP6/1T1uSPWcHMvcBrRm2N/hfq7WFNvbCIIqEtreCuzJ5RokXc/k121wr53qNfYQTNieeBwhCLeGFOfN6jq6c7Sn+XmaUlyTHUnP7Ujx3FbgWPi6TaX8Sc+xs/137vM/NVlIRon19WzSYpU2qFnVh7tJuVrxDoIaZpc75l6gxT3ustZuynMR68nc66CHoGacXJPtAJphwrp2XUBDePDCVK5B6DlNBAG+Nek5XaSoJbuaZY8N1SZdudqBezL8fJPZQFAbTnZHim3J17GW4PcZLmc9wQfZGd8icih/VANFTlMgy2TaCEI1lXbO7OS/maBJIhra1kkQWskLWOZLhNSBE9ZjJzZB9BDU8JKbOxKhFO7nO5VrEH5Ouv7GvSm2tZK6p0gnuS302QM0J3cxc8GZTdPR7qTHOwjWxEv+2aZb/j7cCiqiNmTJIHQHPN0btweIGGMiKQJvttQTvKkzSVf+aIptZ5wrh2tQT5b9jUNhudbVrMNqE+Ww02h7t8ECq53AHmNMD8GHRIe1ttNO0jad/HqJdfGA6/NY/ihFOhhpJiiQJZPEGy06yXH1BF/Dx5sBQhoJbn6luvEWIWhO2GRDa8+F9reGXjuS/PV/lkz1GoRDfLIPivBzAe5NE5I5demz1m5yQbmZ4PfT6kJ6U7YfpO5nak/znFzLn+11mvPUZCGZJGo0kUmOi8LplYaNMe1uKG4DQe+G65Of4Pbd5M49oYYUarfdZoOeDF1phttGUz0/j6Z0DeCMmmW25Zqx4d6JDwcbrOy8yVq7lGAASC1Tm1NiQlNFqBkql/JHmPzDbt5QIEtaodpOY5pDUnVtaycI4Q1ArbV2Y6pak7vBt430TQG3E6pZubBPNc9CH5OH5bRN9RqE9JDlyLfQc1O2s+d402tDivbjHoJVoLNq1w81VdyStCsxf0gu5a+dTlPMXKVAlslsJXjzprKZIHwTGgiaFqKuRjahmSIbLthTtkunuDnYxczfFJrKNUhIdKlLZT0TP0S2uHOlkuuEPxPO60IwOtkTQx+yqfpmhz9wplv+yGRlmE8UyJKRtXYLjHfjGueaFPrsmQMjuoDb8zAcNt3zo0x8A+9i8ppo8nMgfXPChO1TvAaJ52wD+pK7xLkPlEjy67j28b5EU03o+AZyn0yoOfl34sqRqgkokvQ4Za8K9/zxn2E65Xf7vJ6db7YZG3TOlnnMfaW8h9M1uk5gS/irZOiNFiXDKDXXzptcI+oJnTOadHwTcLu1dl1o2waCngAm6dhuglndtoW2RQgGWaxNOraeM0fnjf9MrowbCIJ/J8HNqJ3uZ9yMG1Hotm8NnTOra5BUjsRzeggCrMedv829xpbwNwl3fB1BD40+QjdEs/k9pXj9JveaiUAOh+i2DOe9xT3eQfANIdFjJEJQw28CttmJI/XSlj9F2ZqBtYkPPFEgSx65wQG7km78RAjCoAW4yd1UCj9nKoF8jCCAtiVt72AKPQbED+731qI25NPUZCF54YI1kvzV1rUnd7ma1O4U/VRTSfcGjaTZl6pWLh5zH9S6oZdEgSz5UksWI+ayOZE9PYH5hLbkVF9/3YeA5tUtLrczsdfGvKdAlrxwzQgN6YZIu+31KYbcpru5dgehXgquZp2pvXZL8g0l8ZOrHUcmGyk4H6kNWfIqaYrKhAjBfBLhtuV6ghtDiZtoW5nY/tzK6Vr1+slu/rjjJx0SLIVljGlPvhkoAQWyzCnuA2G7bvD5Sb+fzBTIIiKeUBuyiIgnFMgiIp5QIIuIeEKBLCLiCQWyiIgnFMgiIp74/2tykT8aWsA0AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = model_size.log10()\n",
    "y_mean = mu.mean(dim=1)\n",
    "y_std = mu.std(dim=1)\n",
    "plt.plot(x,y_mean,'-o', color='C0', markersize=8)\n",
    "plt.errorbar(x,y_mean, yerr=y_std, label=r'$\\mu(\\theta)$')\n",
    "\n",
    "\n",
    "# y_mean = alphas.mean(dim=1)\n",
    "# y_std = alphas.std(dim=1)\n",
    "# plt.plot(x,y_mean,'-o', color='C1', markersize=8)\n",
    "# plt.errorbar(x,y_mean, yerr=y_std, label=r'$\\alpha(\\theta)$')\n",
    "plt.xlabel(r'$\\log_{10}$(model size)')\n",
    "plt.ylabel(r'$\\mu(\\theta)$')\n",
    "# plt.legend()\n",
    "plt.title('Linear net')\n",
    "plt.savefig('../figs/alignment_linear_net_modelsize.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "a6aa9a6e-137e-4f1e-9f5a-06fee6f4aead",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVwAAAEvCAYAAAAJoHlDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjuklEQVR4nO3df5RU5Zkn8O/TNKCAUjQQJSg/qlUUFUzTOCa4zGboNo5mNEqDx7Oe3Uw8dmeys8meZOyWmMTZODPQOJnZM9l47GYzM5n17I7QmGh0zNhtVBI1waYjiBqVLkCiqNBNgdABbOrZP+5b1bdv18+uqvfeqvp+zumjdeu9VW9X0d+69dz3vq+oKoiIqPiq/O4AEVGlYOASEVnCwCUisoSBS0RkCQOXiMgSBi4RkSUMXCoqEWkVkS0iouZni4h0+N0vIj8Ix+GSDSJyBEBEVZdl2T4EYAeAdlXtLGbfyomI1Klqn9/9oOR4hEu2DJofKq7b/O4ApVbtdweIklHVKIBav/tRSsy3gpDP3aA0eIRLVD6a/e4ApcfAJSoDIlIHoN3vflB6LClQIIlIN4B6AJtVtcVsCwPYAiAMYDOANgBrzS61AELxtkkerx1Av7m5DECH9+SS+UreDCAK56t5LYA2U96It3H3YT2ALgBNpm2/qm5M8zsVpf8i0gSg0fS7wT0KJNXjkT84SoGsEJF+OKMUGnPYpxtAVFXXJNkegSfgzPY+VW3ztN8B4C5XQIXgjIBY4w5dEWn1PF4DnIBc6A5d12P2ABhQ1Y0mEJtUNWPduYj93wJgkCEbXCwpUJBF0mxfC8A7XKwbztFmggnCiDuYTHh2ANjkatcAoNGEWbxdD5yRFclqoxHzXJ2mbRucI89sFLz/VBoYuFSqIt6jTjhBVuPZ1gonyLx6ANS5bg/CKWF4948g9WiJqLsPSfqTTqH7TyWANVwKDBEJ5RBaqY5+3Y8XD6RaU+d0qzFtwqoaP4Kc4do3DKfWWpPmuTL2IY2C9j+PfpBFDFwKkmYAKU86jUPY/PeRFFdfjfpKb0I2Xj/thnMUmS7Min0hR079p+Bj4FI5y/oSV3M0GT8R1eXaXox+ZSuvS3RNPTrMS32DgzVcCpLlhXww11fthmT3u0+QwTkB1eUOWyPkap/0cYolx/6nEs7chGxh4FIgiEgzivMVvQ2p5xdwjz6oA/BykjbxOm78/23Ltv+A8/q5T7qFkV+dmQqMgUu2eM++J5gTQh1wBu57hVLslmz7mOcw41wHRaTV85x1GB1GffAcYbv6laoP+Uj2mPn0H3DGDLuPhutZTggWXvhARWWCohEjQdCFkSPZGjhHlvEjxxZV7TRflTe59ukBcJdp7746qwfOlWARc3VVg3msLjgnmty12FYAM+FcrTUIZ0hXj+v+EEYujd1h/turqn3xCwow8qHQ7urbZgDdSUoRyV6L+Em5gvc/RbtIsjbkHwYuEZElLCkQEVnCwCUisoSBS0RkCQOXiMiSwF1p5pqTFHCG6XRnWkTQ7LMOI/OF9mYaDjNr1ixdsGBBXn0lIvLasWPHYVWdney+wAUugHXu+UBFpF9EkCp0Tdhuic+zagbQrwOwJln7uAULFqC3t7dwvSYiAiAi+1PdF6iSQvzab8/mDoxMKJLMJtMmLj6TPhFRoAQqcI0GM0A8Lor0l1Q2AegRkbCI1KlqlNPVEVEQBSpwTVjO8ARmI5wrcsZwzRda79q2JctJPYiIrApU4HqZ4GxA6hJB4sjXNYn0I0ix9IiINItIr4j0Hjp0qNDdJSJKK9CBCyc416QZcRA1/3Wf/YqvNTWGqnaqar2q1s+enfQkIhFR0QQ2cM0kHB0ZJt+IAGPWkoqa/UPF6hsR0XgEMnDNtHh98bBNNfGzqfVGPeEagmdxPyKiIAhc4JpwrQHQKyIhM2KhznV/2LOg3nqMTHcHOJM1r7fSWSKiHATqwgdzpBpfEto9ttY912gTnJELXYAzQbOItLomaB4wkzYTEY3bgnueBADs23BjwR4zUIFrygBpV+0zYboxyTYiokALXEmBiKhcMXCJiCxh4BIRWcLAJSKyhIFLRGQJA5eIyBIGLhGRJQxcIiJLGLhERJYwcImILGHgEhFZwsAlIrKEgUtEZAkDl4jIEgYuEZElDFwiIksYuEREljBwiYgsYeASEVnCwCUisoSBS0RkCQOXiMgSBi4RkSUMXCIiSxi4RESWMHCJiCxh4BIRWcLAJSKyhIFLRGQJA5eIyBIGLhGRy4lTw/i77jcTt+u++zT+rvtNnDg1nPdjM3CJiIwTp4Zxy4MvoOP5SGLb4NDH6Hg+glsefCHv0GXgEhEZHdv6sX9gCKeGY6O2nxqOYf/AEDq29ef1+AxcIiLj4Zf2jwnbuFPDMTz8q3fyenwGLhGRMTj0cdr7jwydzuvxGbhEVPE+OHYSX/yn7RnbzZgyKa/nqc5rbyKiEnZgYAitW3fhV5EBaIa2k6urcMc18/J6PgYuEVWcPR98hNatu9D3TjSxbdrkCfjStQvx1KsH8c7g70fVcidXV2H+zCloWVmb1/MycImoYux+9yjatu7Ca+8dS2wLnT0RX111Mb74mfmoqqpCy8padGzrxz88swcAUDN1Eu64Zh5aVtZi6uT8IpOBS0Rlb/veAdz74914+8PjiW2zpk7CNz63CLdfPbpMMHVyNb7euCgRuH3fbixYPxi4RFS2nn/zQ3zn8dewf2AosW3O9LOw7oZLcdPSudb7w8AlorLz1O6DuP+JN/Be9PeJbfNmnI37brocqy47z7d+MXCJqGxs3XEA7T97Ex9+dCqxrXb2VNx/8xX4zEWzfOyZg4FLRCXvX17ah7/vfgtHXBcuXHb+OfibW6/Ep+bN8LFnozFwiagkxWIxdG7bix88twcfnRyZVGbphdPRfusSXDrnXB97lxwDl4hKSiwWw//seRv/+5d7MXT6TGL71Qtr0H7rlVg4e5qPvUsvcIErIiEAzebmcgDdqtqZw/4dqtpSjL4RkX+Gh2PY8O+/xf9xTTAjAK69eBY2rL4Sc0NT/O1gFgIXuADWqWpb/IaI9IsIsgldEWkHEC5q7yijBfc8CQDYt+FGn3tC5eDk6WF894nXsWXH7/DxGecC3CoBVl36CaxffSVmTTvL5x5mL1CBa45uvYHZAaANQNrAFZG6InWLiHzw0cmP8Z3HXsNPd76H4ZgTtBNEcMOS8/FXN1+B6XlOJOOHQAWu0SAiYVWNT7keRXZHrfUAugEweIlKWHToNNY9+ir+/bX3YXIW1VWCW+vm4i9vuhxTJgUxtrITqJ6rahSAdwxHI4CedPuJSBOAzXBCl4hK0AfHTuKerbvw3FuHoCZoJ1VX4far5+HeGy7DpOrSn002UIHrZUoMDQBWZWgTVdWoiGR6vGaYE3Lz5uU3zRoRFcaBI0No69qFl/pHpkg8e+IEfPEzC/CNxktQXQZBGxfowAWwCcAaVe1L02ZttqMYTLtOAKivr880/SVRIJTrSchUUyQ2r6zFn3+2FlVV5RO0cYENXBFpBdChqinLCeZEWdpyQ6GU6z96ItuymSKxXAUycE1Nti8etiLSkCJ4a+CcZIvfXg4gbMK6y3XijYh8tn3vAL71k91464PMUySWq8AFrog0wAnSHlOfrYEz8iAevmEAdaraZUK4x7VvM4Cwqm603nEiSur5Nz/EfY+/hn2uKRLPP3cy1t1wGW6+yv4UiX4KVOCagO02Nztcd3W5/r8JzsgF97Z42K7ByBFupxn1QEQ+SDZF4oUzzsZ9f7IYDYvP97Fn/glU4JqATDvUwBy9jjmCdZ8QIyL/BH2KRD8FKnCJqHSVyhSJfmLgEtG4pZwi8YLpWL/6SiyeM93H3gUPA5eIcpZyisQFM9C+ekmgp0j0EwOXiLJWDlMk+omBS0QZldMUiX5i4BJRSuU4RaKfGLhENEY5T5HoJ75qRJRQCVMk+omBS0QVNUWinxi4RBWsEqdI9BMDl6gCVfIUiX5i4BJVEE6R6K+8A1dEFgAYVNVjmdoSkT84RWLuirHYQCGOcCMAOkRkB5zgfbQAj0lEBcApEoOlEIHbpap/VoDHIaIC4RSJwVSoI1wiCoCHX9qH73GKxMAqROAOuG+IyGo4KzIcAdCtqj8vwHMQUQqxWAybfrEXP3h2D45xisRAK0TgjlpuXFW3ishvADwEp7Z7Lk+oERUep0gsPYUI3OUico6qfhTfoKoREelW1X0FeHyiinTi1DA6tvUnbtd992nc8en5uHPFQnz/2T14+KX9OMkpEktKIQJ3DYAmEYnAWUG3F8Az8Bz5AoCIXKWqrxTgOYnK2olTw7jlwRew3zWMa3DoY/yvn+/BPzyzJ7GNUySWlkIEbiecFXYbATQAuA3AdABREVkOZxXeblXdb+5/pQDPSVTWOrb1Y//AUGKS77j4zF0C4PNL53CKxBJTiMDtUNXfAPgNzGq6IrIQTrg2mm3TRSRq2v9tAZ6TqKw97FpRIZnQlIn4/u11FntEhZD3BdMmbL3b9qrqJlVdq6o1AGoAtMAZuUBEaUSHTmPQNawraZvfp7+fgsnKXAqqehRAlznyJaIkDh8/ibauV/HzNz/M2HYGywglyfbkNZ2Wn48o8D44dhKtXbuw7a1DiTPN1VUCheJMkqrC5Ooq3HENJ5opRVYD1xzpEhGAd6POpN8v7BmZ9HvKpAloXhnGnSsWYPVDL405cTa5ugrzZ05By8pafzpNeeH0jESWHRgYwt1bd+JXkcHEtmmTJ+Arn70IX14ZTsxF++OvrEDHtv7EMLCaqZNwxzXz0LKyFlMn80+3FPFdI7Jk76HjaN26Cy/vGzl3fO5Z1fhvqy7CnSsWjpn0e+rkany9cVEicPu+3Wi1v1R4aQNXRB4CUKwTXRHOMkaVYM8HH+EvunbhlQPRxLbQ2RPxtYaL8acreB65kmQ6wm0HECrSc0eL9LhEgfD6waNo69qFV98dmUpkxpSJ+IvPLcJ/+oP5PvaM/JI2cFV1r62OEJWL3e8eRWvXTrx+MDG9CGZNm4TW6y/F2voLfewZ+Y01XKIC2bH/CNY9umvUemGfOGcyvnnDZfjCp7iMDWURuCJybjGemFM2UrnYvncA6x7djf5DI0E7Z/pZ+PaNi3HDkjk+9oyCJtNJs6cBhIvxxCLSr6qfK8ZjE9nw4p7DuPcnr2Lv4ZEZveaGnPXCrruc64XRWJlquNfZ6ghRqXjuzQ/xncd2453BkYUZ59Wcje/efAX+46JP+NgzCjrWcImy1PP6+/jLn76O3x0ZCdqFs6bg/puvwLUXz/axZ1QqGLhEGfzbroO4/8nXcfDoycS22tnT8De3XIE/CM/0sWdUarI5afZHlbwQZKplTnh5Zfl77JV38ddPvjFqqfFF5zkr4C6bzxVwKXfZJEYbgIoM3FTLnHQ8H8HPdr+PH39lBUO3DHX1HsCGn/0Wh4+fTmxbPOccbFi9BEsuCPnXMSp52aRF2LtIZKVItczJqeEY9g8MoWNbP77euMin3lGh/d9fv4O/ffpNDJ4YCdor556L9qYlXGqcCiKbwJ0JZ32ybjiLRPZUykKQ6ZY5OTUcw7+8uJ+BWwZ+9OI+/H33W6NWUbjqwhA2Ni3BJeed42PPqNxk+334KIDrzI+KCOCEbx+cBSJTlhxE5FZVfTTfjvrhSBbLnCy7vxvXXjwLd167kF83S0gsFsMPX9iL7z+zB8dODie2L18wAxtXL8HC2dN87B2Vq2wCt1dVrzPL49TBCd1VcBaIbATQagJ4B5zl0bvNPvEryVoAlGTgzpgyMePaUgMnTuOxV97DY6+8h7MmVmHJBSHcVn8Bbl46F9XVeS8ZRwUWi8Xw0LYIHny2H8dPjQTtNeEaPLB6KS6cOcXH3lG5yyZwI0BiIpu9ALYCgIhMB7AcwGbTpt78tJr7++EEcH3Be23JHZ+ej47nI0nLCpOrq7C6bi5ODcew7e3DOPTRKZz8OIbtewexfe8g7u7ahYWzpuL6K+bgzhULUTONa1D5KRaL4fvP7kHH8xEMnT4DwFlqfMVFM9HetARzQwxaKr6MgauqX06x/SiAHhHpjV+RJiKr4BwBNwD4FICLgMTqISWnZWUtfrb7/ZTLnNx74+LEKIWjQ6fxzy/tw5O7DmLPh8cRU6D/0An84Nk9+MGzezBz6iSWHnwQi8Xwve638Y+/3IvffzwStCsvmY2NTUtw3rln+dtBqiiFGNOUCFRVfQbOUS0AQESaAHQU4Dl8MXVyddbLnEyfMglfW3UJvrbqEsRiMTzx6kH86/YDeOVAFEOnz7D0YNnwcAwPPP1b/OjF/ThpPixFgD9a9Am0N12JWdMYtGRfNhc+XDXeUQmq2iUia3PZR0RCAJrNzeVwTsqlXO031/a5Gs8yJ1VVVbhp6VzctNSZku/1g0fxw1/sTVl6WDBzKv74SpYeCuH0cAwbnnoDD//6HZw2QVslwHWLz8OG1UsQ4vLi5KNsjnDbAeQzq9fLObZfp6pt8Rsi0i8iSBOiuba3bvGc6fje2qsAJC89RA6PLj2suGgW7rx2AZZeyKuZsnXy9DD+6sk38EjvAXx8xvnSVSXADVfOwV9/4QpMZ9BSAGQTuDX5PIGqPpBtW3O06p0OsgPO1W5jAjTX9kGQrPTwyPYD+I2r9PD4zvfw+E6WHrJx8vQw7vvp69i643cYjjlBO0EEf7L0k7j/C5fjnLMm+txDohHZBG6tiByGuegBzoUP+1z3S4H71CAiYVWNmNtRpJ+TN9f2gZGq9PCLtw/jQ5Ye0jp+chj3Pb4bj73yXiJoq6sEt9TNxf+46XJMmcRLril4sv1XWQNgLYA1ACAiUTilgh5kGIWQy+Q3qhoF4P0e3WieJ+/2QZeq9ND/4XGcYekBAPDRyY/xrZ/sxhM7D+KMOv/0Jk4QrFl2Ab7z+cU4qwyDdt+GG/3uAhVILhc+1MG54OE6AMsw+sqzM3AueEh26e+4J78xJYMG87x5txeRZpgTbPPmzRtPl6zxlh7+7dX38f+2v5O29LB22QX4wlXlWXo4OnQa3/zxq3hq9/swB7SYOEFw+9Xz8K0bF2NSGf7OVH6yCdwoAKhqH5xLeR8AAHPlWQOcI8oGjL30N37lWT5f7zcBWGOeO+/25kRaJwDU19eXzPjgqqoqfH7pJ/H5pZ8E4JQe/umX+/D8W4fGlB5at46UHv50xfySH/40ePw07nl0F3re+CARtJOrq3DHNfOx7vpLy/LDhcpXNhc+JB3WZa4822R+vAFch5Erz8YVbCLSCqBDVbMqD+TavpQtnjMdD6xZCiBz6aFm6iRcW4Klh8PHT6K1axeeffMQTOUAZ02swn/59ALcfd0iBi2VpIIVvJIE8HQAtwFYn+tjmQsm+uLhKSIN6YI01/blJFPpYdBdeqiuwpILpmNt/YWBLT18cOwk7u7aiV+8dTjxSX32xAm46z8sxH9vuBhVVcHrM1G2inaGwVz622nCMGsi0gDnJF2PqcnWwDlijodpGECdqnZl076SZCw9DMewfd8RbN93JFF6uP6K8/Glaxf4Xnp4NzqE1i278GL/QCJop0yagC//YS3+/LO1DFoqCzZO6WZbf42f9Oo2N92XBHe5/r8JTtmiK8v2FSub0sODz/Xjwef6fSs9HBgYwje6dmL73sHEtnMmV+O/fvYiNK9cyKClslL0wFXVe3JoG0WGcb2quhHAxmzbkyNopYf+Q8fR2rUTO/ZHR/p4djW+tuoSfOnahQV/PqIgKL9Bi5SRt/Tw24PH8MNf7k1Zepg/cyr+uEClh7c++Ah3d+3EzgNHE9tCUybi642X4D9/ekFej00UdAxcwqVzzh1VevjRS/vxxK73EqWHvZ7Sw4qLZuJLKxbiU/NGlx7SrXC8f+AEWrt2Yfd7xxL310ydhLs/twi3Xx3sMdFEhSKqJTMctaDq6+u1t7c36/YL7nkSQGVd9RMvPfzry++g751oYuLuOHfpoXHxeVjT8dKYuYMnThAIBKfPjGybfc5k3HP9IqxedqG134XIFhHZoapJF15Ie4QrIg8BKFZBLaKqf1akx6YCyKX0kIozc5fzoX7euZNx742XJeaOIKo0mUoK7QBCRXruaJEel4okU+khnWmTq/HrbzZY6CVRcKUNXHMxA9EY06dMwldXXYyvrroYsVgM4W8+lbb9idPDae8nqgQc5Eh5q6qqQs2U9PPOzuAE4ERZLbFzbjGe2LWMOpWBTCsc33ENRyIQZTpp9jSKNJm3iPSraj5L91CAZFrhuGVlrY+9IwqGTDXc62x1hEpbLiscE1Uq/hVQwYxnhWOiSsKTZkREljBwiYgsYeASEVnCwCUisoSBS0RkCQOXiMgSBi4RkSUMXCIiSxi4RESWMHCJiCxh4BIRWcLAJSKyhIFLRGQJA5eIyBIGLhGRJQxcIiJLGLhERJZwxYcs7dtwo99dIKISxyNcIiJLGLhERJYwcImILGHgEhFZwsAlIrKEgUtEZAkDl4jIEgYuEZElDFwiIksYuEREljBwiYgsYeASEVnCwCUisoSBS0RkCQOXiMgSBi4RkSWBm4BcREIAms3N5QC6VbUzwz6tACIAagAgU3siIj8ELnABrFPVtvgNEekXkZQhKiLtAF5W1a74bRFpit8mIgqKQJUUzNFt2LO5A0Db2NYJzZ5wfQRAS4G7RkSUt0AFrtEgIu7QjWJsCAMARKQuyeYogIbCd4uIKD+BKimoahTADM/mRgA9KXapATDo2ea9TUQUCIEKXC9TYmgAsCpFk1C6fU2Au7c1w5yQmzdvXkH6SESUrSCWFNw2AVijqn0p7o/CjExw8d5OUNVOVa1X1frZs2cXqItERNkJbOCaoV4dqpqqnAA45YOQZ1sISJQniIgCI5CBKyJNAPriYSsiSU+CmSPfqGdzDVLXfImIfBO4wDXhWgOgV0RCZsRCnev+sAnkuM2e241whpIREQVKoALXnCTrhhOYR8xPP5wrzuKa4Bpnq6otAMIi0mBOivXzogciCqJAjVIwdVfJ0GYjgI1JthERBVqgjnCJiMoZA5eIyBIGLhGRJQxcIiJLGLhERJYwcImILGHgEhFZwsAlIrKEgUtEZAkDl4jIEgYuEZElDFwiIksYuEREljBwiYgsYeASEVnCwCUisoSBS0RkCQOXiMgSBi4RkSUMXCIiSxi4RESWMHCJiCxh4BIRWVLtdweo/OzbcKPfXSAKJB7hEhFZwsAlIrKEgUtEZAkDl4jIEgYuEZElDFwiIksYuEREljBwiYgsYeASEVkiqup3H3whIocA7Pe7HxVsFoDDfneCkuJ7k5/5qjo72R0VG7jkLxHpVdV6v/tBY/G9KR6WFIiILGHgEhFZwsAlv3T63QFKie9NkbCGS0RkCY9wiYgs4QTkVHAiEgKw1tysBQBVbUvTvhnAMgBbzKY1ANpVNVLEbhIAEelQ1ZYMbVoBRADUAICqsuQwTgxcKoZ2AG2qGgUAEdkhIq2qujHNPmsBNAPoA3AXw7b4RKQdQDiLNi+ralf8tog0xW9TblhSoGKoB9Dguh0BsDzdDqo6Q1VFVZepal9Re0cQkbosmzZ7wvURAGmPiCk1Bi4VnAlN9x9pHYBuv/pDSdUjw3uSIpSjGP1hSjlgSYGKytT/ejLV/UwddxCsExadiDQB2AwndNOpgfOeuHlvUw4YuFQUnhNn/Rma9wKIxuu2IrJFRAZZJyw8875EVTUqIpmah9I9TrxGT9ljSYGKQlWjqtppTpQ1isiWNG37PCfJXgawruidrExrVbUny7ZRmG8cLt7blAMGLhWUiIRMGcGtG0BTmn28NcEInLovFZCpyWYbtoBTPgh5toUA5wO1IJ2qMCwpUKHVA2gXkc5s/ihFJAygW0RmeNpzWFjh1QBocJUSlgMImw/ILu9QPFXtE5FoksfIJbTJhYFLBaWqPSLS5gnPRgCJMbgmZOtUtUtVI0na3wZnLC8VkCklJMLSnKgMu8dHu98bs2mzZ9xtI4AOW30uN5xLgQrO/NHGSwgzAQx4/qhbATSqamOK9v0cpVBcJmzXwPlGsh5ApzmRNuq9MW1b4VyQEgY4giQfDFwiIkt40oyIyBIGLhGRJQxcIiJLGLhERJYwcImILGHgEhFZwsClkiIizWYCFgogvj/pMXArgIiERaRbRI6ISMleJWQG4PcG+Tp+8zr3i8iOcexb8u+TuSiCVwmmwMCtAKoaMVcO9aJEZ3uKT4Yd9NUgzOs8rmkli/k+mUmF+s0VZsXWXqofGMXGwK0spTwhTHuGNdGCJNP8v5mU8vsEMwlONIdlfCoGJ6+hwDMrFHCJnjyYMkytxadcD2cV5sZMDSsJj3CpFLQA4IQpJcS1YnPaVYErDQOXAs2c8Q4H+UQZpZR24vlKxJICAUhM1xeCs6xKCM66VymPKk19zr0OWRjOKg3e5XLy1YAkE16b59sCZ8rAzQDaMLKG2jIAUNUW0zZ+oigEZ9Ltu5IFeK6vgdmnyfQhvk/ak3oi0o6RGu8yAB35ngg0r0UDRhbhDJn+JNaFE5FuOFMxbo6/Lma7mj73YHTtud3bdhz97zGPUyq19+JTVf5UyA+ciaO3JNm+BUCDZ1tTirZNpn2d6/9bze26ZPvk2ed2AK1p7u82v1drku3tcIIo5NreCmBHPq+B5/X0Pm98Sfhkz7EDzuTe8dshOOFVl+Rxs3odzWN0pPh9mpK8Jls8+3Yn2bcVwBH365ZL/z37qO1/50H+YUmhwsXXH1PPwoLqHBmF3cOIzFHtFjhHiH2mzSMAWsztPlVdU+AuhpH+rH0EzpGt90i0G0AzMGb9rT4Ade7B+bm8Bq59muAE9EbPPn1IcpRrjgwj6joaNP3qALApze+XSQOco1mv9Um2eV/HGjjvp7ufYTgfVKO+BeTR/ygvhBjBwKV2OKGZTAdGD2K/DU7JIOra1gMnlLwLQRZKCMkDxS2iY0sEEThHaN5yRDx03ONcc3kN3PukGm87kGRbK5KPtOhBfgtmRgA0e4dgmWDMprTT67m9Bc76Zt7fbbz9H4RZKYJYw61orjPIqf4wIwBCIhJKEmi2hOH80aaTqv/RJNtGPVYer0EYWY63dYVhrTkydquJ90PHUftWZ6HHHgA7RCQC50OgW1V7NENt2Pt88TXOAKwqYP+jKNGLbYqBgVvZ4n9I0QztwnC+Jie+prvUwzm5lOzEVgjO1/016lojy3V/q+u5Q96v55bk+hq4QzrTB4F7XwB4JEUI5jXkTVXXmCC8Dc7702pCeE22H5Tmd+pIsU++/c/2dSp7LClUtvgRSShDuygwsuqriHSYS0Xr4IwOWOXdwdy31jz2mCMcV920U52RAH0pLgeNJtu/gHJ6DYBRR4bZ9qtolyPHw1+dFZDXqOoMOBc41CC3OQ3GlBJcZaJ8+h9C5g+zisHArWCuo5X6FE2SDf3qgBOyDQBqVLUx2VGPOYHWidRf1dfBdWRkwjzZdf6DyByG45bra+ASQZZXbrn2TVrnzvOkUkOS+m0Ezoq8WdXVXaWEuzx3xeevyKf/NeMplZQrBi5thPPHmcxtcMI1rg7OV/+oOaIaU0bIhgnupHXhJCffEstzF1Eur0FcfMhZMssx9kOizTxWMvlOKDPmcU3IRTPt6PoQTTY22f2BMt7+hzL1oZIwcCucqrYBiWFOCeYr/6COHvjfB2BdAS7XTLV/FGP/QF9G5iNJ7z5A6q/7Y7bn+BrE9+kEMOgdMmY+MELe5zH16cF4KcXVvg75T1bT7H1PTD+SlWhCnttJRyWY/RO/w3j6b+4L9Oxutok6g5OpjJmvfJswckTWA6DN/VXP9YcURZqrrEyd1XtEE3E9ZtTTvgnAOlVd5trWAOdMunja9sOZFazTtS0E5yKCWk/bMEZfXZb4nUwfG+AEexeckz1d5ne8DeaKOLN9o+sxs3oNPP2I7xOBE1AR8/jt5jna3N8ETPuZcEY4DMJ1wjGb9ynJ8zeZ54wHrjskO9M87l3m9hY4R/jxERchOEfoTQA6deyVZin7n6RvzQBq4x9oxMClHJjB7y97TqyE4PyxtwBYa07auPfJJXCPwAmYTs/2buRwxp2CwbxvLazhjmBJgbJigjPk/epp6rl95kioN8k4zWRS/QGGUtyX7KiaAsx8EPOEmQcDl7JVgyyu+MrmgXRkguoxtdxkX09NyHNe1dKyDmNHPVQ8Bi5lxXzNr0t1Ca/ZHk5ySWiqk1fr4TrLb46M09VL27wnbCiYzNFtKNOVbpWINVzKiWcKw7gQnPkM3LXdMJwTL/GTVBsxtv7bipGj4uWZTq6Y9hkvWSV/iUiH92QbORi4VFJM4G/mCbRg4vuTHgOXiMgS1nCJiCxh4BIRWcLAJSKyhIFLRGQJA5eIyBIGLhGRJf8fub11r2fMPYEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = model_size.log10()\n",
    "y_mean = hessian_fro.mean(dim=1)\n",
    "y_std = hessian_fro.std(dim=1)\n",
    "plt.plot(x,y_mean,'-o', color='C0', markersize=8)\n",
    "plt.errorbar(x,y_mean, yerr=y_std, label=r'$\\|H\\|_F$')\n",
    "\n",
    "\n",
    "plt.xlabel(r'$\\log_{10}$(model size)')\n",
    "plt.ylabel(r'$\\|H\\|_F$')\n",
    "# plt.legend()\n",
    "plt.title('Linear net')\n",
    "plt.savefig('../figs/hessian_fro_linear_net_modelsize.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e6f602b-de4d-4000-879a-114bad43d670",
   "metadata": {},
   "source": [
    "---\n",
    "# Fully-connected networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "f37839fb-c392-4134-96c2-b6074f89d7e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 500\n",
    "d = 28*28\n",
    "\n",
    "train_loader, test_loader = data.load_mnist(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "01b9dbc6-a625-4f6d-81d6-6b666e341d58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1, loss_tr: 7.9e-01\n",
      "5001, loss_tr: 1.6e-04\n",
      "10001, loss_tr: 2.0e-06\n",
      "15001, loss_tr: 2.5e-04\n",
      "train error 0.0042374138647574\n",
      "test error 0.10372681757056852\n",
      "(0, 0)-> took 6.1 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 1.1e+00\n",
      "5001, loss_tr: 3.8e-04\n",
      "10001, loss_tr: 1.2e-06\n",
      "15001, loss_tr: 5.7e-08\n",
      "train error 1.7918618340218017e-07\n",
      "test error 0.1315695960447192\n",
      "(0, 1)-> took 6.1 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 1.2e+00\n",
      "5001, loss_tr: 3.2e-06\n",
      "10001, loss_tr: 3.2e-03\n",
      "15001, loss_tr: 3.6e-08\n",
      "train error 0.0060125723062810724\n",
      "test error 0.10401776490441988\n",
      "(0, 2)-> took 6.0 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 2.6e-01\n",
      "5001, loss_tr: 1.8e-06\n",
      "10001, loss_tr: 7.9e-06\n",
      "15001, loss_tr: 2.6e-05\n",
      "train error 0.003968279832042754\n",
      "test error 0.09879927735804812\n",
      "(0, 3)-> took 6.1 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 2.2e-01\n",
      "5001, loss_tr: 8.8e-06\n",
      "10001, loss_tr: 7.7e-06\n",
      "15001, loss_tr: 2.6e-05\n",
      "train error 5.576228994641497e-06\n",
      "test error 0.09652764024736825\n",
      "(0, 4)-> took 6.0 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.0e-01\n",
      "5001, loss_tr: 7.9e-07\n",
      "10001, loss_tr: 2.0e-07\n",
      "15001, loss_tr: 2.2e-06\n",
      "train error 2.4241772280220173e-06\n",
      "test error 0.08578911185264587\n",
      "(1, 0)-> took 6.8 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 4.2e-01\n",
      "5001, loss_tr: 1.7e-04\n",
      "10001, loss_tr: 5.8e-04\n",
      "15001, loss_tr: 2.7e-06\n",
      "train error 2.5186150423905928e-06\n",
      "test error 0.09182244957017247\n",
      "(1, 1)-> took 6.8 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 5.3e-01\n",
      "5001, loss_tr: 1.5e-05\n",
      "10001, loss_tr: 4.8e-06\n",
      "15001, loss_tr: 1.8e-06\n",
      "train error 4.5756680265185425e-06\n",
      "test error 0.10068266286049038\n",
      "(1, 2)-> took 6.8 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 2.6e-01\n",
      "5001, loss_tr: 2.9e-06\n",
      "10001, loss_tr: 6.5e-07\n",
      "15001, loss_tr: 9.4e-07\n",
      "train error 1.7654282146395418e-06\n",
      "test error 0.0761227853957098\n",
      "(1, 3)-> took 6.8 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.5e-01\n",
      "5001, loss_tr: 8.5e-05\n",
      "10001, loss_tr: 3.5e-06\n",
      "15001, loss_tr: 4.2e-07\n",
      "train error 1.8034077129414072e-06\n",
      "test error 0.09751759295293595\n",
      "(1, 4)-> took 6.7 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 4.3e-01\n",
      "5001, loss_tr: 4.3e-05\n",
      "10001, loss_tr: 3.2e-05\n",
      "15001, loss_tr: 4.4e-06\n",
      "train error 3.827696173175355e-06\n",
      "test error 0.083276548760914\n",
      "(2, 0)-> took 8.4 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 1.2e+00\n",
      "5001, loss_tr: 3.5e-05\n",
      "10001, loss_tr: 8.4e-06\n",
      "15001, loss_tr: 5.4e-07\n",
      "train error 2.3936329171192484e-06\n",
      "test error 0.08005727534007746\n",
      "(2, 1)-> took 8.4 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 1.6e-04\n",
      "5001, loss_tr: 2.7e-05\n",
      "10001, loss_tr: 1.2e-05\n",
      "15001, loss_tr: 5.6e-06\n",
      "train error 2.6244122636853716e-06\n",
      "test error 0.07952072160434909\n",
      "(2, 2)-> took 8.4 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 1.1e+00\n",
      "5001, loss_tr: 1.4e-05\n",
      "10001, loss_tr: 4.1e-06\n",
      "15001, loss_tr: 1.2e-06\n",
      "train error 1.469277094656718e-06\n",
      "test error 0.08362035720812855\n",
      "(2, 3)-> took 8.5 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 2.8e-01\n",
      "5001, loss_tr: 4.6e-06\n",
      "10001, loss_tr: 1.8e-05\n",
      "15001, loss_tr: 2.3e-06\n",
      "train error 2.2104827849034336e-06\n",
      "test error 0.08336481946753338\n",
      "(2, 4)-> took 8.5 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.6e-01\n",
      "5001, loss_tr: 1.7e-05\n",
      "10001, loss_tr: 3.0e-06\n",
      "15001, loss_tr: 4.8e-06\n",
      "train error 4.001610932391486e-06\n",
      "test error 0.0756818820250919\n",
      "(3, 0)-> took 14.0 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 8.5e-01\n",
      "5001, loss_tr: 9.8e-07\n",
      "10001, loss_tr: 1.1e-05\n",
      "15001, loss_tr: 4.3e-06\n",
      "train error 2.740742775131366e-06\n",
      "test error 0.07877223661782409\n",
      "(3, 1)-> took 13.9 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 1.4e-02\n",
      "5001, loss_tr: 8.5e-06\n",
      "10001, loss_tr: 1.3e-07\n",
      "15001, loss_tr: 2.2e-06\n",
      "train error 1.811734409784549e-06\n",
      "test error 0.0855066563881337\n",
      "(3, 2)-> took 13.9 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 7.5e-01\n",
      "5001, loss_tr: 9.1e-06\n",
      "10001, loss_tr: 6.5e-06\n",
      "15001, loss_tr: 4.5e-06\n",
      "train error 4.084530792169971e-06\n",
      "test error 0.07063384105172191\n",
      "(3, 3)-> took 14.2 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 2.8e-01\n",
      "5001, loss_tr: 2.5e-05\n",
      "10001, loss_tr: 2.4e-06\n",
      "15001, loss_tr: 1.6e-06\n",
      "train error 2.922173553088214e-06\n",
      "test error 0.06944931099737005\n",
      "(3, 4)-> took 14.1 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 6.4e-01\n",
      "5001, loss_tr: 8.7e-06\n",
      "10001, loss_tr: 1.0e-06\n",
      "15001, loss_tr: 1.4e-06\n",
      "train error 3.182188174832845e-06\n",
      "test error 0.06912924969929009\n",
      "(4, 0)-> took 18.3 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.0e-01\n",
      "5001, loss_tr: 6.7e-06\n",
      "10001, loss_tr: 2.4e-06\n",
      "15001, loss_tr: 8.7e-06\n",
      "train error 1.7927198314282578e-06\n",
      "test error 0.07050567884522024\n",
      "(4, 1)-> took 17.3 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 1.1e+00\n",
      "5001, loss_tr: 9.2e-06\n",
      "10001, loss_tr: 9.3e-06\n",
      "15001, loss_tr: 1.8e-06\n",
      "train error 2.128914593413356e-06\n",
      "test error 0.0755820055591903\n",
      "(4, 2)-> took 18.6 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 1.1e+00\n",
      "5001, loss_tr: 2.5e-05\n",
      "10001, loss_tr: 1.9e-06\n",
      "15001, loss_tr: 4.1e-06\n",
      "train error 1.5643743154214462e-06\n",
      "test error 0.06849744356652082\n",
      "(4, 3)-> took 17.9 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.2e-01\n",
      "5001, loss_tr: 5.4e-05\n",
      "10001, loss_tr: 1.5e-06\n",
      "15001, loss_tr: 1.1e-07\n",
      "train error 2.5077759801206413e-06\n",
      "test error 0.07431668295437703\n",
      "(4, 4)-> took 17.3 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 2.7e-03\n",
      "5001, loss_tr: 1.8e-05\n",
      "10001, loss_tr: 1.2e-05\n",
      "15001, loss_tr: 2.5e-06\n",
      "train error 1.713665642455453e-06\n",
      "test error 0.06625859521707753\n",
      "(5, 0)-> took 33.5 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.1e-01\n",
      "5001, loss_tr: 4.9e-05\n",
      "10001, loss_tr: 4.9e-06\n",
      "15001, loss_tr: 1.5e-06\n",
      "train error 7.320251029341307e-07\n",
      "test error 0.06953134578550817\n",
      "(5, 1)-> took 33.4 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 6.9e-03\n",
      "5001, loss_tr: 4.1e-06\n",
      "10001, loss_tr: 3.0e-07\n",
      "15001, loss_tr: 2.6e-08\n",
      "train error 5.95947562942456e-07\n",
      "test error 0.06513986085970828\n",
      "(5, 2)-> took 34.2 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 2.1e-04\n",
      "5001, loss_tr: 2.0e-05\n",
      "10001, loss_tr: 5.1e-06\n",
      "15001, loss_tr: 1.3e-06\n",
      "train error 1.1615659104791121e-06\n",
      "test error 0.06934031586701167\n",
      "(5, 3)-> took 33.3 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.3e-01\n",
      "5001, loss_tr: 9.5e-06\n",
      "10001, loss_tr: 1.3e-07\n",
      "15001, loss_tr: 1.5e-06\n",
      "train error 7.863453902245965e-07\n",
      "test error 0.06501565858281538\n",
      "(5, 4)-> took 34.0 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 7.0e-04\n",
      "5001, loss_tr: 7.8e-06\n",
      "10001, loss_tr: 2.3e-07\n",
      "15001, loss_tr: 2.3e-07\n",
      "train error 3.5652781207318183e-07\n",
      "test error 0.06719566651940113\n",
      "(6, 0)-> took 70.4 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.5e-01\n",
      "5001, loss_tr: 5.9e-06\n",
      "10001, loss_tr: 5.3e-07\n",
      "15001, loss_tr: 1.6e-06\n",
      "train error 4.107417623799847e-07\n",
      "test error 0.07070904229967709\n",
      "(6, 1)-> took 68.4 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 5.8e-01\n",
      "5001, loss_tr: 9.0e-06\n",
      "10001, loss_tr: 2.0e-06\n",
      "15001, loss_tr: 3.2e-07\n",
      "train error 2.031745339081681e-07\n",
      "test error 0.06413690404195221\n",
      "(6, 2)-> took 68.5 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 3.1e-01\n",
      "5001, loss_tr: 9.7e-06\n",
      "10001, loss_tr: 2.6e-06\n",
      "15001, loss_tr: 2.4e-07\n",
      "train error 3.4326398008488467e-07\n",
      "test error 0.0722404130815994\n",
      "(6, 3)-> took 68.4 seconds, hessian: 2.2e+00\n",
      "1, loss_tr: 6.2e-01\n",
      "5001, loss_tr: 4.2e-06\n",
      "10001, loss_tr: 9.1e-08\n",
      "15001, loss_tr: 1.2e-07\n",
      "train error 1.8992479056123557e-07\n",
      "test error 0.06715589981657104\n",
      "(6, 4)-> took 68.4 seconds, hessian: 2.2e+00\n"
     ]
    }
   ],
   "source": [
    "widths = [10, 20, 40, 80, 160]\n",
    "learning_rates = [0.01, 0.005, 0.003, 0.001, 0.0007]\n",
    "ntries = 5\n",
    "\n",
    "model_size = torch.zeros(len(widths))\n",
    "hessian_fro = torch.zeros(len(widths), ntries)\n",
    "mu = torch.zeros(len(widths), ntries)\n",
    "bounds = torch.zeros(len(widths), ntries)\n",
    "alphas = torch.zeros(len(widths), ntries)\n",
    "\n",
    "for i, m in enumerate(widths):\n",
    "    lr = 0.2\n",
    "    for j in range(ntries):\n",
    "        time_st = time.time()\n",
    "        \n",
    "        net = models.build_linear_net(d, m)\n",
    "        train_net(net, train_loader, nsteps=5000, lr=lr, display=40000)\n",
    "        H_i, mu_i, alpha_i, bound_i = compute_geoinfo(net, train_loader, batch_size, lr)\n",
    "        hessian_fro[i,j], mu[i,j], alphas[i,j], bounds[i,j] = H_i, mu_i, alpha_i, bound_i\n",
    "        \n",
    "        print('({:}, {:})-> took {:.1f} seconds, hessian: {:.1e}'.format(i,j, time.time()-time_st, H_i))\n",
    "    model_size[i] = utils.num_para(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "c16285cc-41bf-4bb3-a6c3-8a372307f39e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {'model_size':model_size, \n",
    "        'hessian_fro': hessian_fro, \n",
    "        'mu': mu, \n",
    "        'batch_size': 3, \n",
    "       'learning_rate':0.2}\n",
    "torch.save(data, 'fcn-para.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "id": "2a024abd-504d-4bc9-af4a-4d7bdfc561d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAFkCAYAAAAHV825AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAy+UlEQVR4nO3deXhT550v8O/PLGY1wthgCIuRAyGQhMSYsGXHTrqvJkmbZiENdtKZp23aGQjTmTvTuXduapK2k7mdmdi0IUmzDIGkTdKmCyYLCRDAdkI2lmCxhdWAhc0O9u/+cV7Jx7JkS/aRjmR/P8+jxzqLjl4dyfrqnHc5oqogIiLqqjS3C0BERN0DA4WIiBzBQCEiIkcwUIiIyBEMFCIicgQDhYiIHMFAIYqCiBSLyCoRqQ9z0zA3b4TteERkodlWrW39WjNvoYjkm3VLbM+pIrIqinKWi0i1rWyrRKTY6f1BFI6wHwpR9MyX8woAflUdGrLMCyAfwApVlZBlHgBlAEoArASwHECNqvrMMi+A2wAsBAD740Wk3DwOAJao6qIoylkNYLmqLunEyyTqlN5uF4AoxfgjLVBVHwCfiPjs803QVJvJIlWtDHmcH0ANgBoTHqFHIrXmeT0AForIJlVd2UE5fe2VlSgeeMqLKDbHolgneARhjj6qYYXBvNAwCWVCyR9m0QsAKsz9FYHTYg6UlcgxDBQih4UcPSyFFSYrOwoTmwURtlsK60gGsELF09kyEsUDA4XIQSKywnbfAyBQIV4e7TZUtaadxXNhHcF4YdXlECUNBgqRQ0yA2E9FFdru++AAU98yN7B9ESlzYrtETmCgEHWOJ7SpMIB6AJm2dYL3Td2II8wRTKmZXMhmwZQsGChEneNXVbHfAOShdUW4J15PrqoVsJofA9FX0hPFFQOFyCGBZsO2WcH7kTo6dvH55qGlkn41K+nJbQwUImeV2u7bW3XF67RUoJLeA2B1nJ6DKCoMFCIH2etKTAV6IFTK4nSU4kdLJX2+6RhJ5AoGClF82Y9Y4tLMN6SSvoSV9OQWBgpRbDI7XqWFOWKZZybzzSCQ7Vagi0hhrPUhoZX0aN18mSghOJYXUWw8IX87pKorRSQP1uCQxQCqRaQS1hd/4BSZB8B0s9ynqkW2TeTB6sjY0fPMM4NC5kezPpHTONowURTMaaRStO2s6IM1Rpc/yu3k27aTCStI/GY7VQDKAz3lRaTErBs42qgBUNneaMPmyGan2W6pOXIhSggGClE3ZILL72SHSqKOMFCIiMgRrJQnIiJHMFCIiMgRDBQiInIEA4WIiBzRY/qhZGVlaW5urtvFICJKKdXV1UdUNTuadXtMoOTm5qKqqsrtYhARpRQR2R3tujzlRUREjmCgEBGRIxgoRETkCAYKERE5goFCRESOYKAQEZEjGChEROQIBgoRETmCgUJERI5goBARkSMYKERE5IgeM5YXkdt+uWp78P6DRRNdLAlRfDBQiBLksdWfBu8zUKg74ikvIiJyBAOFiIgcwUAhIiJHMFCIiMgRSRcoIlIuIvXmVhZhnWIRKTO3FSLiSXAxiYgoRFK18hKRYgC1AMYDKASwQkQ2qepK2zpeAItVdZqZXghgMYBFLhSZiIiMpAoUAH5VXWLurxSRGgDekHW8ADz2xwDwxb9oRETUnqQ65aWqlSGzPABqwq0jIrUiUgLAYz+CsROREhGpEpGqurq6eBSZiIiMpAoUOxHJB+ALEzIAUArrqKQMQKk5DdaGqlaoaoGqFmRnZ8extERElJSBYirZF6tqUZhl+QBKVbVIVYfCOoIpT3ARiYgoRFIGCqwjjwURlt0GYLltehHa1rMQEVGCJV2giEg5rECBiHgCp7NMCzDAagVmP3LxAghbh0JERImTVK28TJiUmFurRQCWiohPVStMX5WFsFp4eVSVTYaJiFyWVIGiqqWwKtzDLRsash4RESWRpDvlRUREqYmBQkREjkiqU17kPF4lkIgShYHSzfEqgUSUKDzlRUREjmCgEBGRIxgoRETkCAYKERE5goFCRESOYKAQEZEjGChEROQIBgoRETmCgUJERI5goBARkSMYKERE5AgGChEROYKBQkREjmCgEBGRIxgoRETkCAYKERE5goFCRESO4BUbiajb4aWv3cFAIaJuh5e+dgdPeRERkSN4hELUQ/A0EMUbA4Woh+BpIIo3nvIiIiJHMFCIiMgRDBQiInIE61AoqbEimSh1MFAoqbEimSh18JQXERE5IukCRUTKRaTe3Mo6WLdERBaKiCdBxSMiogiS6pSXiBQDqAUwHkAhgBUisklVV4ZZdwWAh1W1JsHFJCKiMJIqUAD4VXWJub9SRGoAeENXYpgQESWfpDrlpaqVIbM8AFqFhogUAsgH4BWRVeYUmScxJSQiokiSKlDsRCQfgC9MyBQByATgBzAP1hHMigjbKBGRKhGpqquri2dxiYh6vKQMFHPEsVhVi8Is9gCoVNVKVfUDKINV39KGqlaoaoGqFmRnZ8eruEREhCQNFFghsSDCstqQaR+soxUiInJR0gWKiJTDChSIiEdEvOZ+sVllJaw6lIBCABUJLSQREbWRVK28TJiUmFurRQCWiohPVWtEZJFZtxqAR1UXJbqsRETUWlIFiqqWAiiNsGyo7f5KWEcqRESUJJLulBcREaUmBgoRETmCgUJERI5goBARkSMYKERE5AgGChEROYKBQkREjmCgEBGRIxgoRETkCAYKERE5goFCRESOYKAQEZEjGChEROQIBgoRETmCgUJERI5goBARkSMYKERE5AgGChEROYKBQkREjmCgEBGRI3q7XQCinuBCU3Or6WVrd+KSnMG4NCcDQwf2dalUqelCUzOOnDiHw41ncKjhbPBvnW3absfhRlw8fLBLpe1ZGChEcXbuQjN+8D/vtZr301c/Cd4fkZGOSTkZmDTSCphJIwfDmzUIfXv3rBMI55uaceTEWSsUGs7gcGPL30PBv2dx9ORZqEa/3cJfrEHhpcOx4Fovrh6fCRGJ34vo4RgoRHF05nwTHnimGm9sq4u4zqGGszjUUIe3tres06eXIC97EC4dmYFJOYMxaWQGLs0ZjOzB6Sn3hXi+qRl1tlAIFxJ1jWdw9OS5mIIiFpVbDqNyy2FMHeNB6XVe3DIlB73SUms/pgIGClGcnDp3AQuersLaHUfbLJs3bTS2HmzEtkONOHehuc3y802KrQcbsfVgY6v5Qwf0aXM0M3HEYPTr0yturyOScxeaUXfCBIM51XS4wR4UZ1DXeBZHT55z9HlFgGED+yJ7cD+MyEjH8MHpGJHRD8MHp2O4+fv1/1oX9rGb9/rxvWdrMG7YANx3zXgUTxuD/n0Tv++6KwYKURw0nDmPe5dtQtXu+rDLH5k3FYBVH7Dr6ClsPdiArQcasfVgA7YcaMQ+/+mwj6s/dR7rfUex3tcSUmkC5GYNtALGHM1MyhmM0UP7d+po5uyFJhMQbY8m7POOxSUo0iOGxIiMfhiekY6sQeno0yv604GVP7oev37bh5dq9uGcqcvaffQU/unlj/HLyk9x58xxuGvWOAwblO7o60kWv1y1PXj/waKJcX0uBgqRw/ynzuGuJzbig8+OB+f9/S2X4JG/bGuzbu9eabh4+CBcPHwQvnRFy/yGM+ex/WAjthxsxNYDDdbRyoEGnDzX1GYbzQr46k7CV3cSf/zwQHD+oPTeuCRncDBk7F778AAON5zBocazrY8uGs/Af+q8A3uhRZoAwwYFgsI6qmg5umj5mzWoL3rHEBTRunj4IPzsm1fgRzdPxFPrduG363ej4cwFAMCxk+fw2OpP8fhbtZhXMBr3XeNFbtZAx8vgpsdWfxq8z0AhSiF1jWdx5282tDpV9U9fmozvXjM+bKBEktGvDwpyM1GQmxmc19ys2Oc/jS2BgDlo/d155GTYuocTZy+genc9qsMcJX3v2ZrYXlgYaQJkDw4fEvYjimED4xMUsRo+uB/+/pZJ+N4NF2P5pr34zTs7g0eCZy8045l39+DZDXvwuSk5WHCdF/ljh7pc4tTDQCFyyIHjp3HHrzfAV3cSgHUK59++djm+PWOsI9tPSxOMyRyAMZkDcPOUnOD80+ea8OnhRmw90IgttlNn9Z080uiVJsgelI7hgXDISMeIwF/bvGED01OyYntgem/ce8143DVrHF776CAq1tTio30NAABV4E8fHcSfPjqI6blDUXJdHuZOGo60FHydbmCgEDlg77FT+Pav38XeY9Yv3jQBHp03Fd/IHx335+7ftxeuGO3BFaM9wXmqisONZ4OnyrYebMTv3tsXXF40eUSreooRGf2Qbf5mDuybkkERq9690vCVqaPw5StGYn3tUZSv8bVqabdpVz027apCXvZALLjWi69ddZErjR9SCQOFqIt8dSdwx6834MBxq0Nd7zTBY7dfhS9eMdK1MokIRmT0w4iMfrh+YjYAtAqUpXcVuFW0pCMimH1xFmZfnIWtBxtQscaHV97fjwvN1nnE2rqTeOilD/HoX7dj/pxcfGfGOAwZ0MflUicn909sEqWwbQcbcWv5u8Ew6ds7DeV3TnM1TKjzJuVk4Be3Xom3F92Ikuu8GJTe8pv7yImzeOQv2zDrZ6vx01c/xmf1p1wsaXJioBB10kf7juP2ivU4cuIsAKB/n1544u7pmHvpCJdLRl01ckh//MMXLsW6xTdh8ecnISejX3DZqXNNWLZ2F65/5E18//n38NG+4+1sqWfhKS+iTqjeXY97lm1Eo2l+Oii9N5bNn47ptlZZlPoy+vVB6fV5mD9nPF7ZvB9L1/iw7ZDVgq+pWfHK5v14ZfN+zLl4GBZc68X1E7NTbiQDJyXdEYqIlItIvbmVdbBusYisSFTZiABgXe0R3PmbDcEwGdK/D569bwbDpBvr2zsNxdNG488/vBZPzp+O2XnDWi1fu+Mo7lm2CZ9/7G28WP1Z2NEPeoKkChQRKQZQC2A8gAUAFpp54db1AFgMwJOo8hG9ue0w5i/bhFOmg+GwgX3x/IKZmDrG427BKCFEBDdcMhzPLZiJV//2Gnx56ijYG8RtPdiIH6/YjOuWvIGKNbVoPONsJ9Fkl1SBAsCvqktU1a+qKwHUAPBGWLcEQDkAf6IKRz3bXz4+iAVPV+Gs+fU5IiMdy0tnYvKojA4eSd3R5aOH4P996yq89fc34p7Zuehva1J8sOEM/u9rWzH74dfx8GtbcPD4mXa21H0kVaCoamXILA+sUGlFRPLDzQ+zXomIVIlIVV1d5NFeiTry8vv78L1na3C+yWpKepGnP14oncXrbBDGZA7Av3xlCtYvvgl/d/NEZA1qub5N49kLKF/jw7VLXsePX9iMbSGDfXY3SRUodiY0fGFCBgAKI8xvRVUrVLVAVQuys7OdLyT1CC9s2osfLn8fTaZfQu6wAXjh/lkYN6x7jflEXeMZ0Bd/e9MEvLPoJjz8jcvhzW75fJxvUrxY8xlu+fc1uPuJjVi34wg0XmP1uygpW3kF6kdUtSjMsmIAKxNeKOqRnlq3C//8ysfB6QnDB+HZ+2ZguK0ZKZFdvz698K2rx+K2gjGo3HIIFWt8rUadfmu7de2byy7KQMl1efjCZTlJMdaZE5IyUACUwaqUD6cUgNc0zcsE4BGRWlXNS1ThqGd4/K1a/OxPW4PTk0dm4Jn7ZiCTl+ylKKSlCW6ekoObp+Sgenc9lq7x4S+fHAwO5PnRvgZ8//n3sGRof3z3mvG4tWAMBqYn61dydByJRRH5hohcKSJdrp0UkXJYgQIR8YiI19wvBgBVLVLVPBMgiwCsZJiQk1QVv1y1vVWYXDnGg+cXzGSYUKdMGzcUj985Da//+AbcMWMs0m2Xd/6s/jR++uonmP2z1/HoX7ahrvGsiyXtGqeOs1YCKETkFllRMWFSAqvpcL251ZrFS029SmDdElhHK4XmPlGXqSp+9qetra4hMWN8Jp65bwbHb6IuG581EP/29cux9qGb8P25EzDU9pk6fvo8fvXGDswpex2LX/oAtXUnXCxp5zgVKDWq+qiqvh+YISJDRGRTLBtR1VJVldCbWTZUVWts61ao6jQzv8Kh10E9WHOz4p9f+Rjla3zBeddOyMKT869uNaYTUVdlDUrHj4omYt1Dc/GvX52CsZkDgsvOXWjG8xv3Yu7P38J9T1WhatcxF0saG6f+S9pcNFtVj0tPHoOAUkpTs+KhFz/AiurPgvOKJo/Ar759FdJ7c8hyio/+fXvhrlm5uGPGOPzZXJtls+1Kn5VbDqFyyyHkj/Wg5DoviibnJPWlBeL9syt1opV6rPNNzfjRC5vx6ub9wXlfnjoKv7h1akzXLifqrF5pgi9eMRJfuDwHG3ceQ8UaH1ZvPRxcXrPHj/ufqcH4rIH47jXjUTxtdFJem8Wp/5Y8EfmxiNwUUjHfpqG1iNzk0HMSddnZC0343rM1rcJk3rTR+PfbrmSYUMKJCGZ4h+E390zHqgevw60Fo9HX9jnceeQk/vH3H2HOz17HY5Wf4tjJcy6Wti2n/mMyAfwEQCWAehFpMvUnXhH5bkjQtOlbQvHRHTtOOen0uSYseLoaqz45FJx358xxKPvmFUl9WoF6hgkjBmNJ8VS8s+hGPHBDHgb3azmhdPTkOfyycjtm/2w1/tfLH2HP0eS4NotTgVKhqpmqmgZgOoAHAFQD2AngEdiCBsBCh56TOvDjFZtbTZ853+RSSZLPibMXMP/JjVhju+RryXVe/OtXp/D64ZRUhmf0w6LPTcL6xXPxj1+8FKOGtHSqPXO+GU+v340bHn0Df/NsDTbv9btXUDhXh7I8cMe0xGo1zpaIDAFQAGAarL4jFGcf7z+Ol2r2tZp37ZI3rEuYzhyHjH49twns8dPncc+yjXhvjz847wdzJ+CHhRN69LUsKLkNSu+N+6714u7ZufjjBwdQvsaHLQcaAADNCvzxwwP444cHcPX4TNx/vRc3TBye8B9HjgSKqr7XwfLjAFYDWC0ivGhEAjy5dlebeXWNZ7Hkz9vwX2/U4o4ZY3HvNeMxoocNIXLs5Dnc+ZsN+Hh/Q3DeQ5+fhPuvZ99YSg19eqXha1ddhK9eOQrv7DiCijU+vP3pkeDyjTuPYePOY5gwfBAWXNulroExS3jjelV9KNHP2dMcPXEWL9sqmUOdMCOgPrF2J75+1UUouS4PFw8flMASuuNwwxl85zcbsP1QS4exn35lCu6eneteoYg6SURw7YRsXDshG5/sb8DSt314dfN+XDCDmH56+AQWvvhBQsvEZizd0HMb9oS9YtyS4iuQFzIC6gtVn6HwF29hwdNVqLYNYNfd7Pefxm0V7wbDRARY8s0rGCbULUwelYFf3nYl1iy8EfddMx4D+7rTpJiB0s2cu9CMp9/dHXbZrQVjsOrB67H0rgJMGze01bJVnxzCN/97HW59fD1WbzmE5ubu00Js99GTmPf4euw8chKA1eb/32+7ErdOH+NyyYicNcrTH//4pclYt3guFn1uEoYPTk/o8zNQupnXPjwQHFwu3IcpLU1QNHkEXnxgNlbePwuFlw5vtXzjrmP47lNV+Nxja7CyG1wbe8fhE7i1fD32+U8DAPr0Evznt/Px1SsvcrlkRPEzpH8fPHBDHt5edGNCn7fdOhQReRzW9d3jwaeqD8Rp2z2SqmLZ2p3B6TtnjsPPV22PuH5BbiZ+nZuJTw81onyNDy+/vy94RcLth07g71Zsxs//ug3fvWY8br96bMqNZ7XlQAO+8+sNOGo6f6X3TsPjd07DjZcM7+CRRN1DoocN6ugbogzWZXjjwR+n7fZYNXv8wXGA+vZOw7dnjG03UAImjBiMR+dNxY9vnogn3tmJ5zbswclzVp+VA8fP4P/8cQv+Y/WnuGtWLu6Zk4usQYk9jO6MzXv9uOuJjTh++jwAYEDfXvj13QWYnZflcsmIuq92A0VVd7a3nJKL/ejka1eOwrAYv/hHDumPn3xxMv72xgl4ZsNuLFu7E0dOWL/uG85cwK/e2IGlb/tQPG00Sq7zJu0lcDftOob5yzbhxNkLAIDB6b3x5L3TMW0cW6wTxRPrULqJ/f7T+NNHB4PT8+d0/kzlkAF98Dc3Xox3Ft2Ef/v6Zcgd1jK09tkLzXh2wx7c+Oib+Jtna/ChbWTUZPDOp0dw1282BsPEM6APnlswk2FClAAdnhR34iqM4ahqQ8drUbR+++5uNJmWWTO9mbh0ZNfftn59euGOGeNw+/Sx+MvHB/H4W7X4wASIvWfunIuH4f7r83DNxVmu9jR/fesh3P9MTbAhQdagdDx73wxckjPYtTIR9SQdVcr/FV28CmM7265V1Vvise2e5vS5Jjy/cU9w+t4uHJ2E0ytN8IXLR+Lzl+Vgve8oHn/L12oMrLU7jmLtjqOYMioDpdfn4QuX5aB3gkfqfe3DA/j+8+8FO3XlZPTDswtmIC+7+3fYJEoWHdWh3JyoglDn/f79ffCfsiqfx2T2x9xLR8TleUQEs/OyMDsvCx/vP46KNT784YMDwSOjj/c34PvPv4dHMvtjwbVezJs2Bv0T0MHqd+99hh+/sBmBrjNjMvvjuftmYoztKnhEFH+sQ0lxoU2F756Vm5Ch16eMGoLHbr8Kb/7dDbh71jj069PyUdp77DT+18sfY06Zdc2G+jhes+G5DXvwI1uYeLMG4oXSWQwTIhcwUFLc2h1Hg8OJDOzbK+G9v8dkDsBPv3oZ1j00Fz+YOwGeAS2jGB8LXrPhdfzLKx/js3pnr9nwxDs78Q+/+xCBy75MyhmM5aWzMHJIf0efh4iiE1WgiEhunMtBnWQ/OimeNtq1YekzB/bFg0UTse6hm/AvX56MizwtX+qnzzfhyXW7cP0jb+LB5e8Hh9zuiv98Ywf+9Q+fBKevGD0Ezy+YiewEDzVBRC2i7fpcDoAV6Elm15GTeH1by3Wn73G4Mr4zBvTtjXvmjMcdM8fhtQ8P4L/frMXWg40AgKZmxe/e24ffvbcPN1ySjfuvz8OM8ZkxtQxTVfz8r9vxqzd2BOdNGzcUy+ZP79HXeKHWfjB3gttF6JGiDZTpInIjgOpYm/uKSAabCMfHk+t2BU/33DRpOMZnJU9Hwz690vDVKy/CV6aOwlvb61D+lg/rfUeDy9/cVoc3t9Vh6hgPHrjei6LJOR3W/agq/vcftuAJ21HZ7LxhWHpXAQam2LAwFF8PFk10uwg9Uiz/hZUAICJ+AFWwrsq4CdaYXO+387ilAG7rZPkogoYz57Giam9wev6cXPcK0w4RwQ2XDMcNlwzH5r1+lK+pxZ8+OhgMws17/bj/mRp4swZiwXVefP2qi9CvT/iWYT/5/Ud4bkNL8+gbL8nGf39nWsT1iSixYgmUF2H1SfECKDI3BRA4ZeGDLWQA1KjqLsSpH0tPt6Lqs+B4WxOGD8I1Fyf/GFVTx3jwX3dMw84jJ1GxxocXa1pGM/YdOYnFL32IX6zaHvEyxfYw+dyUHPzHt65C395sV0KULKINlEpVvTUwYbtGfB6Ax2EFSb6ZnofWQUMOa2pWPLVuV3D6njm5KbWvx2cNxMPfuBwPFk3Ak2t34bfv7kbjGWuolNDLFIfztStH4dF5UxPeeZKI2hftf+Ry+4SqHlfV1apaYU1qgaqmoSVQHoJ1DXnWncTB61sPY88xqwnukP598I2rRrtcos4ZPrgfFn5uEtY9dBN+8oVLkWO7vn3gMsWhbp8+Bj+/9UqGCVESiuoIRVVfjHK9nQACNaaPAICIVHWuaBSJvanwt64em5De6PE0uF8fLLjOi7tn5+Ll9/ehfI0POw6faLPePbNz8c9fnpxSR2NEPUkifua1/ZlJnbblQAPW1VqtpXqlCe6aNc7lEjmnb+80zCsYg7/+8DosvasABSGXKWaYECW3RATKggQ8R4/x5Npdwfufm5KDUZ7u1ys8cJnilQ/MbjWfYUKU3KI65dWVviSqmlwXzEhhx06ew+/f3xecvveaXPcKQymHnf0o3qJt5cW+JEng+Y17cNY0s71i9BDkjx3awSOIWrCzH8VbtKe88kUkIVcpEpFyEak3t7LOrtPdnG9qxtPrdwWn56dYU2Ei6v6iPUIZBsAvIvbOi9UwLbpEZLCqNoZ7oIg8rKqLo3kSESkGUAtgPIBCACtEZJOqroxlne7otQ8P4FDDWQBA9uB0fPHyUS6XiIiotVh6yu+E1c+kVedFo0ZEamC16NqEll7ygPWlH1WgAPCr6hJzf6XZZmhP+2jW6XaW2SrjvzNjHHuIE1HSiTZQqgJXb7T1kg/0jLffB1r3kvcD8ERbGFWtDJnlgXVEFNM63c17e+rx/l4/AKBvrzTcMTN8D3IiIjdFGygrAndMq63V5hYkIuNhHSnkA5hu/nrR+kgmaiKSD2vgydAAiXodESkBUAIAY8em7pew/ejkK1eOQtYgXvODiJJPtD3ll0axTqCXfDBoRMQDa2TimJjHLVbVoq6sY4aGqQCAgoKCTgWb2w4eP4PXPjwQnE7WUYWJKDklsrl4XC8ioap+U8cRqzJ03CEymnVS3m/f3YUL5oLpV4/PxJRRQ1wuERGlkkQ2F497za59lOJoiEg5rLCAiHhExGvuF3e0Tndz5nxTqyHb7+XRCRElsaS6zJ0JimC9h30RgKWm2XJpO+t0Ky+/vw/1p84DAEYP7Y+iyTkul4iIKLKkanuqqqWqKqE3s2yoqta0t053oqqtKuPvnpXb4SVyiYjclFSBQi3W1x7F1oNWX9EBfXvh1uljXC4REVH7GChJ6gnb0ck380djSP8+kVcmIkoC7dahiMjjsIY4iQefqj4Qp22ntN1HT2L11kPB6XtYGU9EKaCjSvkyxNDTPUb+OG035T21bjfU9Jq54ZJs5GUPcrdARERRaDdQTGdFSqDGM+fxQtXe4PT8OfE6QCQichbrUJLMyurPcOLsBQBAXvZAXDchy+USERFFp8N+KCKSEY8n7uwVILuz5mbFU+t2BafvmTOe1zwhopTRUaX8XxGnoeFFpFZVb4nHtlPVG9sOY9fRUwCAjH698c38i1wuERFR9DqqQ7k5UQWh1qMKf+vqsRjQN6kGMiAiahfrUJLEtoONeGfHEQBAmgB3zhrncomIiGLDQEkST65raVB3y5QcjB46wMXSEBHFjoGSBOpPnsNLNfuC02wqTESpiIGSBJ7ftAdnLzQDAC67KAPTc4e6XCIiotgxUFx2vqkZv12/Ozg9fzabChNRamKguOzPHx3EgeNnAABZg9LxpakjXS4REVHnMFBctmxtS2X8HTPGIr13LxdLQ0TUeQwUF23e60fNHj8AoG+vNNwxc6y7BSIi6gL2nHOR/ejkS1NHYvjgfi6WhuLtB3MnuF0EorhioLjkUMMZ/OGDA8Hpe9lUuNt7sGii20Ugiiue8nLJM+/uxoVm66In03OH4rKLhrhcIiKirmGguODM+SY8t2FPcJodGYmoO2CguOCVzftx9OQ5AMBFnv64efIIl0tERNR1DJQEU9VWowrfNWscevfi20BEqY/fZAn2ru8Ythywri3Wv08v3D6dTYWJqHtgoCSYvanwN/IvwpABfVwsDRGRcxgoCbT32Cms2nIoOD1/Tq57hSEichgDJYGeWrcLarUUxnUTs3Hx8MHuFoiIyEEMlAQ5cfYClm/aG5zm0QkRdTcMlAR5sfozNJ69AADwZg3E9ROyXS4REZGzGCgJ0NyseHLdruD0PXNykZbGa54QUffCQEmAt7bXYeeRkwCAwf1645v5o10uERGR8xgoCfCEranw7dPHYGA6x+Qkou6HgRJnnx5qxNufHgEApAlw16xcdwtERBQnSRcoIlIuIvXmVhZhnXwRWSgiZSKyItFljMUyW91J0eQRGJM5wL3CEBHFUVIFiogUA6gFMB7AAgALzbxQS1V1iaouAuCLFDxu8586h5dqPgtOc1RhIurOkipQAPhNUPhVdSWAGgBe+woiUgKg0jZrFYDCBJYxav+zaS/OnG8GAEwemYEZ4zNdLhERUfwkVaCoamXILA+sULHLA3DUNu1DSOgEiEiJiFSJSFVdXZ1j5YzGhaZmPG073TV/Ti5E2FSYiLqvpAoUOxHJB+CLEDJRUdUKVS1Q1YLs7MR2JPzLx4ew//gZAMCwgX3x5amjEvr8RESJlpSBIiIeAItVtSjMYj+AYSHzjsW7TLGyjyp8x4yx6Nenl4ulISKKv6QMFABlsCrlwzmK1qe4PGh7WsxVH352HFW76wEAfXoJvjNznMslIiKKv6TrYSci5bACJXCkkqmqPhEpNhX1KwEstj2kEMDyeJXnl6u2B+8/WDQxqsfYj06+dMUoDM/o53i5iIiSTVIFigmTEnNrtQjAUhHxqWqNiCwy61YDgAmauHhs9afB+9EEyuGGM3j1g/3BaY4qTEQ9RVIFiqqWAiiNsGyo7X5FwgoVo2c27MH5JuuiJ9PGDcUVoz3uFoiIKEGStQ4lJZ290ITnNuwOTvPohIh6EgaKg17dfABHTpwDAIwc0g+3TMlxuURERInDQHGIqraqjL9rVi769OLuJaKeg994Dtm48xg+3t8AAOjXJw3funqMyyUiIkosBopDlq3dFbz/9atGwzOgr3uFISJyAQPFAXuPncJfPzkYnGZlPBH1RAwUB/z23d1otloK49oJWZg4YrC7BSIicgEDpYtOnr2A/9m4JzjNoxMi6qkYKF30Us1naDhzAQAwPmsgbpg43OUSERG5I6l6yqea5mZtdYnfu2eNQ1pacl3z5AdzJ7hdBCLqIRgoXbDm0zr46k4CAAan90ZxQfI1FY52QEsioq7iKa8usDcVvnX6GAxKZz4TUc/FQOmkHYdP4K3t1mWFRYC7Z+W6WyAiIpcxUDrpyXUtw6wUXjoCY4cNcLE0RETuY6B0wvFT5/Fi9b7gNJsKExExUDpledUenD7fBACYlDMYs7yhl7gnIup5GCgxutDUjKfWtVzz5N454yGSXE2FiYjcwECJ0apPDmGf/zQAIHNgX3zlylEul4iIKDkwUGJkbyr87avHol+fXu4VhogoiTBQYvDRvuPYuOsYAKB3muDOWeNcLhERUfJgoMTAfnTyxStGYkRGP/cKQ0SUZBgoMXh18/7g/flzxrtYEiKi5MNAicG5pmYAwFVjPbhyjMfdwhARJRkGSifw6ISIqC0GSoxyMvrh85fluF0MIqKkw0CJ0Z2zxqFPL+42IqJQ/GaMQXrvNHz76rFuF4OIKCnxAh4x+PpVF2HowL5uF6NH4RUniVIHA6UdgSFWAu7hqMIJxytOEqUOnvJqx+GGM62mJ+VkuFQSIqLkx0Bpx1Vjh7pdBCKilMFAISIiRyRdHYqIeADcCmCeqhZFWKcYwHQz6QWwQFX9CSkgERGFlVRHKCZMysykN8I6XgCLVXWRqi4CsAnA4sSUkIiIIkmqIxRzlFEqIvntrOYF4LFN+wH44lcqIiKKRlIdoURDVSsBQERqRaQEgEdVV7pcLCKiHi/lAsUohXVUUgbriCbS6bESEakSkaq6urqEFpCIqKdJuUAxp8NKVbVIVYcCqAFQHm5dVa1Q1QJVLcjOzk5oOYmIepqUCxQAtwFYbptehAgV+ERElDgpEyimqTAA1AKwNyf2AmAdChGRy5IuUESkEFYzYK+ILLTVjywVkXxVrTDrLTSV8vmm+TAREbkoqZoNA8FWXJVh5g+13S9NaKGIiKhDSXeEQkREqYmBQkREjmCgEBGRIxgoRETkCAYKERE5goFCRESOYKAQEZEjGChEROQIBgoRETmCgUJERI5goBARkSMYKERE5AgGChEROYKBQkREjmCgEBGRIxgoRETkCAYKERE5goFCRESOYKAQEZEjku6a8snmB3MnuF0EIqKUwEDpwINFE90uAhFRSuApLyIicgQDhYiIHMFAISIiRzBQiIjIEQwUIiJyBAOFiIgcwUAhIiJHMFCIiMgRDBQiInIEA4WIiBwhqup2GRJCROoA7E7w02YBOJLg50xl3F/R476KDfdX9EL31ThVzY7mgT0mUNwgIlWqWuB2OVIF91f0uK9iw/0Vva7sK57yIiIiRzBQiIjIEQyU+KpwuwAphvsretxXseH+il6n9xXrUIiIyBE8QiEiIkcwUCjpiIjHft8+3ZNwP8Qmmv3FfRpfDBSHiEixiKyIsCxfRBaKSFmkdXqa9vYXgHoRURFRAPUAyhJYtGTS4X7gZ6uVaD43/GzZiEiJ+fx4wiyL+bPFOhQHmDdjNYBjqloUZnm1qk4z98sAQFUXJbSQSSSK/VWuqqUJL1iSiWY/8LPVIsr9xc+WYULiYVWtibA85s8Wj1CcUQKgHIA/dIGIlACotM1aBaAwMcVKWhH3lxFpfk/jb28hP1tt+B1ap9uLIkw69dlioHSRiOQDCPumGHkAjtqmfQC8cS1UEotifwFAvohUm1MTq3rwee6O9gM/W61F87np8Z8tESkEkA/Aa/ZBuVOfLQZK1xWqamU7yz2JKkiK6Gh/AcAKc6g9FEAmgKXxL1ZS6mg/eBJeouQWzeeGny2gCNZr9wOYBysoQutIPJ3ZMAOlC0SkGMDKDlbzAxgWMu9YXAqU5KLcX1DVCvPXD+BhWL+mepwo9oMf/GwFRfO54WcLgBUWlapaafZDGdqezvKjE5+t3g4UricrhXXYCFiJ7xGRWlXNs61zFMB027QHHZ/y6a6i2V+h/Oi5+8vOj7b7gZ+tyPzoeF9Es053VAvr/y/Ah7Z1S536bPEIpQtUtUhV88wX4iIAKwNfjubXOGD9IrenfyGA5YktaXKIZn+JyMKQhxWZdXuU9vYDP1ttRbO/+NkKWonWR2aFMMOtdPWzxWbDDjAtIkphnYtcpKoVIlIPYK6q1pjl0wBUA/Co6hIXi+u69vaXWaUM1q+ho7BCx+dOSd1jGi+E3Q/8bLUVzf4yq/b4zxYQDI4ihHxuuvrZYqAQEZEjeMqLiIgcwUAhIiJHMFCIiMgRDBQiInIEA4WIiBzBQCEiIkcwUChmIuI110noVgMRmkHzUoZ5H0q6+j50l/fT9EUhFzFQKGrmoljVsIZuKEM3GtnWDOedEl9IIlJoex/K0cn3Id7vp7lAU4mT2+zAilT7UdDdMFAoaqq60ozU2tFowSlFRMqRQr3MzaB+0xDFQJsdbCfe7+dqAOW24TzirQhWqKTED4PuiINDUmf43S6AU8yXXQms4cxTjVMjC/sd2k6oF2CNAZWQARhV1SciD8MKslR8P1Mej1Cop1sKoMIM400OUtVSMxhoIsfLqoA1inXoQJCUAAwU6rHMl44HVv0BdQPmh0EFgLKeeDVGtzFQyHGmMnaFubxorbnkarv/4OYx5eYx9eYxgWHHPXH6cigF4Av9BW2er9CUuda0grJfOrbe/gvYrFtrlgXL3c7rjGnfmMcV2/ZNvan3abcS3ZRrlXmOWnO/y/ULgcp822uoNeUJfZ0LzXqFIctKbI+rte2HaltZQx8Ty2sJXH0wkQ0CCABUlTfeYrrB+odVWJfzDV1WBqvVkNc2r9CsXw8gP8JjAkNkA9ZRQ7V5jP3mdfA1eM02V4RZttC8hsDzBqbLzM0+v8SUdSGsFleBZU7tm3yz/TJz32ues7aD51pl36ch71txtO9nR++VmVdsfZW0ek0rIpUvsB/DbDuw/2o7+1psnx8FsMrt/5WednO9ALyl3i3SF5DtyzHcP3lxhC+L4gjbyg+3voOvoSQQCu2sUxvui8lWZgVQHrJsYYTHdGbf5LdXRrSEbui+W2Xme0LmB75o66N5PyM8Z9jyhPvybudzUo6QHwe2/dPqh0Osr8W2vN4ecrwl5sZTXuSkcsBqjhq6wMzzwboEsP1UxG3m77GQ9QMtg+yXKnVS4LLD/nbWCSxrdVW/kNcXesW/wLLQ01Gd2TdlZnmk5sxVoTNM58RCADUa0tDATPthVVrH3OfEdlrutjCn6ObFsKkytZ1mNNsKnKaapy0XxurKazkWUmZKADYbJkeYf1wv2v+CroR1ZDAtzDIvbM1LbV8E8eojEdh+e01vA8v8YZb5Yf1qbrVMraargC0Iu7BvCmEFTSwC9Qr5IlIb5jmPdWKbAKwvcRGpMc9Rb+5Xwjo6ifp90ratvlbDej+WhARuV15LYN3MMI+jOGGgkFOi+cVba/7ajzrKYZ3yKRORGtuXTRmsX6ax/PJNVjHvmy4MgxLYtzVqdVp0lKpOMxXwJbC+8PMBLBQRH6yji5j6nJht5Zvyhh7txfW1kPN4youc4jd/PVGsuylwx/yynWceH2jtswpAdXtfIraWRmFb+pjl5aZFVZuWRggfbvHiN389Uawb2DeBo6P2giVc2QOBHM1zdYpa/UsE1tHUIlivzwurT0/UbJ1K/Wi55rtdV15L4DFOdf6kKDBQyBEh58QjjacUqLdodXrEVodQpKrTVLVIVSvCbcA0sV0Fa5iNiGEC4DbzxTcP1pdeaDPTuH/xBnRm39jqCAKvJ5xwYROoV/HGYwgSex2Pqtao6hJVDfRK98SwHS9a15v4wyzvymvJNGX0d7AeOYiBQk4KVB5H6ih4K6xe6a1Oi4g1QGF+lOfhl6tqEUwldwRl9uVmuz5Y/U4CAmWY3s52nDx66cy+eSHwmNDKZREJNCMGbF/k5gs0sB9XhKuU7uzpNLOtogiL7c8bzXZWmckloe+7CY/Szr4Ws54HCRryhVqwDoU6wxPyFwCgqovML/B8EVlhr/8QazTfKlW1f6kH/vnzzf1qtG255IcVIjXmOdr9kjBfRvZftwGVsCq5A2X1mfP+0fzy9USaJyJtKubDPaYz+0ZVS81jvACqTX2DH9YpwhpYIRVuiJF5sCq6883jVsL6AvfAalV3DK3D1RPyNxIvgGIRKbEfQZpw84WWv53trjDbqgxTbwJTxsApyVhfCwAUmL/LO3g95DS32y3zljo3WJXn9g6H9TD/3CHrLTTr1cP68liBMP0vbOuX2LYZ6Rapn0p+mDJqmOcoC52Plj4joX0iQjs21sN0gERL58TAslqEdI4MWVbWxX3jCXnO+sD6aN2RclWY1xH6uFX254r2/bStHzhNVW9u1Wb9kmg/J7Z9Htg/1bZbrVk3UmfIiK8l3Hsduj94i/9NzBtA5ApzRFEG65TUIjW/9m1NbctgHVlUqnWqy/64agDT1HbUYs7xl6tVaWx/noWwvtzFNs8D68tpiYb/pUwpSETqYX1eukMLwZTCOhRyjflCrwaQqVYFuj+wTFX9Jihi/VLwRLuieb5F4JhP3YZpwOBB2w6nlAAMFHJToEK1vcrcQMX4qnbWsfO3s6xNJzi1eqH7hMOddxdlsCr0EzlkPhkMFHKN7VRVexXjpbCCIGwz4jDsw3bYDUPkXtVzAZR2oTMhJQHTaKFSIzQ5p/hjoJDbSgEUhjtCMK2HimH1T/FHub1A667Q/h6FiNDU2Gy7CG1bC1GKMKdP/dq2pRklECvlyXWmgn0xrFNgx9By2mpVpF+bpjntKlhhE9qPoRxWK6E82/aXKofvIIorBgqlFHNaahGsIw4vrNNYKwE8bD+KsXX8C5zmWhTDUQ4RdQIDhYiIHME6FCIicgQDhYiIHMFAISIiRzBQiIjIEQwUIiJyBAOFiIgc8f8Bu7bONe/6jMgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = torch.load('fcn-para.pt')\n",
    "plot_flatness_size(data['model_size'], \n",
    "                   data['hessian_fro'], \n",
    "                   data['mu'], \n",
    "                   xlabel='model size', \n",
    "                   title='FCN', savefile='../figs/hessian_fro_fcn_modelsize.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2242b29-4766-4833-a005-b6d5920025d6",
   "metadata": {},
   "source": [
    "- number size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "id": "24611183-92f7-4a54-95fb-3aeee729090f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1, loss_tr: 5.3e-01\n",
      "5001, loss_tr: 3.9e-07\n",
      "10001, loss_tr: 7.9e-08\n",
      "15001, loss_tr: 1.7e-08\n",
      "train error 4.475082970856192e-09\n",
      "test error 0.1740084735304117\n",
      "(0, 0)-> took 28.3 seconds\n",
      "1, loss_tr: 4.6e-01\n",
      "5001, loss_tr: 1.2e-06\n",
      "10001, loss_tr: 1.1e-07\n",
      "15001, loss_tr: 3.0e-08\n",
      "train error 1.635442714587043e-08\n",
      "test error 0.1973014259338379\n",
      "(0, 1)-> took 28.1 seconds\n",
      "1, loss_tr: 3.6e-01\n",
      "5001, loss_tr: 2.2e-07\n",
      "10001, loss_tr: 2.2e-08\n",
      "15001, loss_tr: 2.4e-09\n",
      "train error 2.0669796430006215e-10\n",
      "test error 0.1823489224165678\n",
      "(0, 2)-> took 28.0 seconds\n",
      "1, loss_tr: 5.6e-01\n",
      "5001, loss_tr: 5.6e-07\n",
      "10001, loss_tr: 1.8e-07\n",
      "15001, loss_tr: 6.4e-08\n",
      "train error 2.2767840945903117e-08\n",
      "test error 0.17692646861076355\n",
      "(0, 3)-> took 28.3 seconds\n",
      "1, loss_tr: 4.4e-01\n",
      "5001, loss_tr: 4.1e-07\n",
      "10001, loss_tr: 1.8e-08\n",
      "15001, loss_tr: 5.2e-09\n",
      "train error 2.162143131201333e-09\n",
      "test error 0.18967681244015694\n",
      "(0, 4)-> took 28.1 seconds\n",
      "1, loss_tr: 3.9e-01\n",
      "5001, loss_tr: 4.1e-07\n",
      "10001, loss_tr: 2.1e-08\n",
      "15001, loss_tr: 2.4e-09\n",
      "train error 1.9003048960797297e-10\n",
      "test error 0.15611951794475318\n",
      "(1, 0)-> took 28.6 seconds\n",
      "1, loss_tr: 4.3e-01\n",
      "5001, loss_tr: 3.4e-06\n",
      "10001, loss_tr: 5.2e-07\n",
      "15001, loss_tr: 1.3e-07\n",
      "train error 2.0595204652806842e-08\n",
      "test error 0.14881017982959746\n",
      "(1, 1)-> took 28.6 seconds\n",
      "1, loss_tr: 5.7e-01\n",
      "5001, loss_tr: 2.3e-06\n",
      "10001, loss_tr: 1.1e-07\n",
      "15001, loss_tr: 3.1e-08\n",
      "train error 7.585349237615446e-09\n",
      "test error 0.15234483987092973\n",
      "(1, 2)-> took 28.9 seconds\n",
      "1, loss_tr: 3.8e-01\n",
      "5001, loss_tr: 7.6e-06\n",
      "10001, loss_tr: 1.3e-06\n",
      "15001, loss_tr: 2.7e-07\n",
      "train error 9.978703729984773e-08\n",
      "test error 0.161190020814538\n",
      "(1, 3)-> took 31.1 seconds\n",
      "1, loss_tr: 5.0e-01\n",
      "5001, loss_tr: 1.9e-06\n",
      "10001, loss_tr: 3.4e-07\n",
      "15001, loss_tr: 6.9e-08\n",
      "train error 1.3989363001343236e-08\n",
      "test error 0.16161490865051747\n",
      "(1, 4)-> took 31.4 seconds\n",
      "1, loss_tr: 4.5e-01\n",
      "5001, loss_tr: 1.8e-05\n",
      "10001, loss_tr: 1.8e-06\n",
      "15001, loss_tr: 6.9e-07\n",
      "train error 3.4506966528624616e-07\n",
      "test error 0.10034602833911777\n",
      "(2, 0)-> took 32.9 seconds\n",
      "1, loss_tr: 4.7e-01\n",
      "5001, loss_tr: 2.4e-05\n",
      "10001, loss_tr: 9.4e-06\n",
      "15001, loss_tr: 3.1e-06\n",
      "train error 1.950493555114008e-06\n",
      "test error 0.0953292231541127\n",
      "(2, 1)-> took 33.9 seconds\n",
      "1, loss_tr: 4.9e-01\n",
      "5001, loss_tr: 2.2e-05\n",
      "10001, loss_tr: 1.2e-06\n",
      "15001, loss_tr: 4.7e-07\n",
      "train error 3.221955395815712e-07\n",
      "test error 0.10277419161051511\n",
      "(2, 2)-> took 32.6 seconds\n",
      "1, loss_tr: 3.1e-01\n",
      "5001, loss_tr: 1.9e-05\n",
      "10001, loss_tr: 4.5e-06\n",
      "15001, loss_tr: 1.1e-06\n",
      "train error 7.9580273393276e-07\n",
      "test error 0.09361816056072712\n",
      "(2, 3)-> took 32.4 seconds\n",
      "1, loss_tr: 5.8e-01\n",
      "5001, loss_tr: 3.5e-05\n",
      "10001, loss_tr: 3.4e-06\n",
      "15001, loss_tr: 2.3e-06\n",
      "train error 6.568768782244661e-07\n",
      "test error 0.10441724086180329\n",
      "(2, 4)-> took 32.6 seconds\n",
      "1, loss_tr: 3.7e-01\n",
      "5001, loss_tr: 2.3e-04\n",
      "10001, loss_tr: 1.4e-05\n",
      "15001, loss_tr: 6.8e-06\n",
      "train error 3.0046401917616095e-06\n",
      "test error 0.07159467979334294\n",
      "(3, 0)-> took 38.6 seconds\n",
      "1, loss_tr: 4.1e-01\n",
      "5001, loss_tr: 6.8e-05\n",
      "10001, loss_tr: 7.7e-06\n",
      "15001, loss_tr: 6.5e-06\n",
      "train error 3.008488903333273e-06\n",
      "test error 0.07505115023814142\n",
      "(3, 1)-> took 38.4 seconds\n",
      "1, loss_tr: 4.6e-01\n",
      "5001, loss_tr: 3.6e-05\n",
      "10001, loss_tr: 1.4e-05\n",
      "15001, loss_tr: 4.5e-06\n",
      "train error 3.0287180834420724e-06\n",
      "test error 0.07642010480165481\n",
      "(3, 2)-> took 38.6 seconds\n",
      "1, loss_tr: 4.7e-01\n",
      "5001, loss_tr: 7.3e-05\n",
      "10001, loss_tr: 1.6e-05\n",
      "15001, loss_tr: 2.1e-05\n",
      "train error 3.4831102482257847e-06\n",
      "test error 0.07238797749392688\n",
      "(3, 3)-> took 38.5 seconds\n",
      "1, loss_tr: 5.2e-01\n",
      "5001, loss_tr: 4.1e-05\n",
      "10001, loss_tr: 1.6e-05\n",
      "15001, loss_tr: 9.0e-06\n",
      "train error 3.5575208698901406e-06\n",
      "test error 0.07707730459049345\n",
      "(3, 4)-> took 38.7 seconds\n",
      "1, loss_tr: 4.9e-01\n",
      "5001, loss_tr: 1.8e-04\n",
      "10001, loss_tr: 3.1e-05\n",
      "15001, loss_tr: 2.2e-05\n",
      "train error 2.764375176411704e-05\n",
      "test error 0.05935943013988435\n",
      "(4, 0)-> took 59.1 seconds\n",
      "1, loss_tr: 3.7e-01\n",
      "5001, loss_tr: 1.5e-04\n",
      "10001, loss_tr: 2.8e-05\n",
      "15001, loss_tr: 1.2e-05\n",
      "train error 5.588196032135784e-06\n",
      "test error 0.05520988123957068\n",
      "(4, 1)-> took 59.1 seconds\n",
      "1, loss_tr: 6.4e-01\n",
      "5001, loss_tr: 1.9e-04\n",
      "10001, loss_tr: 4.7e-05\n",
      "15001, loss_tr: 2.2e-05\n",
      "train error 1.1625768649992096e-05\n",
      "test error 0.055511578484438356\n",
      "(4, 2)-> took 59.2 seconds\n",
      "1, loss_tr: 4.6e-01\n",
      "5001, loss_tr: 1.5e-04\n",
      "10001, loss_tr: 3.6e-05\n",
      "15001, loss_tr: 2.3e-05\n",
      "train error 9.902575300202443e-06\n",
      "test error 0.05960482557304204\n",
      "(4, 3)-> took 59.2 seconds\n",
      "1, loss_tr: 6.7e-01\n",
      "5001, loss_tr: 2.9e-04\n",
      "10001, loss_tr: 3.7e-05\n",
      "15001, loss_tr: 2.0e-05\n",
      "train error 2.959052744699875e-05\n",
      "test error 0.060806230138987304\n",
      "(4, 4)-> took 59.4 seconds\n",
      "1, loss_tr: 4.6e-01\n",
      "5001, loss_tr: 6.7e-04\n",
      "10001, loss_tr: 1.5e-04\n",
      "15001, loss_tr: 7.1e-05\n",
      "train error 4.4490865377611044e-05\n",
      "test error 0.04267239287961275\n",
      "(5, 0)-> took 141.7 seconds\n",
      "1, loss_tr: 6.2e-01\n",
      "5001, loss_tr: 4.0e-04\n",
      "10001, loss_tr: 9.3e-05\n",
      "15001, loss_tr: 4.7e-05\n",
      "train error 2.1565026628422856e-05\n",
      "test error 0.04327935257926583\n",
      "(5, 1)-> took 141.9 seconds\n",
      "1, loss_tr: 5.8e-01\n",
      "5001, loss_tr: 9.4e-04\n",
      "10001, loss_tr: 2.1e-04\n",
      "15001, loss_tr: 1.3e-04\n",
      "train error 5.929897076839552e-05\n",
      "test error 0.04049858972779475\n",
      "(5, 2)-> took 141.5 seconds\n",
      "1, loss_tr: 4.5e-01\n",
      "5001, loss_tr: 7.2e-04\n",
      "10001, loss_tr: 1.5e-04\n",
      "15001, loss_tr: 9.4e-05\n",
      "train error 3.8740628610867134e-05\n",
      "test error 0.04505785422632471\n",
      "(5, 3)-> took 140.9 seconds\n",
      "1, loss_tr: 5.2e-01\n",
      "5001, loss_tr: 5.4e-04\n",
      "10001, loss_tr: 1.3e-04\n",
      "15001, loss_tr: 6.0e-05\n",
      "train error 3.809088292427987e-05\n",
      "test error 0.04548293613595888\n",
      "(5, 4)-> took 141.5 seconds\n",
      "1, loss_tr: 4.9e-01\n",
      "5001, loss_tr: 1.1e-03\n",
      "10001, loss_tr: 5.7e-04\n",
      "15001, loss_tr: 2.4e-04\n",
      "train error 0.00014336424862904096\n",
      "test error 0.03220311668643262\n",
      "(6, 0)-> took 469.3 seconds\n",
      "1, loss_tr: 4.1e-01\n",
      "5001, loss_tr: 1.9e-03\n",
      "10001, loss_tr: 3.3e-04\n",
      "15001, loss_tr: 1.6e-04\n",
      "train error 0.00010441661663662671\n",
      "test error 0.03106562932167435\n",
      "(6, 1)-> took 468.7 seconds\n",
      "1, loss_tr: 4.7e-01\n",
      "5001, loss_tr: 2.7e-03\n",
      "10001, loss_tr: 6.0e-04\n",
      "15001, loss_tr: 2.1e-04\n",
      "train error 0.0002522502088595502\n",
      "test error 0.034387633969890884\n",
      "(6, 2)-> took 466.4 seconds\n",
      "1, loss_tr: 5.7e-01\n",
      "5001, loss_tr: 1.2e-03\n",
      "10001, loss_tr: 2.9e-04\n",
      "15001, loss_tr: 1.0e-04\n",
      "train error 7.045501774882723e-05\n",
      "test error 0.02932969419271103\n",
      "(6, 3)-> took 464.0 seconds\n",
      "1, loss_tr: 5.8e-01\n",
      "5001, loss_tr: 1.6e-03\n",
      "10001, loss_tr: 2.5e-04\n",
      "15001, loss_tr: 1.2e-04\n",
      "train error 0.00022717959791407338\n",
      "test error 0.030910217170603573\n",
      "(6, 4)-> took 467.3 seconds\n"
     ]
    }
   ],
   "source": [
    "sample_sz = [100, 200,400, 800, 1600, 3200, 6400]\n",
    "ntries = 5\n",
    "m = 80\n",
    "bz = n\n",
    "lr = 0.05\n",
    "\n",
    "sample_size = torch.tensor(sample_sz).float()\n",
    "hessian_fro = torch.zeros(len(sample_sz), ntries)\n",
    "mu = torch.zeros(len(sample_sz), ntries)\n",
    "\n",
    "for i, n in enumerate(sample_sz):\n",
    "    for j in range(ntries):\n",
    "        time_st = time.time()\n",
    "        X_tr, y_tr, X_te, y_te = mnist.load_data(n)\n",
    "        \n",
    "        net = models.build_FNN(m)\n",
    "        train_model(net,X_tr, y_tr, X_te, y_te,\n",
    "                    lr=lr, batch_size=bz, nsteps=20000, display=5000, plot=False)\n",
    "        \n",
    "        ana = AnalyzeLargeNet(net, X_tr, y_tr)\n",
    "        ana.compute_grads()\n",
    "        hessian_fro[i,j] = ana.hessian_fro()\n",
    "        mu[i,j] = ana.mu()\n",
    "        \n",
    "        print('({:}, {:})-> took {:.1f} seconds'.format(i,j, time.time()-time_st))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "d42aa0e7-6da4-4798-9e43-41a0f2d15314",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {'sample_size':sample_size, \n",
    "        'hessian_fro': hessian_fro, \n",
    "        'mu': mu, \n",
    "        'batch_size': 3, \n",
    "       'learning_rate':0.2}\n",
    "# torch.save(data, 'fcn-sample.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "id": "932e64e4-0cec-44b5-a2a4-e358e6681986",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZYAAAFkCAYAAAADoh2EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA4q0lEQVR4nO3deXwcd3k/8M+j07Z8rGTJcezYkSWfickhy+Q+iKUcEGgLdgI0AYobOWmBEFJsQqEXlFSmEMJRkExaQuAHjgwtEEJAckhCyIElxUmIb8m3Y1uStbYl2zqf3x/z3dXsaldaSbM7s6vP+/Wal3buZ2dX8+zM9xhRVRARETklze0AiIgotTCxEBGRo5hYiIjIUUwsRETkKCYWIiJyFBMLERE5iomFaAxEZKWI1IpIe4RBIwxFUbbjE5G1ZltNtuWbzLS1IlJilq2w7VNFpDaGOKtEpMEWW62IrHT6eBABgLAdC9HYmZN0DQC/quaGzSsCUAKgRlUlbJ4PQCWACgCbAGwE0KiqzWZeEYA7AawFAPv6IlJl1gOA9aq6LoY4GwBsVNX1o3ibRDHJcDsAohThjzZDVZsBNItIs326STgNZrRcVevC1vMDaATQaJJI+JVJk9mvD8BaEdmiqpuGibN5qFiJnMBbYUTOOBHDMsErCnM10gArKawKTyrhTHLyR5j1JIBq87omcLvMgViJRo2JhShBwq4mNsBKKpuGSyo290TZ7hpYVzaAlVx8o42RyAlMLEQJICI1ttc+AIGC86pYt6GqjUPMXgHriqYIVlkPkWuYWIjizCQS+y2qMtvrZjjAlMesCGxfRCqd2C7RaDCxEDnLF17FGEA7gDzbMsHXpuzEEeaKZo0ZXcvqxOQWJhYiZ/lVVewDgGKEFpj74rVzVa2GVW0ZiL0wn8hRTCxEcRaobmybFHwdrcHkGPe3CgOF+ZtZmE+JxsRClBhrbK/ttcDidbsqUJjvA7A5TvsgioiJhSgB7GUppqA9kFwq43TV4sdAYX6JaWBJlBBMLETusF/BxKV6cFhhfgUL8ylRmFiInJE3/CIDzBXMKjNaYjqbHLKgXUTKRlpeEl6Yj9Bqz0Rxwb7CiJzhC/s7LFXdJCLFsDqhXAmgQUTqYCWAwK0zH4DlZn6zqpbbNlEMq0HkcPtZZTqfLIlleaKxYu/GRGNgbi+tweBGj82w+gDzx7idEtt28mAlFL/ZTj2AqkDLexGpMMsGrj4aAdQN1buxudLZa7a7xlzJEMUFEwvROGISmN/JhplE4ZhYiIjIUSy8JyIiRzGxEBGRo5hYiIjIUUwsRETkKLZjAZCfn6+FhYVuh0FElFQaGhpaVbUgfDoTC4DCwkLU19e7HQYRUVIRkf2RpvNWGBEROSqlEou9HyUR8fE5FEREiee5xCIiVSLSbobK4aaHaQ97HCyf+01ElGCeKmMx/S41AZgHq8+kGhHZYmYPmq6qm8I2Ua2qa0BERK7xVGKB1YfRevN6k4g0wuqNtTHK9EHrJyBGIiIagqduhalqXdgkH6ykEnF6hE2UiEiDuR1WyzIWIqLE81RisTO9sDaHJ5Vo040aVV0GIBdW1+Mbhth+hYjUi0h9S0uLk6ETEY1rnkws5krjobCHGkWdHhB4xoR5BsbDGOJpeaparaqlqlpaUDCofQ8REY2SJxMLrNpc94xgeiR+RL5dRkREceS5xCIiVTDVhE1blKJhpq80f9eGbaocQNQn6hERUXx4qlaYSR4VZrBPr440HYAA2CAizQDqRKQW1lVKG6xHufIpeURECeapxGLaoERrhxJxuqrm2kYjlr0QJaNHancFXz9QvtDFSIhGxlOJhYgGPLp5d/A1EwslE8+VsRARUXLjFQsRxR1v640vTCxEFHe8rTe+8FYYERE5iomFiIgcxcRCRESOYmIhIiJHMbEQEZGjmFiIiMhRTCxEROQoJhYiInIUEwsRETmKiYWIiBzFxEJERI5iX2GUUtjZIZH7mFgIQOqckNnZIZH7mFgIAE/IROQclrEQEZGjmFiIiMhRTCxEFFd9/TrkOKUeJhYiipuunj58sOrlkGl3P/Yqk0uKY+E9EY1Id28/2jq70Hq6G60dXWjp6EJrx8D4wNCNE53dg9bfetCP53Yex4ol57kQPSUCEwsR4VxPXzAZtJ4OTQ4tHV22ad04ebZnTPs6292HbUdOMbGkMM8lFhGpAnCHGa1W1XVmegmAMgDTARSp6qoI6w67DNF40dnVG0wQLeFXEyHj3ejo6k1YXFkZabho1tSE7Y8Sz1OJRURWAmgCMA9WgqgRkS2qugnABlVdZparFJHKQNKxiWUZChOpcDU9TVyKhoDIn0maAKe7es3Vgy0xnO5CS8fgxHG2py8usaUJkJeTjfzJWSiYko38ydZr62828qeYeZOzMW1iJv7mB1vwUlObbX3BtfPz4xIbeYOnEgsAv6quN683iUgjgCIRqQBQZ1uuFkClfcVYlqHB+voVdz/2asi0ux97FU+svoLJJcFUFW+fPIfdx07jX371Vsi8JV98BgpFT198Cr0z0gTT7clhcjbyp1jJwT6ePzkbuZOyRvTdeGL1FSj+/NPB8bM9fXiy/iDuvqowDu+EvMBTiUVV68Im+QA0AigH0Gab3gygKGzZ4hiWoTDP7TyOLftOhEx7pbkNqx/fgmvn52Nefg4K83MwJ3cSsjJYidAJ53r6sLe1E00tHWhusf4GXp/pjnyV0d3XP+L9ZGWkmcQwkDCCycN2VZFvrizS4vRDIlIS+nrtLrzvstmYNjEzLvskd3kqsdiZ8pJmVa0TkVUITRqR+GJYxr79CgAVADB37tzRhpn0auoPDfoV3K/Acztb8NzOluC09DTBBbkTUTg9x0o20yehMN96Pds3ERnpTDp2qorWju5ByaOppQOH2s9CR3nhMTEzPXjlEBgKJmeZRGG7LTUlG1OyMyDizavO9jM9+Nbm3fjC7Re5HQrFgScTi4j4ADykquVmkh9WgbzdibDxWJYJUtVqANUAUFpaOi4r1TceaEfd9mMxLdvXr9jfdgb7287g+V0tIfMy0wVz8iZh3nTr6qYwP8e8noRZ0ybG7ZewF/T09ePAiTNoOt6BJnsCOd6BU+dGXiA+bWIm8idnYX/bGfTaylkmZqbhP1ddivdcMsvJ8F31+Mv78NdXXoh5+Tluh0IO82RigVU2co9tvA3Actu4D9YtMoxwGTKO+M+i4ocNISevgMUzp+CvLp+N/SfOYF9rJ/a1duLIyXNRt9XTp2hu6URzS+egeVkZadbVTeBKJz8n+Pq8qdme/UUd7uSZHjS1dgxKIAfCEkAs0gSYkzcJxQWTUVyQg+KCySgyr/NystCvVjmXvcD78rm5uHXp+U6/LVcsL8zFln3t6OlTPPz0dlR/pNTtkMhhnkssprpxpXntA5AHYBOAh2yLlQHYaJZZaWqNRV2GQp3p7sU9P6xHa0cXAMA3MQP+swO/rn/9qesG3Rc/19OH/W1nsLe1E3tNstnbZv09fror6r66e/ux61gHdh3rGDRvYma6uZ1mJZ7ArbXC6TnIn5yV8KTT16844j+LPS2hCaS5pQOtHYMb+g0nJysdxTMmW4kjPyf4+sLpkzAhMz3qeukyuMA7lSpTfPH2i/C+b/8RAPC7bcfwUlMrri5mLbFU4qnEYpJKsOwjQFVFRNaZ+Q1m2iYze4OINKtq4xDLkKGq+GzNG3jryCkAVm2g791dig9WvxJcJtIJbEJmOhbNnIJFM6cMmtfZ1Yt9bbaE03oG+0zSaYvQ8jrgbE8ftr99CtvfPjVo3pTsDNtttUkht9hyc7JG89ZD4g0UntsTyN7WTnT1jryQfNa0CSieEZo8igsmj+mKLPwzSJWkAgCXXODD+0tm4+eNhwEAX3pqO5765LUp9R7HO08lFlVdA2BNlHnVUabnDrcMDfjm5j349ZtvB8f/7S+W4sqi8KKpkcnJzsDFs6bh4lnTBs07ebYH+03SGbjSsW6xDdWC+3RXL948fBJvHj45aJ5vUqatEoFVljMvPwdz8iaFLHe4/Sz2tQ0kkObWTjQd7xjytl40WRlpVuII3L4yCWRefg5ysj31b5QU1t6yGL9582jwx8WmhoO4c/n4rUSTavgfMY785s238UjdwJMiP3Z1IT58RXz/madNzMQlF/hwyQW+QfPaO7uDt9OCiaetE/tazwzZEtx/pgdbz/ix9aB/yH1fU/nsiOPNn5yNooLQBDK/YDJm+SbyF7WDZk6bgHtvKA5+H7/62114zyWzMJlJOiXwUxwn/nz4JD7z5OvB8Wvn5+ML71niYkRAbk4WcnOyUDI3N2R6oKruvvArndZO7G87M+YW5RlpgrnTJwVvWQWvQPInY9oktqtIlIrri/DTLQfw9slzaO3own/9fg/W3rrY7bDIAUws48Dx0+dQ8cP64Al5Xn4OvvPhEs+2PRERFEzJRsGUbCwvzAuZp6o4dqrLdnXTiebWTjTsb4/Yk+750ybgmvn5IQlkbt4kZHr0vY8nE7PSse7Wxfj0xq0AgO+/uBcfeufcQbc0KfkwsaS4rt4+3PtEQ7BcYcqEDGz4SGnS/jIXEcycNgEzp03AVcUDZUObtx/DJ3/yWkjL9UlZ6fjyXy5lL7oe9r5LZ+F/XtqH1w/60d3bj/94Zge+8+ESt8OiMeLPthSmqnjo52+i8YAfgNV+4tsfLsH8GZPdDSwOblw0A5fN8YVMu2yODzcumuFOQBSTtDTBP90+cEv212+8jfp9Uds1U5JgYklhG/7QHKzSCQD/+J6LcMPCAhcjip/0NMETq68ImZZKbT9S2bIL8/DeSwd6FPi3p7ahn0+YTGpMLCnq2R3H8PBvdgTH7yydg49fU+heQAmQym0/Ut26Wxch23Ry+sahk/i/rYeHWYO8jGUsKWj3sdP41E+2Bjs6XF6Yiy/95dIhG+vdv2JBgqIjGuyC3En42+vm4Tu/bwIArH9mJ25dOhOTsniKSkb81FJMe2c3Vj9eH2wHMts3Ed+9a9mwXd4/UL4wEeERRXXfjfPxZP0htJzuwtFT51D1fDO/l0mKt8JSSE9fP/7ux404cOIMAKtW1Pc/Wor8ydkuR0Y0vMnZGfjszYuC41UvNOHtk2ddjIhGi4klhfzrr97Cy80DPeI+cudlWHI+ny1OyeMDyy7AReY7e66nH199ZqfLEdFoMLGkiCde3ocfvXIgOP7ZWxbhlotnuhgR0YD7VywIDkNJTxN80fbwr5+/dnjYrnvIe1jGkgJe2tOKf/nVtuD4+y6dhb+7sdjFiIhCjaSs5Kri6bjl4vPw27esh9B9+altqLn3qqR5dg/xiiXp7WvtxH0/bkSfqfd/yQXTsH7lJfwnpKT20G1LkJlufYfr97eH9MhN3sfEksROnevB6se3BLufP29qNjZ8pHTIh0gRJYPC/Bz8zTXzguMPP70D58bY+SglDhNLkurrV3zqJ6+hyTwOODsjDdV3l+K8qRNcjozIGZ+4aT7yzEPdDvvP4rEX97ocEcWKiSVJ/cdvtuO5nS3B8fUrL8GlYX1lESWzqRMyQ8pm/uv3e3D89Mgf0kaJx8SShGrqD2LDHwZ+vf39u4rxF5fNdjEiovj40PI5WHie1WlqZ3cfvvbbXcOsQV7AxJJk6vedwD/+75+D4+UXnYcHyxcNsQZR8spIT8MX3jNQ/fjJhoN468jgx1WTtzCxJJHD/rO490cN6O7rBwAsnjkF37jzMqSxs0VKYdcvLMBNi63HH6gCX3pqG1TZ+7GXMbEkic6uXvzt4/Vo7bCekpiXk4UNHylFDp8RTuPA59+9BBnmB9QrzSfwu23HXI6IhsLEkgT6+xUPPvk6tr99CgCQmS743l3L+AhXGjfmz5iMu668MDj+lae3o6uX1Y+9KqV+7oqIT1X9gdcAEBhPZt+o24Vn3joaHP/yXy7FO+flDbHG+MXu/1PX/SsW4H9fO4yTZ3uwv+0MfvjSftxzfZHbYVEEnkssJiHcAWCVqpabaRUAqiIsXq6qdbbx9rAW59UA1sQp1IT41etH8M1n9wTHP37NPNy5fK6LEXlbKnWzziQZKjcnC/evWIB/e8rqvuibz+7G+0tmY3qCe+9+pHagZloqfd+c5KnEYpJKJYAGAOE/RZapaqNtuQ1hSQUAqlU1qROJ3RuH/PiHmteD49cvLMDn373YxYgokXjSGuzuqy7Ej17Zj+bWTpw+14tv1O3Gl/5yaUJjeHTz7uBrfkaReaqMRVX9JjHUh81qDiQVowLAwxE24Y9XbJE8UrsrODjt+KlzqPhhA7p6rRpgRQU5+NaHLkdGuqc+MqKEykxPwz++Z0lw/Mev7seuY6ddjIgiSYqzlP3KxFytFIclmoASEWkQERWR2kA5S7w8unl3cHDSuZ4+3PNEA46esloZT52Qge9/pBTTJmY6uh+iZHTT4hm4dn4+AKCf1Y89KSkSS5gNsG6XRVKjqssA5ALIM8smFVXF5372Bl43z6BITxN8569LUFQw2d3AiDxCRPCF25cg0HzrD7tbQ7o3IvclVWIRkTIAJ1S1OdJ8Va02f/2wbpWVDLGtChGpF5H6lhbvfCm/93wz/m/rkeD4F9+zBNctKHAxIiLvWTxzKj74zoFKLF/+9Tb0mIbD5L6kSiywaoati3FZP4BIt8sAWElIVUtVtbSgwBsn7tptx7D+tzuC4x9651x89OpC9wIi8rDPlC/EFNNAuKmlE//v1QPDrEGJkjSJRUTWAmgMb5ciIitt8+3KEXsSct2Oo6fw6Z++hsCt4nfOy8O/vu9iPrCLKIr8ydn4+5vmB8cfqduFk2d6XIyIAjyXWMztrocAFInIWhEpslVDjlQTbIOIlACoMwX2lSbJVEW7ZeY1bR1d+NvH69HZbbUkviB3Ir531zJkZXju4yHylL+5phBzTQ8U/jM9jlekodHxVDsWIFgDLLx9CgBE/Omuqrm20fK4BBVH3b39uO/HjTjUfhYAkJOVjsc+ujz4gCMiii47Ix0P3bYY9/3Yuuv9w5f34a4r57Kyi8v4k9hFqop//uWf8ae9JwAAIsCjH7wci2ZOcTkyouRx69KZwS6OevsVX3l6u8sREROLix5/aR9+8qeDwfG1tyxG2UXnuRgRUfIREfzT7RchUBxZt/04/rin1d2gxjkmFpf8YXdLsM8jAPiry2fj3hvYoR7RaCydPQ0fKLkgOP6lp7ahr5+NJt3CxOKC5pYO/P2PGxH43l82x4eH3/8O1gAjGoPP3rIIk7LSAQA7jp7Gxi0Hh1mD4oWJJcFOnunB3z5ej1PnegEAM6dOQPXdyzAhM93lyIiS23lTJ+C+G4qD41/73U6cOsfqx25gYkmg3r5+fOInjWhu7QQATMhMw4aPlGLG1AkuR0aUGu65vgizpln/T22d3fjO7/cMswbFAxNLAn3l6R34w+6BQsX/XHUp3nHBNBcjIkotEzLTse62gUdL/M+L+3Cg7YyLEY1PTCwJ8tM/HcB//3FvcPxTKxbg9ktmuRgRUWp636WzcPlcHwCgu68f//EMqx8nGhNLAvxp7wl88Rd/Do7ftnQmPs2nAxLFhYjgi7dfFBx/+s2jeLW5zcWIxh8mljg7eOIM7v1RA3r6rCpgS86fiq/dcSnS0lgDjCheSubm4i8uG7gj8OVfb0c/qx8nDBNLHHV09eKeH9bjRGc3ACB/cha+/9FSTMryXE86RCln7a2LkW3623vz8En8/LXDLkc0fjCxxEl/v+KBjVux46j12NSs9DRU3b0Ms30TXY6MaHyY7ZuIiusHGh2vf2YHOrt6XYxo/GBiiZOv1e5E7bZjwfF//6ulWHZhnosREY0/995QjBlTsgEAx093oer5JpcjGh+YWOLgF1sP4zu/H/gC33PdPKwqneNiRETjU052Bj57y6LgeNULzTjsP+tiROMDE4vDth7047Ob3giO37ioAJ+7bYmLERGNbx8ouQBLZ08FAHT19mP9MzuGWYPGionFQUdPnkPFD+vR3Ws9e3v+jMn45ocuRzprgBG5Ji1N8E+3Xxwc/8XWI2g80O5iRKmPicUh53r6UPFEPY6f7gIATJuYie9/pBRTJ2S6HBkRvXNeHm5bOjM4/qWntkGV1Y/jhYnFIZ/d9AbeOHQSAJCeJvjuX5egMD/H5aiIKOCh25YgK9065b12wI9fvn7E5YhSFxOLQ35l+5L+y3svwtXz812MhojCzZ0+CX9zbWFwvPI3O3Cup8+9gFIYE4vD7rpyLu6+qtDtMIgogk+8az6m52QBAI6cPIcNLzS7HFFqYmIZpUhPp7uqaDr++b0XR1iaiLxgyoRMPHjzQPXj7z7fhGOnzrkYUWpiYhmFvn7FB6tfDpmWnZGGb33ocmSm85ASedmdy+dg8cwpAIAz3X34z9/udDmi1MOz4CjUbTuGhv2h1RVFgNcP+d0JiIhilp4m+MJ7Bno/3tR4CH8+fNLFiFJP0iQWEfHZX9vHE23H0VMIvxPW1dOPbUdOuRMQEY3ItQvyUbZkBgBAFfg3Vj92lOcSi0kaFSJSGzarXURURBRAO4DKCOuWiMhaEakUkZp4xbh09jRMygp9Rv3ErHRcNGtqvHZJRA77/LuXIMM0Xv7T3hP47VtHXY4odXgqsZirkEDCKAqbXa2qYhvWRNjEBlVdr6rrADSLyKDk44QbF83AZXN8IdMum+PDjYtmxGN3RBQHRQWTcfdVFwbHv/L0DnT1svqxEzyVWFTVbxJGfYTZ/qHWFZEKAHW2SbUAypyLbkB6muCJ1VeETHti9RXsuoUoydy/YgF8k6zeMQ6cOIMf/HGfuwGlCE8llmGUiEiDuR1WG6GMpRiA/fmjzRh81eOY8CTCpEKUfHyTskIeE/7tZ/egtaPLxYhSQzIllhpVXQYgF0AegA1h830Jj4iIkt5fX3khigus7pdOd/Xi67W7XI4o+SVNYlHVavPXD+BhACVhi/gBTA+bdiLa9kwFgXoRqW9paXEwUiJKJpnpaSHVj3/6pwPYcZQ1PMfCkcQiIu8XkctEJFHVovwAGsOmtSH01pcvwjJBqlqtqqWqWlpQUOB4gESUPG5cVIDrFlj9+/Ur8OWntrP68Rg4dcWyCVZBeVzKNERkbdikcgDrzLyVYTEElAHYGI94iCi1iAi+ePtFCBSVvrinFc/uOO5uUEnMqcTSqKr/qapbAxNEZJqIbBnphkSkDMBDAIpMm5QiAHWmwL7SJJkqVQ30HrdBRErM+DoRqTI1xKCqm8b8zohoXFh43hR8+Iq5wfF///V29PT1uxhR8spwaDtt4RNU9aSIjLiqlKrWIbTacEB5lOVzba+rR7o/IqKAB8oW4hdbj+D0uV40t3biiZf34+PXznM7rKQT78L7qIXnREReM31yNj5100D140c374b/TLeLESUnpxJLsYg8KCI3hRXgDyr9EpGbHNonEZHjPnL1hbhw+iQAwMmzPfhG3W6XI0o+TiWWPAD/COsWVruI9JnylSIRWR2WcCLe0iIi8oLsjHQ8dNuS4PgTr+zHnuMdLkaUfJxKLNWqmqeqaQCWA7gPQAOAvQC+ClvCARBew4uIyFNuufg8XFmUB8B6/tJXnt7uckTJxanEEqzWq6qNpo3Ivap6s6rmwWotfzOs2l5+h/ZJRBQXgerHgepHz+44jhd2sSF1rBxJLKr62jDzT6rqZlVdj8FdsRARec7Fs6Zh1bILguNf/vU29LL6cUwS3qWLqn4u0fskIhqNf7h5EXLMs5d2HevAT7ccdDmi5JA0fYURESXajKkT8Hfvmh8cZweVsWFiISIawupr52G2byIA4EQn27TEgomFiGgIEzLT8bnbFkec19fPjiojGbJLFxH5HoB49WfQrKr3xWnbRESOuf2S8/E/f9yLxgP+kOl3P/Yqnx4bwXB9hVUifg/Q8sdpu0REjhIR3PaO8wcllq0H/Xhu53GsWHKeO4F51JCJRVX3JioQIiIvO9vdF3HatiOnmFjCsIyFiCgGF8+aiomZoafMrIw0XDQrUc83TB7Ddpsfr6dCqiqf/UlESePGRTNw+dxcvNQ08JSQrIw03LCQT6ANN1zh/e8Qv6dCNqnqLfHYNhGR09LTBE+svgLFn386OO30uV681NSG65lcQgxXxnJzogIhIvK6SLW/vla7C9ctyMconmuYsljGQkQ0Bq8f9GPz9uNuh+EpTCxERGP0tdpd6GdjyaCYnnkvIoWqui/OsSSd+1csGH4hIkpZEzPTcbanD9vfPoVn3jqKd7/jfLdD8oSYEguAKgAsaA/zQPlCt0MgIhd99OpCfO/5JgBWB5W3XDyTrfAR+62w5SLyrtFUPY5XdWUiIretub4Ik7Ot3+d7jnfgl68fdjkibxhJGUvg8cJtIvJbEXlYRN4vIpcNsx4f7EVEKSk3Jwurrx3oTvHRut3o4cPARpRYfgbgNQACoBzAOgA1ABpEpE9EdovIRhH5B5NwCs16cWkHE4mI+Oyv7eNERPGw+rp5mDYxEwCwr+0Mft54yOWI3BdrYqlT1TtUtdT2DPtyAPfBSjSvASgGsArAelgJp0lE+gCUjCQgkxAqRKQ2bHqViLSboTLK6u0ioiKiANphdaJJRBQ3UydkouL6gd/P39y8B129g/sVG09iTSwb7SO2Z9hXW6NaqqppGEgunwOwGcCIum0xVxiBZFBkm74SQBOsLvzvAbDWTAtXrapiG9aMZP9ERKPxsasLMT0nCwBw2H8WT47zRxjHlFhU9WcxLrdXVX+mql9V1ZtVNRfW1UxMVNVvkkF92Cy/qq438zcBaETkW2z+WPdFROSUnOwM3HdjcXD8W8/uwbme8XvVkogGks1j3YCq1oVN8sFKLuFKRKTB3A6rZRkLESXKXVdeiBlTsgEAx0934Uev7Hc5IvckIrHc4+TGRKQE1tMnw5MNANSo6jJYZUB5YI00IkqQCZnp+ORN84Pj332uCZ1dvS5G5J6YEstY2qKo6snRrhshDh+Ah1S1PMq+qs1fP4CHMUTFAVNBoF5E6ltaWpwKkYjGsTuWz8Fs30QAQFtnN37w0j53A3JJrFcsXvnlX4nYr4D8iHy7DICVhEylg9KCAnZ5TURjl52Rjk+tGLhqqX6hGafO9bgYkTtiTSwlIjIlrpEMQ0SqYGqMmSrJReb1SvN3bdgqgbY2REQJ8/6SC1A4fRIA4OTZHjz2h/H3hPdYE8t0AP6wRpDvCjSCHCrpiMjDIwlIRMoAPASgSETWikiRSSoVsKoct5uhyayywZS71JkC+0qTZKpUdcwVB4iIRiIzPQ33lw10UPvYi3vR3tntYkSJF2snlACwF1Y7lUBbFXsf0Y0i0girBtgWAI223pADiSImplA+vGB+jRkiLZ9rG41Y9kJElEjvu3Q2vvP7Juw53oGOrl5UvdCMz9222O2wEibWK5Z6VZ1vGkEGWt1/DlbZi73VfaCblybTzUsbRtjynogo2aWnCT5j6/388Zf2oeV0l4sRJVasiaUm8MLW6v6rqnpvWKv7QML5GawrnNzImyMiSm23XjwTS863KtSe7enDd59rGmaN1BFry/tha4WZVveBhHOHqs6H1ZZk/JVcEdG4l5YmeNB21fKjV/fj7ZNnXYwoceLaQNK0J4la5ZeIKJWtWDIDl87xAQC6e/vx7Wf3uBtQgsS95b2q3hHvfRAReZFI6FXLk/UHcfDEGRcjSoxEdOlCRDRuXbcgH+8szAMA9PQpvrl5t8sRxR8TCxFRHIkIHrx54KrlZ42H0NzS4WJE8cfEQkQUZ1cUTce18/MBAP0KPJriVy1MLERECfAZ21XLL18/gp1HT7sYTXwN2fJeRL4H66mN8dCsqvfFadtERJ5SMjcXKxbPwOYdx6EKPFK7C9+7e5nbYcXFcF26VMJ6qFY8+OO0XSIiT3qgfCE27zgOAHjmraP48+GTWDp7mstROW/IxKKqbNxIROSQpbOn4balM/GbPx8FAHy9dhf++2PLXY7KeSxjISJKoAfKF0LEev3sjuNo2N/ubkBxMGxiEZGp8RgS8eaIiLxm4XlT8L5LZwXHH6nd5WI08TFc4f3vABTFY8ci0qSqt8Rj20REXnb/igV46o230deveHFPK15uasNVxdPdDssxw5Wx3JyoQIiIxouigsn4QMlsPFl/CADw9dqdeLLoKkjgHlmSYxkLEZELPnnTAmSmW4lky752/GF3q8sROYeJhYjIBXPyJuHO5XOC41/73U6o6hBrJA8mFiIil3ziXQuQlWGdhl8/dBJ124+7HJEzmFiIiFwyc9oE3HXFhcHxr9fuQn9/8l+1MLEQEbnovhuLMTEzHQCw/e1TwcaTyYyJhYjIRQVTsvGxawqD44/U7UJfkl+1MLEQEblszfVFmJJttf7Yc7wDv3z9sMsRjQ0TCxGRy3yTsvDxawc6kv9G3W709PW7GNHYeC6xiIhPRCpEpDZseomIrBWRShGpibLusMsQEXnR6uvmYdrETADA/rYz+FnDIZcjGj1PJRYR8cHqqh8Y3JXMBlVdr6rrADSLSCUGi2UZIiLPmTohE2tuGDjtfevZPejq7XMxotHzVGJRVb+qrgFQb58uIhUA6myTagGUjXQZIiIv++hVhZiekwUAOOw/i41bDroc0eh4KrEMoRhAm228GYOvaGJZhojIs3KyM3DfjcXB8W89uwdnu5PvqiVZEovPoWWIiDztrisvxHlTswEALae78KNX9rsc0cglS2LxAwjvU/rEKJYJMhUE6kWkvqWlZcwBEhE5YUJmOj7xrvnB8e8+34TOrl4XIxq5ZEksbQi9reUD0DiKZYJUtVpVS1W1tKCgwKEwiYjG7o7lczDbNxEAcKKzGz94aZ+7AY1QsiSWTQgtiC8DsBEARGTlcMsQESWT7Ix03L9iQXC86vkmnDzb42JEI+O5xCIiZQAeAlBk2qQUqWozgHUiUmVqf0FVN5lVNohIyTDLEBEllfeXzEbh9EkAgFPnevHYi3tdjih2Qz5B0g2qWofQasOB6dVRls8dbhkiIqfYryTiKSM9DZ8uW4hPb9wKAPjvF/fiY1cXIs9UR/Yyz12xEBF52QPlC4NDvL330llYMGMyAKCjqxdVLzTFfZ9OYGIhIvKo9DQJSWCPv7QPx0+fczGi2DCxEBF52K0Xz8RF508FAJzr6cd3n/P+VQsTCxGRh6WlCR68eeCq5cevHMDbJ8+6GNHwmFiIiDzupsUzcNkcHwCgu68f3352j7sBDYOJhYjI40RCr1o2bjmIgyfOuBjR0JhYiIiSwLXz8/HOeXkAgN5+xaObd7scUXRMLERESUBE8KCthtjPGw+huaXDxYiiY2IhIkoSVxRNx3UL8gEA/Wo9wni0HqndFRycxsRCRJREPmO7avnVG0ew8+jpUW3n0c27g4PTmFiIiJLI5XNzsWLxDACAKuJyxTFWTCxEREnG3hr/mbeO4s1DJ12MZjAmFiKiJLN09jS8+x0zg+Nfr93pYjSDMbEQESWhT5cthIj1+vc7W9Cwv93dgGyYWIiIktDC86bgLy6dFRz30lULEwsRUZK6v2wh0tOsy5Y/7mnDy01tLkdkYWIhIkpS8/Jz8IGS2cHxr9fuhKq6GJGFiYWIKIl98qYFyEy3rlq27GvHC7tbXY6IiYWIKKnNyZuEDy6fGxz/2u/cv2phYiEiSnJ//675yMqwTudvHDqJ2m3HXI2HiYWIKMnNnDYBd195YXD867W70N/v3lULEwsRUQq478ZiTMxMBwDsOHoaT//5bddiYWIhIkoB+ZOz8bFrCoPjj9TuQp9LVy1MLEREKWLN9UWYkp0BAGhq6cQvth52JY6USywi4rO/to8TEaUy36QsrL5uXnD8G3W70dPXn/A4kiaxiEiFiGiEoSxs0fbAPADtACpdCJeIyBUfv3Yepk3MBAAcOHEGmxoOJTyGpEksxjJVFVUVALkANqlqXdgy1YFlzLDGhTiJiFwxdUIm1txQFBz/1ubd6OrtS2gMyZRYmlW10TZeAeDhCMv5ExMOEZE3fezqQuRPzgIAHDl5Dj/908GE7j9pEov9ysSUmxSHJZqAEhFpMLfDalnGQkTjzaSsDNx7Q3Fw/Nu/34Oz3Ym7akmaxBJmA6KXndSo6jJYt8ryzLKDmDKbehGpb2lpiVOYRETuuOvKC3He1GwAQMvpLvzolf0J23fSJRZTWH9CVZsjzVfVavPXD+tWWUm05VS1VFVLCwoK4hUuEZErJmSm4xM3LQiOf/f5JnR09SZk30mXWABUAVgX47J+AJFulxERpbw7S+dgtm8iAOBEZzd+8Me9CdlvUiUWEVkLoNFcjdinr7TNtytH7EmIiCilZGWk4f4VA1ct1S804+TZnrjvN2kSiymEr0TkmmAbRKQEQJ0psK80SaYq2i0zIqLx4P0ls1E4fRIA4NS5Xjz2h/ifEjPivgeHmKsUiTIv1zZanpCAiIiSQEZ6Gj5dthCf3rgVAPDYi3vxsWvmDb3SGCXNFQsREY3Oey+dhQUzJgMAOrv7UPVCU1z3x8RCRJTi0tMEnylfGBx//KV9cd0fEwsR0Thwy8UzcdH5UwEA53ri2zElEwsR0TiQliZ48OaFwy/oxL4SshciInLdTYtn4LI5vkHTnX4gGBMLEdE4ISJ4oGzBoOl3P/aqo8mFiYWIaBzp6etHWljDja0H/Xhu53HH9sHEQkQ0jmx7+zQ07OLkbHcfth055dg+mFiIiMaRi2dNxcSs9JBpE7PScdGsqY7tg4mFiGgcuXHR4AL8y+b4cOOiGY7tg4mFiGgcSU8TPLH6ipBpT6y+AunhBS9jwMRCRDTOhCcRJ5MKwMRCREQOY2IhIiJHMbEQEZGjmFiIiMhRTCxEROQoJhYiInIUEwsRETmKiYWIiBzFxEJERI5iYiEiIkcxsRARkaOSKrGIiM/+2j5ORETekOF2ACPULhLSWVo1gDX2CSJSAqAMwHQARaq6KnHhERFRUl2xAKhWVbENayIss0FV16vqOgDNIlKZ6CCJiMazZEss/qFmikgFgDrbpFpYVy9ERJQgyZZYSkSkQURURGojlLEUA2izjTcDKEpYdERElHSJpUZVlwHIBZAHYEPYfF+sGxKRChGpF5H6lpYWB0MkIhrfkiqxqGq1+esH8DCAkrBF/LAK7e1ORNuWqpaqamlBQYHDkRIRjV9JlVjC+AE0hk1rQ+itL1+EZYiIKI6SJrGIyNqwSeUA1pl5K820TQgtrC8DsDH+0RERUUAytWOpE5FaWFcgbQCqVLXZzNsgIs2q2igi60SkCkADAKjqJpfiJSIal5ImsahqI6yrlEjzcm2vqxMWFBERDZI0t8KIiCg5MLEQEZGjmFiIiMhRTCxEROQoJhYiInIUEwsRETmKiYWIiBzFxEJERI5KmgaSRETknPtXLIjbtplYiIjGoQfKF8Zt27wVRkREjmJiISIiRzGxEBGRo5hYiIjIUUwsRETkKCYWIiJyFBMLERE5iomFiIgcxcRCRESOYmIhIiJHiaq6HYPrRKQFwP5Rrp4PoNXBcOKFcTqLcTovWWJlnAMuVNWC8IlMLGMkIvWqWup2HMNhnM5inM5LllgZ5/B4K4yIiBzFxEJERI5iYhm7arcDiBHjdBbjdF6yxMo4h8EyFiIichSvWIiIyFFMLOOYiPjsr+3jbvBaPCOV7PF7jdeOp9fiGY1EvQcmlihEpEpE2s1QGWWZEhFZKyKVIlIz3HQX44y2TLuIqIgogHYAEddPVJzR4vHS8RSRikCMYUNZtPjjFGeliDQlwfczlji98P0cNs5o8STyeMYSqye+o6rKIWwAsBLAWgA+81oBrIywXIPtdSWAyqGmuxHnUMsAqPLY8YwYj8eOZwWAEtu4D0BNgo9nCYAK87rIxFnmwe/nsHF65PsZ6/F09fs5gmPq+neUVyyR+VV1var6VXUTgEZYH2KQiFQAqLNNqgVQFm26W3EOs4w/TnGNJs6I8XjweDaraqNtvALAw4H14xRXOL+qVgOAqjbDirPZvoCHvp9DxgnvfD+HizNiPAk+nkBssbr+HWViiUBV68Im+WB9gHbFANps482w/iGiTXdcLHEOs0yJiDSYS+PaeN1vjfF4RovHs8czEJ/tnzhRxzN4IjEntir7NMML389h4/TI9zOW4xktnoQdz1hj9cJ3lIllGCJSAusXQKR/gEiiTY+rIeIcapkaVV0GIBdAHoANLscZKR5fvGOKJJbjCSs++z3qhB5PEakCsAxApbl/bueLslq06XEzTJz25Vz9fsYQp5e+nzEdU7j1HU3E/bZkHWC7NxlhXsi9VFi/UpqiTXcrzliXgXV/2/U4w+Px8PEswxD3qxNxPG37qgzfl5e+n0PFGesxd/t4RovHzeMZwzF17TvKK5ahVQK4J8q8NoRe8vpgXcJHmx5PQ8UZ6zJ+eCPOAD+8fTyrAKwbYr4fcYozwu2LLRh879z172eMcQa49v0cYZwBfrjw/RxhrK59RxOSVZNxMB9KEawvig9AUSDLm79FANpty6+F9Qsg4nS34oy2DIC1YdupDKzr0vGMGI9Hj+dahP26TuTxNNv2hcVc5rXvZyxxeuH7GePx9Mr3M9Zj6up3lF26RGDuX1aET1dVEZF2ACtUtdEUni0D0ADrw15v1o843Y04AayJtIyJrxIDv7o2aeQCy0TFiWjxeOx4NsOq/79MbTVvTPlAoo5nGaxfoo2wbsfUB2Lx2Pdz2Djhje9nLHEiWjyJOp4jiNX97ygTCxEROYllLERE5CgmFiIichQTCxEROYqJhYiIHMXEQkREjmJiISIiRzGx0KiJSJF5DkXcOt1zwzB9L6WUVP0Mh2Lac1AcMbHQiInIShFpwEBfSSlzUjIPakr5E4+XPkPzoKxIjSTjpWY8/XhwAxMLjZiqblKrh9Shev5NOqblfVxbTnuFxz7DzQCqRGRlgvZXDiu5pPwPCLdkuB0AJTW/2wE4xZzUKmB1Jz6e+N0OAMCTsHrijXfnogCsZ5qIyMOwEtp4+7wTglcsRJYNAKpV1e92IOONqq5R1eJ49VsVRTUAn4isTeA+xw0mFhr3zMnFh9AHIlEKMz8gqmE9KMvnbjSph4mF4sYUytaYR6A2mUeiDvmPbNapMuu0m3VWmnm+OJ0E1sB6cmHEX8yBgm7b+2gy5THhy/nM+2syj34dtJxZpsy2XJF5z4HHxbbbf0WbZQPbCx6LaNsz04pMrIHt1Yy21pfZdq3Zd5N5HVPZRCzHzbz3tWa5srB5Fbb1AkODbWiKsM5I4q0xfxNZcWB8iNdzAzik/gDrH1NhngcRNq8SVo2jItu0MrN8O4CSKOsEuh4HrKuIBrOOfXDsGRKwakMphn5SaDAmM22l9a8TslyJeV8V5n1WmPevsD3FD9ZzMpps7yUwXmkG+/QKs++1sJ67EZhXFra9Gtu8ErNOlS32wLxBzwkZ5jOsjfDea6Jta6THzRwne+xlEbZRGWHbgWMR/tTMEcVrvl8KoNbt/6VUG1wPgEPyDtFOSrYEEumfeWWUk8LKKNsqibS8g++hInAijzI/4rzwk5GJvzJsWuDEpRHWDySXSNsJnGirwuatjXYitK0zKEHa1lP7SXeYz7A2yvKB99Q+zHGN6bgNE0MVwn5E2L5bIT8wRhsvrB8Dgz4fDmMbXA+AQ/IOQ5wQmob6Z7WdVCsibCvSlcywJ7IxvIfK8Fhs8wInpZBfwYF5MW4/2gmvYZj3G2mdwNXVoCQbbZ0I+6sImz7oM7TtpyHKttrDT+xjOW5DfI/Ck4rPtu+VTsRr+y7G9HlyiG1gdWNylCkDKcLQ1VjrYF0pLIswrwi2aqe2MpV4tbcIbP9E+AxV9YtII8xtLvO6Dtav7qjxmPKMMljvr9RMzkPoMQnszz4Ntmk+DauhplY12cC2YlrHZqN5H5GOebhAmUSJKbcJ3+YJWE8pjGi0xy3CdsL3sRnW57VeVTc5FG9g2fDPh8aAiYWcFkshcZP5az9BVsHcThKRRttJpRJAo6qucjDGmKnqMhl4ZHGJGdaKSDOAVRr66Ne1sCoC1MH6Ff6kmVWBgQTmFr/52zTUQkbgc2lUqxHliI3kuMXCbKvExLTO6XjJWawVRk7zm7++GJbdEnhhfsmuMusHav/Uwrq9EfVkYat5FLHmj5lfZWpGDap5hMhJLoRa7SwE1q/9dSbGIlhtXwL7qYGVBFeZ5euGuHpwQ7H5u2nIpSyBpO4byw5jOW6xsDVe9WPg+fN2Y4k3sM6gK1YaPSYWcpT99sUQ/TEFTnIht0XM7Y1mAOWqukxVy1W1OtIGTPXaWljdc0RNKgDuNCe4VbBObuHVT4c8Kdn7sFLVRlVdr6q59nXM7bqVsH4xJ6T1+CisBFAX4fZSJPXmb9Fouz2J5bjFuJ0iDFQLXhWerM38scSbZ2L0D7McjQATC8VDoK+taA0O74DVyj3kJCxWp4glMd6H36iq5bBuoUVTaZ9vttsM63ZVQCCG5eErm4RRHmXbfgwkxsDtP1/Y+oFf2pFEvUIai0jtVcxJPg+h7zsqc5INvLeaSG2HhmoXM4LjNiSznVozuj78e2GSyJrRxmuW8yFBXcmMJyxjobHwhf0FAKjqOnO1UiIiNfbyEXPLqF5VQ05y5p+8xLxuwMCv0AA/rGTSaPYx5MnAnHTsv2YD6mAVrAdibTb3/SP90i0CsFJEKuxXTiJSCatBZeA9BK4Cisz724iBE2vI/iLwRZsmItEK4yOtE9AkIuswcLJcZd7HvGG2Fb7NVbAKy0sANIjIJlgneR+AO2HdOoqWqGI9bsPFUGO2VRehXAUmjsCtzNHEG6hYsTHK+6DRcrtaGofkG2DdVrE3vGuH+ScOW26tWa4d1kmiBkM0rMNAm5KhhmjtXEoixKgR9lEZPh0D7TzCq7cGbsO0m6HBvM9IVZPLbMckcOUFDFQ3bjL7CW8g2Q7T9gQDjUrVtk5N2H7s8ypt0wNVamtsr5uiHe8RfIb2mALLDNc4MqbjNlQMCG1702SWCwxNtvcYqVFlTPFioKq5Yw1uOViDmANM5CpzhVEJ69f/OjW/rm3Vlythnbzr1LoFZl+vAcAyDa2hVQGrgaGE7WctrBOy2Kb5YJ2E1mvkX8aeJyLtsJKCDLswAQgeszp1qcZhKmMZC7nOnNgbAOSpVdDuD8xTVb9JGCP95/fFuqDZ3zqwz6hxw5R/+WB97uQwJhbygkDB6lCFuoHC7tohlrHzDzFvUM0otR7u1SzsRn28qIRV8J/IrvrHDSYWcp3tFtZQVUXXwEoIEasfR9AMRKwNNB3RW2GvALBmtD0Bu8zndgDJwjS2rNMoVdlp7JhYyCvWACiLdMVgahOthNW+xR/j9gK1wcJrZJUhShVls+1yxFgtl5KPue3q18E108hBLLwnzzAF8Q/BujV2AgO3s2qj/bo01ZprYSWd8HYOVbBqDRXbtr9BU6jbj8B7xMDtxGZYv8Z54iTXMLFQUjK3q9Zh4KTaDKu7koftVzXmaqcEA7e/1o3gqoeIRoGJhYiIHMUyFiIichQTCxEROYqJhYiIHMXEQkREjmJiISIiRzGxEBGRo/4/v5KB37u42pIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = torch.load('fcn-sample.pt')\n",
    "plt.figure(figsize=(6,5))\n",
    "plt.errorbar(data['sample_size'].log10(), \n",
    "             data['hessian_fro'].mean(dim=1), \n",
    "             data['hessian_fro'].std(dim=1), \n",
    "             linestyle='-', marker='o', linewidth=3, markersize=5)\n",
    "\n",
    "plt.xlabel(r' $\\log_{10}$(sample size)', fontsize=23)\n",
    "plt.tick_params(axis='both', labelsize=13)\n",
    "plt.ylabel(r'$\\|H\\|_F$', fontsize=23)\n",
    "plt.title(r'FCN', fontsize=23)\n",
    "\n",
    "plt.savefig('../figs/hessian_fro_fcn_samplesize.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "id": "3a42f548-e904-4333-8194-8120e2083671",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAFiCAYAAAAz0jXdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAv30lEQVR4nO3deXQb53ku8OflIlGUSEGUqH0FKW/yIlOU7XjVQjlNmjROTMlpE2epbdJxHLtpUim6vefe9rS3Lp3m9thxYlOK45s4rmORidNsdUNS3hdZJGXH8UqRWqjFFkURIrVzee8f8wEcgAAJcgAMAD6/c+YIGAyAD0NoHnzLfCOqCiIiIicy3C4AERGlPoYJERE5xjAhIiLHGCZEROQYw4SIiBxjmBARkWMME6IoiEi5iNSJSFeERW1LdYTX8IpItYg0iUir7blNZn25iHjMthW294v4mrbXLgspX52IVMVhVxCFxTAhioKq1qrqOgD3AfAAaFTVabZFABQBqAXQan+uCZEmAP7lDlUtUtVpAJYA2ASgAEANgA3m/baY96s3L1MhIhXDlK/ebL8JQJuqrlPVTTHbAUQjYJgQjU5bpAdUtQ1AHQCff52IlMMKl2MAlpiQaLY9x2eCYD2sIGoMedljtvesFpGSEcp3zCxECcUwIYqterPAHPhrYIXLelX1jfBce3D4+QBUwwoaAGjwN4UNY6T3IYo5hglRDKlqm6mhAFaQAFazli+K51ZG2M5nai7NsJrYGmJQVKKYYpgQxYiIlIjIRv9tAF7zUH3kZ43KWli1jhIRqRlhW6KEYpgQxU6Z7Xap+dcXTa0kGuZ11pq75f7gIkoGDBOisSkzQ3YDQ4MB2IfiFpl/Y9oZbjrv15u7VaaDn8h1DBOisalXVQkZGlxpe7zT/OsN81xHVLUWwP3mbo2IxPw9iEaLYUIUO/ZhvYHhv1EM5x01cw6Jvy+mLooRXkRxxTAhihFVbVbV+83tegwO0Y3XmejrYQ0l9mJw5BiRKxgmRPFzh/m3bLiz18fKdMivs73HsFOuEMUTw4QoTkL6NqqjmSsrTP+HxyyR3qMNg4FSgeB+G6KEYZgQjY7/YF8Qzcamb2MdrOaojWbkV7WZyLHELOVmXReGBkcBgOkjvEc9rDm5gODhyUQJI6rqdhmIkp4ZgluJ4IN1M6wJH6OqDZimrvWwzkHxmNVtZqlR1S0h21YCKIHV99IIoNrUdiK9fg2AclgjzdZF2o4oHhgmRGnENJN57JNJEiUCw4SIiBxjnwkRETnGMCEiIscYJkRE5BjDhIiIHMtyuwCJMmPGDF28eLHbxSAiShlNTU1HVbUwmm3HTZgsXrwYjY2hl9cmIqJIRGRftNuymYuIiBxjmBARkWMMEyIicoxhQkREjjFMiIjIMYYJERE5xjAhIiLHGCZERORYUp20KCIeABsArI90cR9zkaKV5q4XwB3mWthEROSSpKmZmCDxXyM79DrY/m28ADar6iZzOdSdADYnpoRERBRJ0tRMTO2iUkRKhtnMi+BrZPtgXfKUiIhclDRhEg1VrRcRiEgrrFqMx37d7Hj497oPAre/ue68eL4VEVHKSqkwMSoBbIIVJsdEpFZV41Y7eaChJXCbYUJEFF7S9JlEwzSBVarqOlWdBqAZQPUw21eISKOINHZ0dCSsnERE401KhQmAWwA8Zbu/CRE66wFAVbeoaqmqlhYWRjUlPxERjUFKhIkZDgwArQDsQ4a9AGoTXyIiIrJLqjARkTJYQ329IrLRDAUGgK0iUuLvbDePVQAoMUOEiYjIRUnVAa+q9QDqw6yfZrtdmdBCERHRiJKqZkJERKmJYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI4lVZiIiEdEKkSkLoptK0Rko4h4ElA0IiIaRtKEiQmFKnPXO8K2NQAaVfV+VfXFuWhERDSCLLcL4GdCoVJESobbzgTJfaranJCCERHRiJKmZhINESkDUALAKyJ1IlLNZi4iIvelVJgAWAegAIAPwHpYzWE1kTY2/SqNItLY0dGRmBISEY1DqRYmHgD1qlpvmsWqAJRF2lhVt6hqqaqWFhYWJqiIRETjT6qFSWvI/TZYtRQiInJRSoSJiJSbm7Ww+kz8ygBsSXyJiIjILqnCxHSwb4bVwb5RRPxDhLeKSImqtgHYZDreKwB4VHWTawUmIiIASTQ0GABUtR5AfZj102y3a2HVUIiIKEkkVc2EiIhSE8OEiIgcY5gQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkWFKFiYh4RKRCROqi2LZcRGoSUS4iIhpe0oSJiHgAVJm73ii23QzAE9dCERFRVJImTFTVp6qVABqj2LwCQDUAX1wLRUREUUmaMImWiJQAaHa7HERENCjlwgRAmarWu10IIiIalFJhIiLlAGpHsX2FiDSKSGNHR0ccS0ZENL5luV2AUaoE4BURACgA4BGRVlUtCrexqm4BsAUASktLNWGlJCIaZ1KiZmJqJFDVdapaZMJjE4DaSEFCRESJk1RhIiJlsIb8ekVko4j4hwhvNR3v/u0qYNVSysxtIiJyUVI1c5mO9SGd66o6LeR+oPmKiIjcl1Q1EyIiSk0MEyIicoxhQkREjjFMiIjIMYYJERE5xjAhIiLHGCZEROQYw4SIiBxjmBARkWMMEyIicmzM06mIyGIAJbAusTvdrO6EdfXDNlXd7rRwRESUGkYVJiZAvgNgPazrrx8HcAzBl8/1wJqoUWFdEfFfVPVp50UlIqJkFVUzl4jki8g2AE1mVQWAAlUtUNViVS21LcWqmgGgFNZkjH8vIi0i8tn4fIT46R/QYe8TEZFlxDARkcsBbAfwB1Wdrqp3quovVPX4cM9T1V2qulVVSwF8HMBfisjDsSl2/PUPKG59dEfQulsf3cFAISIKY9gwEZG1sJq11qrqj8b6JqrapqobADSIyFNjfZ1E+t0fD+PV1s6gdW+0+/Dc+0dcKhERUfIaqWbiVdVbRqqFREtVawF8R0SWx+L14mlv50mE1kFOn+vHO4e6XSkPEVEyG7YDXlW3xvoNVXVPrF8zHpbNzcfErAyc7RsIrMvJzsBFc/NdLBURUXLieSYRrDp/JlYsCrrAI/InZWPV+TNdKhERUfIaVZiIyBoR+baI3Gf+/ZwZLpx2MjMEj992ZdC6zhNn8VH3GZdKRESUvKI6z8R0xG+DdQ6J2B5S87gPwFMAqlR1X2yL6J7MDAm63zcAPPxcK/7ppotdKhERUXKKtmZSDaABwI0AimzLLbDCpQHA5wG0icjPRSQvDmVNCk/tbMfh46fdLgYRUVKJuplLVTeoaoOq7rEttdZDukFVCwCshHVW/F4RWRSvQrvpXP8AHnmu1e1iEBEllWjDpH6YcAgMG1bVZlWthBUq3xWRtBz69OTOdvadEBHZRBUmqnongK0isjrcw2G295+kGPOhxW66bIEHAHCubwAPs3ZCRBQwmtFc/kDZKSLfsvWLeIZ5TtmYS5aE/mbt0sDtJ1/fjyOsnRARARhdn0kbrMkb9wL4LgCfiLQAgAmXz5mhw58zw4Z3wppROG2sOr8Ql86fCgA42zeAR55vc7lERETJYVTnmaiqT1XXAygG8CNYI7kEVrjUAKgDUAvgfgDTAGyIaWldJiK411Y7eWLHPhzpYe2EiGhMZ8CbPpFKVS2GFRorYAXHBgDrABSZqeh3xa6oyWHNBTNx8TxrXMHZvgFsfYG1EyIix9OpqOpxM938L8zSkCrzb42FiOCeNYO1k8df24ejJ866WCIiIveNNAX9mli/obnQ1vJYv24irbtoFi6aY9VOzvSydkJENFLNpCiWF7QyIVIDIKWPviKCe2x9Jz99dR86WTshonFs2DAxU9DvMsOBL3PyRiLyd7DOO9mgqil/UZAbL5qFC2Zbo6NP9/Zj64tp27JHRDSiEftMVHULgEoAj5pQuT3amYLNUOFHRKTTeildGasLbbktIyN4ZNdPX92LYyfPuVgiIiL3RDVrsKo2AygVkXIAFQC2iEgXrOaqNgyeT1IA6yRGr1l8ALYAKE3HTvmPL5uN82fl4f2PenDqXD8efakNf/fxC9wuFhFRwo32PJNaVb0R1nDgSlizBXcBmA5rFuFpAPZgMECmq+rmaINERDwiUiEidcNsUy0iXWapGk35Yy0jQ/CNtcWB+z95ZR98p1g7IaLxJ6qaSSjTVFVrlpgQEQ+AKgBNsGo14bYpB9AKYAmsqVpqRGSnmb3YFZ+8eA6WzmxBy5ETOHG2D4++tAffuvF8t4pDROSKpLlsrzm7vhJA4zCb+VT1frNtLYBmRAieRLFqJ4N9J//v5b04fqrXxRIRESVe0oRJNFS1PmSVB1aguOrPL5mDosLJAICes3149OW06x4iIhpWSoWJnYiUAGgLEzAJl5kRfN7JYy/vwfHTrJ0Q0fgxpjAZ6Qx2Ebk82uHDY3x/D4DNqrpuhO0qRKRRRBo7OjriVRwAwKcunQuvv3Zypg+PsXZCROPIWGsmtSLSb847+RdzPkngqopmgsd4jrSqAnDHSBup6hZVLVXV0sLCwjgWx6qdfGPN4MiuH7+0B91nWDshovFhrGGyHsC/wZp+/juwpp7vEpEWEXlKRL4NwCsi/xKLQppRXP7b1TBBZYYSu9oBb/fpS+diyQyrdtJ9pg8/eXmvuwUiIkqQsU5Bv0tVN6lqKaxzS74GYJdZVsC6nskKWOeiRE1EygBshhVEG21BsVVESkyQVMAaHtxllqS5fm5WZga+vnqwdvKjl/agh7UTIhoHYjUF/RYTLDth1VqmmX9HddleVa1X1fWqKmYIcJtZP01Vm801VCR0cfoZYumm5XOxaHouAOD46V789NV9LpeIiCj+YjqaS1W/C2tKlbXm2iZpd3GskYTWTra+2IYTZ/tcLBERUfyNeTRXpNFaqtoA4LiI3O6kYKnss5fPw4KCSQAA36le/PTVve4WiIgozsY8mgtAq21E18P22YRNoKwTkdWxKmgqyc7MwN322skLbTjJ2gkRpbGxhskWWBM6fhdWJ/gtZl0gYGD1l1THpJQp6HMl8zF/mlU76TrVi8dfY98JEaWvsU70eL+IeFX1O/51IjIVQCmsUVxlsIYNF5srNfrMZp3mdr2q7h17sZNftuk72fzLtwBYtZMvfWwRcieMaZcTESU1Jx3wm+zXiDejuhrMKKwbVbUA1qiuWljXOykCcCesGkyNk0KniptL5mOex6qddJ48hyde2+9yiYiI4mPMYWLCY3sU2zSo6ndVdYOqFqtqBqxzRdLehKwMfG1VUeB+9QutOH2u38USERHFhysTPY6nIcPrS+djztQcAMDRE+fwxA72nRBR+knZWYNTxcSsTNwVVDtpw5le1k6IKL0wTBJgw8oFmJ1v1U46es7iP3aw74SI0gvDJAEmZmUG9Z088nwraydElFYYJglyy8oFmJU/EQBwpOcsfv46aydElD4YJgmSk52JO28YrJ08zNoJEaURhkkC/eUVC1GYZ9VOPuo+i5rGdpdLREQUGwyTBMrJzkTl9YPX8vrhc60428faCRGlPoZJgn3hykWYMcWqnRw+fgY1jQdcLhERkXMMkwSbNCG4dvLwc6041zfgYomIiJxjmLjgC1ctxPTJEwAAB32nUdvE2gkRpTaGiQtyJ2ShwlY7+cGzu1k7IaKUxjBxyRevWoQCW+3kl82snRBR6mKYuGTyxCzcft2SwP0fPLcbvf2snRBRamKYuOhLH1sMT242AKD92Gk8veugyyUiIhobhomLpkzMwh3XBfed9LF2QkQpiGHisi99bBGmTrJqJ/s6T+FXbxxyuURERKPHMHFZXk42br92sO/koe0tMaud/HvdB4GFiCiestwuAAFfvmYxtr7Yhu4zfdjbeQq/fvMQPlcy3/HrPtDQErj9zXXnOX49olRh/wHF735isGaSBPJzsnHbtYN9Jw9t343+AXWxRESp7YGGlsBCicEwSRJfuWYx8nKsimLb0ZP47R/Zd0JEqYNhkiSmTsrGV68Z7Dt5sKGFtRMiShkMkyRy2zVLkDfRqp20dpzE79467HKJiIiiwzBJIlNzs/GVaxYH7n+/oQUDrJ2kDY6uo3TGMEkyt127BFNM7aTlyAn8/k+snaQLdgpTOmOYJBlP7gR8+epFgfsPsnZCMcBaEcUbwyQJ3X6tF5MnZAIAPvjoBJ55+0OXS0SpjrUiirekChMR8YhIhYjUDbNNiYhsFJEqEalJZPkSZdrkCfjS1YsD91k7IaJklzRhIiIeAFXmrneYTbeq6v2quglAm4hUDbNtyrrjOi9yTe3kvQ978Id3PnK5REREkSVNmKiqT1UrATRG2kZEKgDU21bVASiLd9ncUDB5Am69KrjvRJW1k1QVes4QzyGidJM0YRKlIgCdtvttGL4Wk9LuuN6LSdlW7eSdw92oY+0kJfUPKG59dEfQulsf3cFAobSSahM9ehAcJsMyNZkKAFi4cGGcihQ/M6ZMxBevWoitL+4BYHWirrtoFkTE5ZJRJGd6+9HWcRK7O05g95ETaO04gTf3+3DAdzpou9f3HMP//vWfcHPJfFwwOx+TTJMmUapKtTDxAZgesu5YpI1VdQuALQBQWlqakj8DK64vwuOv7cOZ3gG8fagbDe8eQdlFs9wu1rjnO3UOu4+cCCytHSewu+MEDnSdRjStkX0Dip+9th8/e20/MgTwFk7Bsrn5ZpmKZXPz4cmdEP8PkobCNSlmZvAHWLylWph0Alhpu+8B0OxOURKjMG8ivnDlIjz60mDtZO2FM1k7SYCBAcXh7jPBgXHkBFqPnEDnyXOxex9F4D3+03ZxtHmeSbgoJGDmTM0Zt3/7vv4BnDjbh+7Tfeg+04vu073oPhN823fqHP4r5ETfWx/dgcdvu5KBEmcpESYiUq6qtQBqAWy2PVQG4Cl3SpU4lTd48bPX9uFs3wDeOngcz75/BGsuSP/aSaKuSXGubwD7Ok8OqWW0HjmJ0739o3qtDAEWFuSieOYUFBVOQdHMKVgyYzK++9/v4/U9g5XoRdNzcem8qXjncDfajp4MW5s56DuNg77TQX1l03KzA8FykQmZJTMmp8SBsrd/AN2ne9ETCAB7EJj1/oAw67pP96HnjLXuxNm+Mb3vzr3H8Pu3DuPTl82N8Sciu6QKExEpA1AJwCsiGwHUqmobgK0i0qaqzSKySUSqATQBgAmZtDYzLwd/deVCPPbyXgDAA/UtWH1++tdOYn1xr54zvWjtODmklrHv2KlRd4bnZGfAO2NKIDSKZ1rLoum5yMke2v/x5B1Xoeh//D5wf/u3VgUC4NS5Prx7uAfvHDqOtw914+1D3Xj/wx6cC3PFza5TvXhp91G8tPtoYN2k7ExcOCcvEC7L5ubjvFl5gXLEqtnnbF//kFpBT5hQsAeAfd1ogzlWevsV3655E7uPnMBfX7MEU3OzXSlHukuqMFHVegQP/fWvn2a7vSWhhUoSd95QhCd27Me5vgG8eeA4nv+gA6vOn+l2sZKOqqKj56xVy+iwwsLfGf5R99lRv9603OxAUPhrGsWFUzDPMwkZozgghx687fdzJ2RhxaJpWLEo8DVHb/8Adh85YcLlON451I13DnWjJ8yv89O9/Wje70Pzfl9gXVaGoHjmFFw4Jw+7bOsBoPzhV7D5kxfi5LngmkBPSJNRaI3hbF9sLic9ViJA3sQs5E/KRn5ONvInZSE/Jxt5ttsfdp/B080HhwTx2b4BPNDQgkdf2oMvX70It13rRcFk9knFUlKFCUU2Kz8Hf7lyAX7y6j4A1q/2G84rTPvaSSR9/QNo7zo9tD+j4wR6zoy+OWSeZ9KQWkbxzCmuHXCyMzNw4Zx8XDgnH+UrrEs4qyraj53G24EajPXvkZ6hIdk3oHjvwx6892HPkMd2tfuwofrVuH+GUJkZgvycrKCDvz0U8idlIy9n8HZ+jgkOc3vyhKwRA7x/QNF+7BReaR0c9JmTnYEzvVa4nDjbhx8824rHXt6LW69ahNuv86Iwb2JcP/d4wTBJIXeuKsKTr7fjXP8Adu334cWWo7j+vEK3ixUXoU0zTzcfwJ6jg0Nu9x49FbYZaDjZmYIlMyYHhUZR4RR4Cycjd0Ly/1cQESycnouF03PxiUvmBNZ39JwNBMs7JmT2dp6K+ftnZ8qQA33Yg3+OWR9Sg8idkBn3Hz+ZGYLHb7syqEnxzf91I555+0N8f/tu7D5yAgBw6lw/ql9ow09e3Yu/umIRKm/wYlZ+TlzLlu6S/38QBcyZOgm3rFyAx18brJ1ct3RG2tVOjp/uxc0/fDlo3Te3vRn18/MmZqEopJZRVDgZCwtykZWZaufpjqwwbyJWnT8zqNmz50wv3j3cg6d3HcC2xgNB4SwAFs+YjAUFuSOGwtRJpiaRk42c7IyU+K6FNilOzM7EZ5bPw6cunYv/+tNhPLR9d6DGdqZ3AD9+eQ9+tmMfPr9yAe68oQhzPZPcKHbKY5ikmK+tKsLPd+5Hb7+iaV8XXt7diWuXznC7WI6pKhr3dWHbznb8+s1DUbXPz8qfGNw0Zfo0ZuZNTImDXjzl5WTjiiUFWLFoGvZ1Bjf7fKxo+rgcKpuZIfjUpXPxyYvnoO7dj/BgQwvePtQNwBrR99NX9+HJ1/ejfMUC3LWqCAsKcl0usXOJGhEJMExSzlzPJGwoXYAnduwHADzQ8AGuKZ6esgfPI91n8Ivmg6hpbEfb0ZPDbrtycQFuWbkAxTOtpqn8nNQalXPv2qUJf89wzT7jMUjsMjIEH182GzdeNAvPvn8EDzTsxpvtPgDWyK8nX9+PbY3t+Nzl8/D11cVYPGOyuwV2INYjIofDMElBd60uxrbGdvT2K3bu7cKrrZ24ujh1aie9/QPY/t4RbNvZjuc+6Ag7LFcA2NfmTsjEnTd4sfbC1D2/Jt7/mSMZbiTZeCYiWHPBLKw+fyZebDmKBxta0LivC4DVZ1fTdAC/aD6Azyy3QqV45hSXS5zcGCYpaJ5nEspXLMCTr/trJy0pESYtH/VgW2M7nt51EEdPDD2DPG9iFj69fC5uLpmP7/3h/aCmmeULPBwKTXEhIrj+vEJct3QGXm3rxIMNLXitzTrBdECBp3cdxK/eOIg/v2QOvrFmKc6fnedyiZMTwyRF3bWqCDWN7egbUOzYcwyvtXXiKm/otGXu6znTi9+8eRjbGtvxhmlKCHWVtwAbShfgExfPCUx4yKYZSjQRwdVFM3B10Qy8vucYvr+9BS+2WCeHqgK//eNh/PaPh/Fny2bj7jXFuHjeVJdLnFwYJilqQUEubi6Zj6ca2wFYZ8VfVZEcYaJqBdy2xnb8/q3DgTH+drPzc1C+Yj7Wl87HoulD26TZNENuumJJAR6/7Uo07+/CQ9t3Y/t7RwKPPfP2h3jm7Q+x9oKZ+MbapVi+wONeQZMIwySFfX11MWqbrWGfr7Z14vU9x3DFkgLXynP4+Gn8svkgtjW2Y1+Y8xyyMwU3XjQb60vn47qlhQwISnolC6fhx19ZibcOHMf3t7cEXfG04b0jaHjvCK4/rxD3ri3GikXu/d9LBgyTFLZwei4+d/k81DQdAGCN7Hri9qsSWoazff1oePcItjW244UPOhBuiqsLZudhQ+kC3HT5vFGdUe7G6CeicC6ZPxVbvlSKdw9346Htu/H7Px0OTM75wgcdeOGDDlxdNB33rF2alM3NicAwSXF3rynGL3cdRP+A4uXdnWjcewyli+P/C+m9D7uxbecBPL3rALpO9Q55PC8nCzctn4cNpQtw8bz8MQ1ddmv0E1EkF87Jxw++UIKWj3rw0LO78Zs3DwV+QL3S2olXWjtxxeIC3LN2aUoP2R8LhkmKWzR9Mm5aPg+/aPbXTlrw+G1XxuW9jp/uxa/fPISaxnb88cDxsNtcUzwdG0oX4OPLZoedPZcoHSydlYcHPn857l27FD94thW/euNgYIj763uP4YuP7sDlCz24Z+1SrBonc+gxTNLA3WuK8fSuAxhQ4MWWo2ja1xU0A60TAwOK19o68VRjO57504dhz0y3hirPR/mK+Wlx1jBRtLyFU/C9DZfh3rVL8cPndqO26QD6TKjs2u/DVx/biUvnT8Xdq4vT/pLbDJM0sGTGZHxm+Tw8vesgAODBhhb85K+vcPSaB32nUdt4ADVN7TjQdXrI4xOyMvDxZbOxoXQ+rimaMarp2IniLdH9bQun5+Jfb74Ud68pxiPPt2LbzgOBiUj/eOA4Kh5vwoVz8vGNNcX4s2Wz0/L/C8MkTdy9phj/+cZBDCjw/AcdEc/pGM6Z3n7UvfMRtjW246XdR8Ne/W/Z3HzcsnIB/uKyubxGeQoZb4MZ3Opvmz8tF/980yW4e/VSPPJ8K558fX+gNv/u4W7c9UQzzps1BV9fXYxPXTo3rUY0MkzSRFHhFHz6srmBa4g/UP/BCM8Y9KeDx1HT2I5fvXEIx08P7UyfOikbn718HspXzOeJWimKgxkSa/bUHPzDXyzDXauLsPWFNvzstf2BK01+8NEJ3PvzN/BAQwu+vqoYn1k+Ny1ms2aYpJFvrCnGr988BFXg2fc7ht3Wd+oc/vONQ9jW2B6YOdVOBLi2eAY2lC7AuotmsTOdaAxm5uXg7//8Itx5QxF+9NIe/PSVvTh5zgqVto6T+FbNm1aorC7CZy+fjwlZqRsqDJM0UjwzD5+6dC5+8+ahsI9bw4ePYltjO/7w9kdhLy41f5o1K/HNK+ZjHq/rQBQT06dMxKY/uwCV13vx45f24LFX9gauCLr/2Cls+sVbeLBhN762qgjrS+djYlbq/XhjmKSZe9YU47d/PDSkv+Pf/vA+nm4+iIO+oZ3pE7My8ImLZ2ND6QJc5Z2elp2DRMnAkzsBf3vj+bjtOi9+8spePPrSnkDT8kHfafzPX/0JD23fjTtv8OLzVyxMqRYBhkmaWTorD5+4eDZ+/9aHQesf2r57yLaXzp+KDaUL8OnL5mLqpNS6NghRKps6KRv3rF2Kr16zGI+/tg8/enEPjp20ZtL+sPsM/uE37+AHz7Wi8nov/urKhSlxWenkLyGN2pVLpg8JE79pudn47OXWBIsXzslPcMmIyC4vJxt3rSrGV65ejP/YsR+PPN+GoyfOAgA6es7in3/3Lh5+rhW3X+fFrR9bhCkTk/eQnbwlozELNyILAD596Rx8b8PylO7kI0pHuROycPt1XnzxqkV48vX9eOT5VnzUbYVK58lzqHrmPVS/0IrbrlmCL1+zOCmvMsqjShpaNjcfuROC21pzJ2TipsvnMUiIklhOdia+es0SPP93q/FPN10cNAjGd6oX36v7ANf863b837oP4Ds19AJzbuKRJQ2tOn/mkGss8EqFRKkjJzsTt161CM9+exWqbr4EC23TFPWc6cODDS24tupZ3P/Me+g0zWJuY5ikocwMGTLZI69USJR6JmRl4JaVC7H9Wzfge+svw5IZgxeSO3G2Dz98rhXXVj2L//O7d3Ck54yLJWWYpC1eqZAofWRlZuDmFfNR/7c34IHPL8fSmVMCj53u7cfWF/fguqpn8Y+/eRsfHncnVBgmREQpIjND8Jnl8/Dff3M9fviFElwwOy/w2Nm+ATz28l5cf/+z+J+/egv7jwVf7bQ/3JXrYohhQkSUYjIyBJ+8ZA5+f8912HLrClw8b3CY/7n+Afzstf244f5ng55z66M74hooDBMiohSVkSG4cdls/Obua/HYV1YGDbwJjY032n147v0j8StL3F6ZiIgSQkSw+oKZePquq/H4bVdgnidnyDanz/XjnTCTusYKw4SIKE2ICK5bWoh//ItlmBhyTtmkCZm4aG78Zr1gmBARpZnVF8wacunueJ9rxjAhIkozbpxrllRzc4lICYAyANMBeFV1fZhtygGsNHe9AO5QVV/CCklElAISfa5ZstVMtqrq/aq6CUCbiFTZHxQRL4DNqrrJbLMTwGY3CkpERIOSJkxEpAJAvW1VHaxaip0XgMd23wcrUIiIyEXJ1MxVBKDTdr8NVngEqGq9iEBEWgFUAfCo6pZ4FuretUvj+fJERGkhmcLEg+AwiaQSwCZYYXJMRGpVtS3chqa2UwEACxcuHFOhvrnuvDE9j4hoPEmaZi5YTVbTQ9Yds98xHfSVqrpOVacBaAZQHekFVXWLqpaqamlhYWGsy0tEREYyhUkngpu1PLDCwu4WAE/Z7m8KeQ4REbkgmcKkFsEd7mUwwWGGAwNAK4B1tm285nlEROSipOkzUdU2EdkkItUAmsw6f1BsFZE2Vd0iItUishFWs5jHDBEmIiIXJU2YAFYfR4T102y3KxNXIiIiikYyNXMREVGKYpgQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI4xTIiIyDGGCREROZZUU9ATEVHs3Lt2acLei2FCRJSmvrnuvIS9F5u5iIjIMYYJERE5xjAhIiLHGCZEROQYw4SIiBzjaK40lshhgUQ0vjFM0lgihwUS0fjGZi4iInKMYUJERI4xTIiIyDGGCREROcYwISIixxgmRETkGMOEiIgcY5gQEZFjDBMiInKMYUJERI6JqrpdhoQQkQ4A+xL8tjMAHE3we6Yq7qvR4f6KHvfV6Nj31yJVLYzmSeMmTNwgIo2qWup2OVIB99XocH9Fj/tqdMa6v9jMRUREjjFMiIjIMYZJfG1xuwAphPtqdLi/osd9NTpj2l/sMyEiIsdYMyEiIscYJpRURMRjv22/P95wX0Qvmn3F/RlfDJMYEZFyEamJ8FiJiGwUkapI24wnw+0rAF0ioiKiALoAVCWwaMlmxH3B71ZANN8bfrdCiEiF+f54wjw2qu8W+0xiwPwhGgAcU9V1YR5vUtUV5nYVAKjqpoQWMklEsa+qVbUy4QVLQtHsC363LFHuK363bExA3KeqzREeH9V3izWT2KgAUA3AF/qAiFQAqLetqgNQlphiJaWI+8qItH488g33IL9bQXwx2mZciCJIRv3dYpg4JCIlAML+QYwiAJ22+20AvHEtVJKKYl8BQImINJnmiLpx3q490r7gd2tQNN8bfrcAiEgZgBIAXrMfqmPx3WKYOFemqvXDPO5JVEFSwEj7CgBqTNV6GoACAFvjX6ykNdK+8CS8RMkrmu8Nv1uWdbA+vw/AelghEdon4hntizJMHBCRcgC1I2zmAzA9ZN2xuBQoiUW5r6CqW8y/PgD3wfoFNS5FsS984HcLQHTfG363AjwA6lW13uyLKgxtwvJhlN+trBgVbryqhFVVBKyk94hIq6oW2bbpBLDSdt+DkZt60lE0+yqUD+NzX4Xjw9B9we9WeD6MvB+i2SZdtcL6P+jXhqH9SaP+brFm4oCqrlPVInNA3ASg1n9wNL/EAevXuD31ywA8ldiSui+afSUiG0Oets5sO+4Mty/43QoWzb7idytILYJrZWUwU6g4+W5xaHAMmJEPlbDaHjep6hYR6QKwVlWbzeMrADQB8Kjq/S4W11XD7SuzSRWsX0CdsAKnzZ2SussMVgi7L/jdChbNvjKb8rtlmNBYh5DvjZPvFsOEiIgcYzMXERE5xjAhIiLHGCZEROQYw4SIiBxjmBARkWMMEyIicoxhQmNiLi5UZq530JQOk+aZz1M+8pbuEBGvuRZM9ViuXZIOfzPzGUJPQKQkwDChsfLCOumpCmkwx5GZSbUawdNuJ5syWPu8AmPb53H9m4lIjYg0xfp17cxcUisZKMmHJy2SIyLiv87BNPMfPeWIiBfWWb5rI13fIZmYKwW2jTCv2XDPj8vfzJQLAIrieXa5qVHtgZlBIV7vQ6PDmgmRNf32tlQIkiRXBGBFvKcpMQG4CUC463CQSxgmNK7ZLhRU7XZZUp2qtiUqkE2NxAdgcyLej0bGKegpbsyvxg2wJosrALATw0ywZ7avwODU1zsTMHFhFYDmcAdBU54yVa0190tgfY42+2cw25XCuq59xIPpaPfHWITsw2OwLgjluB/INAX6J+gErGnLnwr9vCISzQXQRnovD6L7DNsAbBSR+1K1iTWtqCoXLmNeYF0bWmHNKmpfXw7rugllsK6FUAKrX0IBbAzzOv7tvWb7jWbbLrO+FdZBJZZl95j3qI5QFrX+iwR9Tv9SZdaXmTL617cC8A7z+aLaH7bnbbRt12rbL60RtrW/R0WEzxf2bzZMGSpMGcrMfX/Za2z7oM62Hzy251aZdU0hi3/b1rF8Bts+VQAVbv8/4KIMEy7OlnAHJnOwUf/BJ2T71tDHIm0Pqy9DYTVBlYc7SDsse0W4g5E5iJXZAqIOQLl5zB4e1aaM/gAMHDhDXm9U+8NWtibbQdV/YB3uIKyh+8i2D+37e7RhEmkf1djueyOESTWAkpDnem37tmQsnyHkdWL6I4PL2BbXC8AltZcIYdIKoCvC9v5fk622ddXhDm62g31VnMpeFe4gFfI5gg54Ic9rCvOcLpjajIP9URHufc1jnghhEqk8/lCstq2LOkxsB+xwNYPQ/RLuuzDkb4fBmlZoQEX9GUZ6DpfEL+yAp5gy7d1eWG3qQ6jpf8Bg2ztg9TeE43+N1pgUbih/GYa9trUO7QepM/82htm8EQj0MYx1f0Tsx9EwfQNmEAEQ/nP413nDPDYitfpzfAAqRKRVRKrMiZOecOUL8/ygqxmKSDWsmlqt2ob1OvwMY/psFFvsgKdYixQMdj5Y14D3moNVG6wDzAaYy4ca68y222JcxljxRbFuVPvD3PdghIAL4X9eqTmHJNQWDAbgWKyA1dRUAqspCgAgIveHhsVwzJX7KmAF5fqQh+P9GSjOGCYUa9H8Ej4GqynE/2v9PljNPVUi0qaq9eaXajmA9eF+jQOBg9MmDXPynm3kVBGsUWG1odtgsLZQMPxHcmRU+8OMGBttmfzvUR/mIO2Y+TutMGFXBmC9+XejiETat0FsMwz4zPNDOfkMowleihM2c1FM2Zo+PMOcUOaFdS1u+3OKYDUR+eedKlHVIg0zJNQ0tXTBag4acpA2B70aWFOjVAPYbJpXQvmbzyKV07HR7g/b9iXhto/wGv5QLAvzmCP++bxM2dpUdYuqrsNgIIzYxGTK7J9LbL2GHwo96s9g2xe+aJ9D8cMwoXjwN1VVhT5gaxu/I+ShcgAwAbJehz+/xANgLayzoMOpgdXx6z8fZD2sNv/QA5//ADbWNndPlNuNdn/4A3TI9gC2hq4wAdQGK7DCPceJAoSvSfhDL5pzSppg7atNoT8OzKSVnjF+Bv/fLZnnUxs/3B4BwCW1FwyOzPHa1nkwOBKqImR9E4aO4rEPFa2GdRC1L+UR3rsCQ0dOeUPXmfVdCH9+SxeAugivP2Rkllk/3Ogi/zBWJ/vDg8FhtjWwgrYCZogywo/mKgnznBLzbx2ChwYP+ZsN8/cdMqzZlK8u9PNH+C74R3iFHb6L4FFsUX8G+98fYUa9cUn84noBuKTmYv6D+w8eag6WoQeXKgyecFhtlrAHMAyeYxBpGXLAjxAm5QgzDNeUNdzBvzp0ewye42H/bBvNYzUIPkmxDsHnStTYnlPjYH/4m+paba/nNY/53zsoiMxBvtpW9i77e0TzN4tQDv8Jia3mdmg4hX3dkL9pk+25TbbtQ0Nx2M8Qsm0dOCw4aRbOGkyuM53OW2E1VwSmxvBPZ2Ie88CaRLDZ9rwKWAdDsa3bCGCzqk4LeY86AD4N6dw1TV+tsNryR+xIpuRgvhtd4N8tabDPhFxlDgpNsOa72qS2kVuq6jMHitFMMz59NO+vVp/KFljzTlHqqID1nWGQJAmGCbnNfx7GiOcQaHQz0rYicsd4pBMHKwF4bZ3hlMTMD5DNCD8wgFzCMCG3+c8iXzHMNuUAop09ONI5BwUY/kz6dbDOc/FE+T7knhoAd0T544IShGFCrtLBCx1VhF5/3VzzvA7WiWzRnmndbJ4bWssowTBn0uvgEOJozlgnl5iwr2bzVvJhBzwlBXPw9weGz/bQfZF+gZrO9ip7B7xZH9TZbjrqV5jmLCKKA4YJpRwz+qsS1nQpHlijwOrUdqKjOeO9AKafZBQ1GyIaA4YJERE5xj4TIiJyjGFCRESOMUyIiMgxhgkRETnGMCEiIscYJkRE5Nj/B0rFYTn9M8EhAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = torch.load('fcn-para.pt')\n",
    "plt.figure(figsize=(6,5))\n",
    "plt.errorbar(data['model_size'].log10(), \n",
    "             data['mu'].mean(dim=1), \n",
    "             data['mu'].std(dim=1), \n",
    "             linestyle='-', marker='o', linewidth=3, markersize=5)\n",
    "\n",
    "plt.xlabel(r' $\\log_{10}$(model size)', fontsize=22)\n",
    "plt.tick_params(axis='both', labelsize=13)\n",
    "plt.ylabel(r'$\\mu(\\theta)$', fontsize=22)\n",
    "plt.title(r'FCN', fontsize=22)\n",
    "\n",
    "plt.savefig('../figs/alignment_fcn_modelsize.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
