{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcab958f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Folder containing pytorch-lightning checkpoints named as \"epoch=0008-val_acc_seen=0.57.ckpt\" etc.\n",
    "# In the figure, we trained a ResNet-18 on CIFAR-10 from stretch with SGD-momentum (alpha = 0.9) for 400 epochs.\n",
    "ROOT = \"../resnet18-cifar10-sgd\"\n",
    "ckpt_fns = os.listdir(ROOT)\n",
    "ckpt_fns.sort()\n",
    "ckpts = [os.path.join(ROOT, c) for c in ckpt_fns]\n",
    "print(ckpts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dd0d667",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = [int(c[6:10]) for c in ckpt_fns]\n",
    "acc = [float(c[-9:-5]) for c in ckpt_fns]\n",
    "print(list(zip(eps, ckpts)))\n",
    "print(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5cbf319",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data_uncertainty import MNIST_UncertaintyDM, CIFAR10_UncertaintyDM, SVHN_UncertaintyDM, ImageNet_Validation_UncertaintyDM\n",
    "\n",
    "input_size_dict = {\n",
    "    'mnist': [1, 32, 32], # Resized\n",
    "    'cifar10': [3, 32, 32],\n",
    "    'cifar100': [3, 32, 32],\n",
    "    'svhn': [3, 32, 32],\n",
    "    'fashionmnist': [1, 32, 32], # Resized\n",
    "    'imagenet': [3, 224, 224],\n",
    "    'tinyimagenet': [3, 64, 64],\n",
    "    'stl10': [3, 96, 96],\n",
    "    'lsun': [3, 256, 256],\n",
    "    'celeba': [3, 64, 64],\n",
    "    'cub200': [3, 224, 224],\n",
    "}\n",
    "\n",
    "def get_data_module(\n",
    "    dataset_name,\n",
    "    batch_size,\n",
    "    data_augmentation=True,\n",
    "    num_workers=16,\n",
    "    data_dir='./data',\n",
    "    do_partial_train = False,\n",
    "    do_contamination = True,\n",
    "    use_full_trainset = True,\n",
    "    test_set_max = -1,\n",
    "    is_binary = 0,\n",
    "    noise_std = 0.3,\n",
    "    blur_sigma = 2.0):\n",
    "    \n",
    "    args = {\n",
    "        \"data_dir\": data_dir,\n",
    "        \"batch_size\": batch_size,\n",
    "        \"num_workers\": num_workers,\n",
    "        \"do_partial_train\": do_partial_train,\n",
    "        \"do_contamination\": do_contamination,\n",
    "        \"use_full_trainset\": use_full_trainset,\n",
    "        \"test_set_max\": test_set_max,\n",
    "        \"is_binary\": is_binary,\n",
    "        \"noise_std\": noise_std,\n",
    "        \"blur_sigma\": blur_sigma,\n",
    "    }\n",
    "\n",
    "    if dataset_name == 'mnist':\n",
    "        main_dm = MNIST_UncertaintyDM(**args)\n",
    "    \n",
    "    elif dataset_name == 'cifar10':\n",
    "        main_dm = CIFAR10_UncertaintyDM(**args)\n",
    "    \n",
    "    elif dataset_name == 'svhn':\n",
    "        main_dm = SVHN_UncertaintyDM(**args)\n",
    "\n",
    "    elif dataset_name == 'imagenet':\n",
    "        main_dm = ImageNet_Validation_UncertaintyDM(**args)\n",
    "    \n",
    "    return main_dm, input_size_dict[dataset_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ea9ea65",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pytorch_lightning import LightningModule, Trainer\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from nets import net_dict\n",
    "\n",
    "from param_inject import *\n",
    "from paramsutils import * # <-- under ntk_utils/; you may want to make a copy\n",
    "from functorch import make_functional, make_functional_with_buffers, jvp, grad, vmap, jacrev\n",
    "\n",
    "class QuantityBase():\n",
    "\n",
    "    def __init__(self, args):\n",
    "\n",
    "        self.args = args\n",
    "        self.q_eval_ref = 1.0\n",
    "\n",
    "    def preprocess(self, net, head):\n",
    "\n",
    "        return net, head\n",
    "\n",
    "    def evaluate(self, net, head, batch):\n",
    "\n",
    "        i, x, y, o = batch\n",
    "        \n",
    "        logits = head(net(x))\n",
    "        pred_prob = F.softmax(logits, dim = 1)\n",
    "        entropy = -torch.sum(pred_prob * torch.log(pred_prob + 1e-8), dim = 1)\n",
    "\n",
    "        # shape of logits: [batch_size, num_classes]\n",
    "        # shape of entropy: [batch_size]\n",
    "\n",
    "        return entropy.detach().cpu()\n",
    "\n",
    "    def summary(self, is_ref, results):\n",
    "\n",
    "        q_eval = torch.cat(results).mean()\n",
    "        if is_ref:\n",
    "            self.q_eval_ref = q_eval\n",
    "\n",
    "        return {\"Q\": q_eval, \"Q (Normalized)\": q_eval / self.q_eval_ref}\n",
    "\n",
    "class QComputeGrad(QuantityBase):\n",
    "\n",
    "    def fnet_single(self, y_index = 0):\n",
    "\n",
    "        # Torch 2.0 / CUDA 11.7\n",
    "        # def foo(params):\n",
    "        #     return torch.func.functional_call(self, params, (x,))[0, y_index]\n",
    "\n",
    "        # Torch 1.13 / FuncTorch\n",
    "        def foo(params, x):\n",
    "\n",
    "            result = self.fnet(params, self.fbuffer, x.unsqueeze(0))[0]\n",
    "            # print(\"%s -> %s\" % (str(x.shape), str(result.shape)))\n",
    "\n",
    "            resolved_y_index = y_index\n",
    "            if resolved_y_index < 0:\n",
    "                resolved_y_index = torch.argmax(result)\n",
    "\n",
    "            return result[resolved_y_index]\n",
    "        \n",
    "        return foo\n",
    "\n",
    "    def preprocess(self, net, head):\n",
    "\n",
    "        # self.perturb_power = self.args.perturb_power \n",
    "        \n",
    "        # InjectNet(net, noise_norm = self.perturb_power)\n",
    "        # InjectNet(head, noise_norm = self.perturb_power)\n",
    "\n",
    "        self.combined_net = nn.Sequential(net, head)\n",
    "        self.combined_net.eval()\n",
    "        self.y_index = 0\n",
    "\n",
    "        # Torch 1.13 / FuncTorch\n",
    "        funcresult = make_functional_with_buffers(self.combined_net)\n",
    "        self.fnet, _, self.fbuffer = funcresult\n",
    "        self.fparams = dict(self.combined_net.named_parameters())\n",
    "        \n",
    "        # Torch 2.0 / CUDA 11.7\n",
    "        # self.fparams = dict(self.combined_net.named_parameters())\n",
    "\n",
    "        self.device = torch.device('cuda')\n",
    "\n",
    "        self.final_result = {}\n",
    "        self.param_count = {}\n",
    "        self.count = 0\n",
    "        for k in self.fparams:\n",
    "            self.final_result[k] = 0\n",
    "\n",
    "        return net, head\n",
    "\n",
    "    def evaluate(self, net, head, batch):\n",
    "        \n",
    "        i, x, y = batch\n",
    "        \n",
    "        grad_result = vmap(jacrev(self.fnet_single(self.y_index)), (None, 0))(to_unnamed(self.fparams), x)\n",
    "        grad_result = unnamed_tuple_to_named(self.fparams, grad_result)\n",
    "\n",
    "        for i, k in enumerate(grad_result):\n",
    "            flat_grad = grad_result[k].flatten(1)\n",
    "            norms = flat_grad.norm(dim = -1) / math.sqrt(torch.numel(flat_grad[0]))\n",
    "            self.final_result[k] += (norms.mean()).detach().cpu().item()\n",
    "            self.param_count[k] = torch.numel(flat_grad[0])\n",
    "        self.count += 1\n",
    "        \n",
    "    def summary(self, is_ref, results):\n",
    "        for i, k in enumerate(self.final_result):\n",
    "            self.final_result[k] /= self.count\n",
    "\n",
    "class QComputeParam(QuantityBase):\n",
    "    \n",
    "    def preprocess(self, net, head):\n",
    "        self.combined_net = nn.Sequential(net, head)\n",
    "        self.fparams = dict(self.combined_net.named_parameters())\n",
    "        self.paramResults = ParamsStatsticsRecorder()\n",
    "        self.paramResults.record(self.fparams)\n",
    "        self.paramResults.record(self.fparams)\n",
    "        return net, head\n",
    "    \n",
    "    def summary(self, is_ref, results):\n",
    "        means = self.paramResults.mean()\n",
    "        self.final_result = {}\n",
    "        self.param_count = {}\n",
    "        for i, k in enumerate(means):\n",
    "#             self.final_result[k] = (means[k].norm() / math.sqrt(torch.numel(means[k]))).detach().cpu().item()\n",
    "            self.final_result[k] = (means[k].norm() / math.sqrt(torch.numel(means[k]))).detach().cpu().item()\n",
    "            self.param_count[k] = torch.numel(means[k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8b148ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_rename_ckpt(state_dict, prefix):\n",
    "    filtered_dict = {}\n",
    "    for k in state_dict:\n",
    "        if k.startswith(\"%s.\" % prefix):\n",
    "            filtered_dict[k.replace(\"%s.\" % prefix, \"\", 1)] = state_dict[k]\n",
    "    return filtered_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09ead5c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Network\n",
    "import resnet\n",
    "\n",
    "class ResNetCIFARFactory():\n",
    "\n",
    "    def __init__(self, factory_args = None):\n",
    "        super(ResNetCIFARFactory, self).__init__(factory_args)\n",
    "        self.args = factory_args\n",
    "\n",
    "    def getNets(self, input_shape, output_shape):\n",
    "\n",
    "        in_features = input_shape[0]\n",
    "\n",
    "        whole_net = resnet.resnet18(\n",
    "            pretrained = False,\n",
    "            conv1_type = 'cifar',\n",
    "            no_maxpool = True,\n",
    "            num_classes = output_shape[0],\n",
    "            input_channels = in_features,\n",
    "            act = nn.ReLU(inplace = False)\n",
    "        )\n",
    "\n",
    "        head = whole_net.fc\n",
    "        whole_net.fc = nn.Identity()\n",
    "        net = whole_net\n",
    "\n",
    "        # print(net)\n",
    "        # print(head)\n",
    "\n",
    "        return net, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e12b753a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "reference_data_count = 8192\n",
    "test_set_max = 1024\n",
    "net_name = 'resnet-cifar'\n",
    "\n",
    "##############################\n",
    "# Prepare dataloader\n",
    "##############################\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "main_datamodule, input_dim = get_data_module(\n",
    "    \"cifar10\",\n",
    "    128,\n",
    "    data_augmentation = True,\n",
    "    num_workers=0,\n",
    "    do_partial_train = False,\n",
    "    do_contamination = False,\n",
    "    test_set_max = test_set_max,\n",
    "    is_binary = False,\n",
    "    noise_std = 0.0,\n",
    "    blur_sigma = 0.0)\n",
    "\n",
    "main_datamodule.setup()\n",
    "# Q = QComputeParam(None)\n",
    "Q = QComputeGrad(None)\n",
    "\n",
    "indices = None\n",
    "if indices is None:\n",
    "    indices = np.random.choice(len(main_datamodule.test_dataset), size = (reference_data_count,))\n",
    "\n",
    "calibration_data = torch.utils.data.Subset(\n",
    "    main_datamodule.test_dataset,\n",
    "    np.random.choice(len(main_datamodule.test_dataset), size = (reference_data_count,))\n",
    ")\n",
    "\n",
    "calibration_dl = torch.utils.data.DataLoader(\n",
    "    calibration_data,\n",
    "    batch_size = 64,\n",
    "    shuffle = False,\n",
    "    num_workers = 0\n",
    ")\n",
    "\n",
    "##############################\n",
    "# Get Net\n",
    "##############################\n",
    "\n",
    "output_dim = main_datamodule.n_classes\n",
    "\n",
    "results = []\n",
    "\n",
    "for i in range(len(ckpts)):\n",
    "    \n",
    "    print(\"Ep #%04d\" % eps[i])\n",
    "    \n",
    "    net_factory = ResNetCIFARFactory()\n",
    "    net, head = net_factory.getNets(\n",
    "        input_dim, \n",
    "        [output_dim],\n",
    "    )\n",
    "\n",
    "    path = ckpts[i]\n",
    "    loaded = torch.load(path)[\"state_dict\"]\n",
    "    net.load_state_dict(filter_rename_ckpt(loaded, \"net.0\"))\n",
    "    head.load_state_dict(filter_rename_ckpt(loaded, \"net.1\"))\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    net = net.to(device)\n",
    "    head = head.to(device)\n",
    "\n",
    "    net, head = Q.preprocess(net, head)\n",
    "\n",
    "    ##############################\n",
    "    # Feed data\n",
    "    ##############################\n",
    "\n",
    "    q_evals = []\n",
    "\n",
    "    for batch in tqdm(calibration_dl):\n",
    "        batch = [b.to(device) for b in batch]\n",
    "        q_eval = Q.evaluate(net, head, batch)\n",
    "        q_evals.append(q_eval)\n",
    "\n",
    "    Q.summary(False, q_evals)\n",
    "    \n",
    "    results.append(Q.final_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1ae70dc-67ad-4cf0-9181-742715c92515",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results_paramStats = []\n",
    "Q2 = QComputeParam(None)\n",
    "\n",
    "for i in range(len(ckpts)):\n",
    "    \n",
    "    print(\"Ep #%04d\" % eps[i])\n",
    "    \n",
    "    net_factory = ResNetCIFARFactory()\n",
    "    net, head = net_factory.getNets(\n",
    "        input_dim, \n",
    "        [output_dim],\n",
    "    )\n",
    "\n",
    "    path = ckpts[i]\n",
    "    loaded = torch.load(path)[\"state_dict\"]\n",
    "    net.load_state_dict(filter_rename_ckpt(loaded, \"net.0\"))\n",
    "    head.load_state_dict(filter_rename_ckpt(loaded, \"net.1\"))\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    net = net.to(device)\n",
    "    head = head.to(device)\n",
    "\n",
    "    net, head = Q2.preprocess(net, head)\n",
    "    Q2.summary(False, None)\n",
    "    results_paramStats.append(Q2.final_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b17459b9-ecde-4c17-be1d-36cb8a271391",
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "outpath = \"new-params.csv\"\n",
    "\n",
    "# Combine all results\n",
    "all_results_paramStats = {k: [r[k] for r in results_paramStats] for k in results_paramStats[0].keys()}\n",
    "all_convs_paramStats = {k: all_results_paramStats[k] for k in list(all_results_paramStats.keys())[1:] if 'conv' in k}\n",
    "\n",
    "with open(outpath, 'w') as csvfile:\n",
    "    wtr = csv.writer(csvfile, delimiter = ',')\n",
    "    wtr.writerow([\"Layer name\", *[\"%d\" % e for e in eps]])\n",
    "    for k in all_results_paramStats.keys():\n",
    "        wtr.writerow([k, *all_results_paramStats[k]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a3f6526f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "# outpath = \"params.csv\"\n",
    "outpath = \"new-grads_train_sgd.csv\"\n",
    "\n",
    "# Combine all results\n",
    "all_results = {k: [r[k] for r in results] for k in results[0].keys()}\n",
    "\n",
    "with open(outpath, 'w') as csvfile:\n",
    "    wtr = csv.writer(csvfile, delimiter = ',')\n",
    "    wtr.writerow([\"Layer name\", *[\"%d\" % e for e in eps]])\n",
    "    for k in all_results.keys():\n",
    "        wtr.writerow([k, *all_results[k]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "febea8ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import matplotlib\n",
    "\n",
    "all_convs = {k: all_results[k] for k in list(all_results.keys())[1:] if 'conv' in k}\n",
    "fig, ax = plt.subplots(figsize = (10, 7))\n",
    "\n",
    "cmap = matplotlib.colormaps['hsv']\n",
    "\n",
    "for i, ep in enumerate(eps):\n",
    "    ax.scatter(\n",
    "        [x + (ep / 600.0) for x in list(range(len(all_convs.keys())))],\n",
    "        [all_convs[k][i] for k in all_convs.keys()],\n",
    "        color = cmap(ep / 400.0),\n",
    "        marker = 'x',\n",
    "        alpha = 1.0\n",
    "    )\n",
    "#     break\n",
    "\n",
    "# print(all_convs)\n",
    "ax.set_title(\"Layer-wise gradient mean abs. on training set ResNet-18 CIFAR-10 SGD over 400 epochs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b385cc45-c9ef-4c2c-be94-cc49fb4aaa7d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(Q.param_count)\n",
    "print(all_convs_paramStats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "832393c8-8068-4330-acf7-a341946f5f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "def create_custom_colormap(color1, color2):\n",
    "    \"\"\"\n",
    "    Create a custom colormap from two input colors.\n",
    "    \n",
    "    Parameters:\n",
    "    color1 (str): The starting color of the colormap.\n",
    "    color2 (str): The ending color of the colormap.\n",
    "    \n",
    "    Returns:\n",
    "    matplotlib.colors.LinearSegmentedColormap: A custom colormap.\n",
    "    \"\"\"\n",
    "    # Create a list of colors for the colormap\n",
    "    colors = [color1, color2]\n",
    "    \n",
    "    # Create the colormap\n",
    "    cmap = mcolors.LinearSegmentedColormap.from_list(\"custom_cmap\", colors)\n",
    "    \n",
    "    return cmap\n",
    "\n",
    "im_cmap = create_custom_colormap('#202ff7', '#21d95b')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1be14add-910a-4afb-873f-0aa2a9635a5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "epix = 12\n",
    "\n",
    "sns.set_style(\"white\")\n",
    "\n",
    "plt.style.use('style.mplstyle')\n",
    "# plt.tight_layout()\n",
    "\n",
    "fig, ax = plt.subplots(figsize = (1.7, 1.2))\n",
    "# fig, ax = plt.subplots()\n",
    "ax2 = ax.twinx()\n",
    "ax.grid(linestyle = '--', linewidth = 0.5, alpha = 1.0)\n",
    "ax2.grid(False)\n",
    "cmap = im_cmap\n",
    "colors = np.asarray([(all_convs_paramStats[k][-1] ** 2) / Q.param_count[k] for k in all_convs])\n",
    "colors = np.sqrt(colors)\n",
    "colors = colors / colors.max()\n",
    "\n",
    "print(colors)\n",
    "\n",
    "for i, k in enumerate(all_convs):\n",
    "    ax.plot(eps, all_convs[k], '-', color = cmap(colors[i]), linewidth = 1.0)\n",
    "\n",
    "ym, yM = ax.get_ylim()\n",
    "ax.vlines(eps[epix], -1, 1, 'k', 'dashed', linewidth = 0.5)\n",
    "ax.vlines(eps[30], -1, 1, 'm', 'dashed', linewidth = 0.5)\n",
    "ax.set_ylim(ym, yM)\n",
    "\n",
    "ax.set_ylabel('$\\\\|\\\\nabla_{\\\\theta_l} f_t(x)\\\\| \\\\cdot |\\\\theta_l|^{-\\\\frac{1}{2}}$')\n",
    "ax.set_xlabel('$t$ (Epochs)')\n",
    "ax.text(-.07,.9,'\\\\textbf{a)}',\n",
    "        horizontalalignment='center',\n",
    "        transform=ax.transAxes)\n",
    "ax2.plot(eps, acc, '--', color = 'red', linewidth = 1.0)\n",
    "\n",
    "ax2.set_ylabel('Validation accuracy')\n",
    "\n",
    "# print(all_convs)\n",
    "# ax.set_title(\"Layer-wise gradient mean abs. on training set ResNet-18 CIFAR-10 SGD over 400 epochs\")\n",
    "plt.savefig(\"layer-grad-evolution.pdf\", format=\"pdf\", bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f0e93b0-7ee0-409e-9bc5-4a60dbc237ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import curve_fit\n",
    "\n",
    "epix = 12\n",
    "print(eps[epix])\n",
    "\n",
    "plt.style.use('style.mplstyle')\n",
    "grads = [(all_convs[k][epix] / all_convs[k][-1]) for k in all_convs]\n",
    "\n",
    "# coeff = np.polyfit(colors, grads, 1)\n",
    "xxx = np.sqrt(np.asarray([1 / Q.param_count[k] for k in all_convs])) * 100\n",
    "yyy = np.array(grads)\n",
    "popt, _ = curve_fit(f, xxx, yyy)\n",
    "# print(coeff)\n",
    "fitted = np.poly1d((popt[0], 0))\n",
    "\n",
    "fig, ax = plt.subplots(figsize = (1.2, 1.2))\n",
    "ax.grid(linestyle = '--', linewidth = 0.5, alpha = 1.0)\n",
    "ax.scatter(xxx, grads, s = 10, color = [cmap(colors[i]) for i in range(len(grads))])\n",
    "ax.plot(xxx, fitted(xxx), '--k', linewidth = 0.5)\n",
    "\n",
    "xm, xM = ax.get_xlim()\n",
    "ym, yM = ax.get_ylim()\n",
    "ax.set_yticks(np.arange(0, 1, 0.1))\n",
    "ax.set_xticks(np.arange(0, 1, 0.1))\n",
    "ax.set_xlim(xm, xM)\n",
    "ax.set_ylim(ym, yM)\n",
    "\n",
    "ax.set_ylabel('$\\\\left. \\\\|\\\\nabla_{\\\\theta_l} f_t(x)\\\\| \\\\middle/ \\\\|\\\\nabla_{\\\\theta_l} f_T(x)\\\\| \\\\right.$')\n",
    "ax.set_xlabel('$10^2 \\cdot |\\\\theta_l|^{-1/2}$')\n",
    "ax.yaxis.set_label_coords(-0.3,0.34)\n",
    "ax.text(.1,.87,'\\\\textbf{b)}',\n",
    "        horizontalalignment='center',\n",
    "        transform=ax.transAxes)\n",
    "# ax.set_xlim(0, 1.1)\n",
    "# ax.set_ylim(0, 1.1)\n",
    "# ax.set_title(\"grad scaling for epoch %d vs. termination time\" % (eps[epix]))\n",
    "\n",
    "iax = ax.inset_axes([0.55, 0.05, 0.4, 0.4])\n",
    "iax.set_xticks([])\n",
    "iax.set_yticks([])\n",
    "\n",
    "epix = 30\n",
    "print(eps[epix])\n",
    "\n",
    "grads = [(all_convs[k][epix] / all_convs[k][-1]) for k in all_convs]\n",
    "\n",
    "# coeff = np.polyfit(colors, grads, 1)\n",
    "xxx = np.sqrt(np.asarray([1 / Q.param_count[k] for k in all_convs])) * 100\n",
    "yyy = np.array(grads)\n",
    "popt, _ = curve_fit(f, xxx, yyy)\n",
    "# print(coeff)\n",
    "fitted = np.poly1d((popt[0], 0))\n",
    "\n",
    "iax.scatter(xxx, grads, s = 4, color = [cmap(colors[i]) for i in range(len(grads))])\n",
    "iax.plot(xxx, fitted(xxx), '--m', linewidth = 0.5)\n",
    "iax.text(.88,.1,'\\\\textbf{c)}',\n",
    "        horizontalalignment='center',\n",
    "        transform=ax.transAxes)\n",
    "\n",
    "plt.savefig(\"layer-grad-earlyEp.pdf\", format=\"pdf\", bbox_inches = 'tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
