{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd36b163",
   "metadata": {},
   "outputs": [],
   "source": [
    "## imports\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.nn import Module\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision\n",
    "\n",
    "import brevitas.nn as qnn\n",
    "from brevitas.quant import Int8Bias as BiasQuant\n",
    "from brevitas import config as bconfig\n",
    "\n",
    "from training_utils import *\n",
    "from FIT_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1871d35a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get cpu or gpu device for training.\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8a59f96",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Define the experiment:\n",
    "experiment_config = {\n",
    "    'dataset': 'cifar10', ## oneof cifar10/mnist\n",
    "    'batch_norm': True,\n",
    "    'lr': 0.1,\n",
    "    'epochs_fp': 50,\n",
    "    'epochs_qp': 30,\n",
    "    'data_path': 'path/to/data',\n",
    "    'n_models': 100,\n",
    "    'save_path': 'path/to/save'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2229594d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LowPrecisionNet(Module):\n",
    "    def __init__(self, config, experiment_config):\n",
    "        super(LowPrecisionNet, self).__init__()\n",
    "        \n",
    "        self.config=config\n",
    "        \n",
    "        scale = 1\n",
    "        input_dim = 1\n",
    "        if experiment_config['dataset'] == 'cifar10':\n",
    "            scale = 2\n",
    "            input_dim = 3\n",
    "        \n",
    "        self.quant_inp = qnn.QuantIdentity(\n",
    "            bit_width=16, return_quant_tensor=True)\n",
    "        self.conv1 = qnn.QuantConv2d(\n",
    "            input_dim, 4*scale, 5, weight_bit_width=self.config[0], bias_quant=BiasQuant, return_quant_tensor=True, padding=0)\n",
    "        \n",
    "        self.relu1 = qnn.QuantReLU(\n",
    "            bit_width=self.config[1], return_quant_tensor=True)\n",
    "        \n",
    "        self.conv2 = qnn.QuantConv2d(\n",
    "            4*scale, 8*scale, 5, weight_bit_width=self.config[2], bias_quant=BiasQuant, return_quant_tensor=True)\n",
    "        self.relu2 = qnn.QuantReLU(\n",
    "            bit_width=self.config[3], return_quant_tensor=True)\n",
    "        \n",
    "        \n",
    "        self.conv3 = qnn.QuantConv2d(\n",
    "            8*scale, 16*scale, 5, weight_bit_width=self.config[4], bias_quant=BiasQuant, return_quant_tensor=True)\n",
    "        self.relu3 = qnn.QuantReLU(\n",
    "            bit_width=self.config[5], return_quant_tensor=True)\n",
    "        \n",
    "        \n",
    "        self.fc1   = qnn.QuantLinear(\n",
    "            16*scale, 10, bias=True, weight_bit_width=self.config[6], bias_quant=BiasQuant)\n",
    "        \n",
    "        self.dropout1=nn.Dropout(0.2)\n",
    "        \n",
    "        if experiment_config['batch_norm']:\n",
    "            self.bn1 = nn.BatchNorm2d(4*scale)\n",
    "            self.bn2 = nn.BatchNorm2d(8*scale)\n",
    "            self.bn3 = nn.BatchNorm2d(16*scale)\n",
    "        else:\n",
    "            self.bn1 = nn.Identity()\n",
    "            self.bn2 = nn.Identity()\n",
    "            self.bn3 = nn.Identity()\n",
    "            \n",
    "    def forward(self, x):\n",
    "        x = self.quant_inp(x)\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu1(x)\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        \n",
    "        x = self.conv2(x)\n",
    "        x = self.bn2(x)\n",
    "        x = self.relu2(x)\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        \n",
    "        x = self.conv3(x)\n",
    "        x = self.bn3(x)\n",
    "        x = self.relu3(x)\n",
    "\n",
    "        x = x.reshape(x.shape[0], -1)\n",
    "        x = self.fc1(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7440ad2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FullPrecisionNet(Module):\n",
    "    def __init__(self, experiment_config):\n",
    "        super(FullPrecisionNet, self).__init__()\n",
    "        \n",
    "        scale = 1\n",
    "        input_dim = 1\n",
    "        if experiment_config['dataset'] == 'cifar10':\n",
    "            scale = 2\n",
    "            input_dim = 3\n",
    "        \n",
    "        self.conv1 = nn.Conv2d(in_channels = input_dim, out_channels = 4*scale,  kernel_size = 5, stride = 1, padding = 0)\n",
    "\n",
    "        self.relu1 = nn.ReLU()\n",
    " \n",
    "        self.conv2 = nn.Conv2d(in_channels = 4*scale, out_channels = 8*scale,  kernel_size = 5, stride = 1, padding = 0)\n",
    "\n",
    "        self.relu2 = nn.ReLU()\n",
    "\n",
    "        self.dropout1=nn.Dropout(0.2)\n",
    "        \n",
    "        self.conv3 = nn.Conv2d(in_channels = 8*scale, out_channels = 16*scale, kernel_size = 5, stride = 1, padding = 0)\n",
    "\n",
    "        self.relu3 = nn.ReLU()\n",
    "\n",
    "        self.fc1 = nn.Linear(16*scale, 10)\n",
    "        \n",
    "        if experiment_config['batch_norm']:\n",
    "            self.bn1 = nn.BatchNorm2d(4*scale)\n",
    "            self.bn2 = nn.BatchNorm2d(8*scale)\n",
    "            self.bn3 = nn.BatchNorm2d(16*scale)\n",
    "        else:\n",
    "            self.bn1 = nn.Identity()\n",
    "            self.bn2 = nn.Identity()\n",
    "            self.bn3 = nn.Identity()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu1(x)\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        \n",
    "        x = self.conv2(x)\n",
    "        x = self.bn2(x)\n",
    "        x = self.relu2(x)\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        \n",
    "        x = self.dropout1(x)\n",
    "\n",
    "        x = self.conv3(x)\n",
    "        x = self.bn3(x)\n",
    "        x = self.relu3(x)\n",
    "        \n",
    "        x = x.reshape(x.shape[0], -1)\n",
    "        x = self.fc1(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4ea77ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cifar10_loaders(path, train_batch_size=512, test_batch_size=2048):\n",
    "    \n",
    "    train_ds = torchvision.datasets.CIFAR10(\n",
    "            root=path,\n",
    "            train=True,\n",
    "            download=False,\n",
    "            transform=transforms.Compose([\n",
    "                transforms.RandomCrop(32, padding=4),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize((0.4914, 0.4822, 0.4465),\n",
    "                                     (0.2023, 0.1994, 0.2010)),\n",
    "                ]))\n",
    "    test_ds = torchvision.datasets.CIFAR10(\n",
    "            root=path,\n",
    "            train=False,\n",
    "            download=False,\n",
    "            transform=transforms.Compose([\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize((0.4914, 0.4822, 0.4465),\n",
    "                                     (0.2023, 0.1994, 0.2010)),\n",
    "                ]))\n",
    "    \n",
    "    \n",
    "    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)\n",
    "    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=test_batch_size, shuffle=False, num_workers=4)\n",
    "    \n",
    "    return train_loader, test_loader\n",
    "\n",
    "def get_mnist_loaders(path, train_batch_size=512, test_batch_size=2048):\n",
    "    \n",
    "    transform=transforms.Compose([\n",
    "        transforms.Pad(2),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "        ])\n",
    "    train_dataset = torchvision.datasets.MNIST(path, train=True, transform=transform, download=True)\n",
    "    test_dataset = torchvision.datasets.MNIST(path, train=False, transform=transform, download=True)\n",
    "    \n",
    "    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, num_workers=4, shuffle=True)\n",
    "    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, num_workers=4, shuffle=False)\n",
    "    \n",
    "    return train_loader, test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adccfc37",
   "metadata": {},
   "outputs": [],
   "source": [
    "## get relavent dataloaders:\n",
    "if experiment_config['dataset'] == 'cifar10':\n",
    "    train_loader, test_loader = get_cifar10_loaders(experiment_config['data_path'],\n",
    "                                                   512,2048)\n",
    "elif experiment_config['dataset'] == 'mnist':\n",
    "    train_loader, test_loader = get_mnist_loaders(experiment_config['data_path'],\n",
    "                                                   512,2048)\n",
    "else:\n",
    "    raise ValueError('invalid dataset')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d07da9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Now generate the full precision model:\n",
    "model = FullPrecisionNet(experiment_config)\n",
    "model.train()\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43065403",
   "metadata": {},
   "outputs": [],
   "source": [
    "## training parameters:\n",
    "EPOCHS = experiment_config['epochs_fp']\n",
    "\n",
    "optimizer = torch.optim.Adam(params=model.parameters(), lr=experiment_config['lr'])\n",
    "scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)\n",
    "criterion = nn.CrossEntropyLoss().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f15919ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the float32 model - collecting various statistics about the model\n",
    "best_accuracy = 0\n",
    "state_accumulator = []\n",
    "model_state = {}\n",
    "for epoch in range(EPOCHS):\n",
    "    print('-'*80)\n",
    "    \n",
    "    train_loss, train_accuracy = train(model, device, train_loader, criterion, optimizer, epoch)\n",
    "    val_loss, val_accuracy = evaluate(model, device, test_loader, criterion, epoch)\n",
    "    scheduler.step()\n",
    "    \n",
    "    if val_accuracy > best_accuracy:\n",
    "            \n",
    "        best_accuracy = val_accuracy\n",
    "\n",
    "        parameters = float_accumulator(model)\n",
    "        param_embed = torch.cat([g.view(-1) for g in parameters])\n",
    "        model_state['params'] = param_embed\n",
    "        model_state['model_state_dict_top_val'] = model.state_dict()\n",
    "        model_state['top_val_accuracy'] = best_accuracy\n",
    "        print('... saving ...')\n",
    "        \n",
    "    lr = [ group['lr'] for group in optimizer.param_groups ][0]\n",
    "    \n",
    "    print(f'Learning rate: {lr}')\n",
    "    \n",
    "    state = [train_loss, train_accuracy, val_loss, val_accuracy, lr]\n",
    "    \n",
    "    state_accumulator.append(state)\n",
    "    \n",
    "model_state['model_state_dict_fin_val'] = model.state_dict()\n",
    "model_state['fin_train_accuracy'] = train_accuracy\n",
    "model_state['fin_val_accuracy'] = val_accuracy\n",
    "print('... saving ...')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "979abb7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualise training statistics\n",
    "fig = plt.figure()\n",
    "plt.plot(np.array(state_accumulator)[..., 1], label='train')\n",
    "plt.plot(np.array(state_accumulator)[..., 3], label='val')\n",
    "plt.legend()\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca4432e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Compute EF for full precision model\n",
    "input_dim = 1\n",
    "if experiment_config['dataset'] == 'cifar10':\n",
    "    input_dim = 3\n",
    "fit_computer = FIT(model, device, input_spec=(input_dim, 32, 32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2691cd98",
   "metadata": {},
   "outputs": [],
   "source": [
    "EFw, EFa, fap, faa, param_ranges, act_ranges = fit_computer.EF(model, train_loader, \n",
    "                                                               criterion, \n",
    "                                                               tol=1e-2, \n",
    "                                                               min_iterations=20,\n",
    "                                                               max_iterations=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f49e720b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('EF W Convergence')\n",
    "plt.plot(fap)\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6bdba7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('EF A Convergence')\n",
    "plt.plot(faa)\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ec73fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('W Trace')\n",
    "plt.plot(EFw/fit_computer.param_nums,'o-', label='EF')\n",
    "plt.grid(True, which='both')\n",
    "plt.legend()\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cd5f573",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('A Trace')\n",
    "plt.plot(EFa/fit_computer.act_nums[1:],'o-', label='EF')\n",
    "plt.grid(True, which='both')\n",
    "plt.legend()\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afaf251c",
   "metadata": {},
   "outputs": [],
   "source": [
    "## save stats to the model state\n",
    "model_state['EFw'] = EFw\n",
    "model_state['EFa'] = EFa\n",
    "model_state['param_ranges'] = param_ranges\n",
    "model_state['act_ranges'] = act_ranges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cba87ba1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if experiment_config['batch_norm']:\n",
    "    bn_paras = []\n",
    "    for layer in model.modules():\n",
    "        if isinstance(layer,torch.nn.modules.batchnorm.BatchNorm2d): \n",
    "            bn_paras.extend(list(layer.parameters()))\n",
    "    gammas = np.array([np.mean(c.detach().cpu().numpy()) for c in bn_paras[0::2]])\n",
    "    model_state['gammas'] = gammas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a15299e",
   "metadata": {},
   "outputs": [],
   "source": [
    "## save model state:\n",
    "torch.save(model_state, f\"{experiment_config['save_path']}FPNN.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86b6ea02",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "## Generate random quantization configurations:\n",
    "bit_precisions = [3,4,6,8]\n",
    "configs = np.array(np.meshgrid([8],\n",
    "                               bit_precisions, \n",
    "                               bit_precisions, \n",
    "                               bit_precisions, \n",
    "                               bit_precisions, \n",
    "                               bit_precisions, \n",
    "                               bit_precisions,)).T.reshape(-1, 7)\n",
    "random_indxs = np.random.choice(len(configs),experiment_config['n_models'])\n",
    "random_configs = configs[random_indxs, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1582d2e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "## observe FIT value stats for chosen configuraions:\n",
    "q_FIT_vals = []\n",
    "for qconfig in random_configs:\n",
    "    q_FIT = fit_computer.FIT(qconfig.astype(int))\n",
    "    q_FIT_vals.append(q_FIT)\n",
    "    print(qconfig, q_FIT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21493df3",
   "metadata": {},
   "outputs": [],
   "source": [
    "## sanity check\n",
    "print(f'Config with largest FIT: {random_configs[np.argmax(q_FIT_vals)]}, {np.max(q_FIT_vals)}', )\n",
    "print(f'Config with smallest FIT: {random_configs[np.argmin(q_FIT_vals)]}, {np.min(q_FIT_vals)}', )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bd98b83",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ignore missing params when initialisating quant model from full precision checkpoint\n",
    "bconfig.IGNORE_MISSING_KEYS = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6470243c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train random quant configurations\n",
    "for i, qconfig in enumerate(random_configs):\n",
    "    print('#'*80)\n",
    "    start = time.time()\n",
    "    \n",
    "    q_FIT = fit_computer.FIT(qconfig.astype(int))\n",
    "    \n",
    "    print(f'starting analysis for: {qconfig} - FIT: {q_FIT} (iteration {i} of {len(random_configs)})')\n",
    "    \n",
    "    quant_model_state = {}\n",
    "    quant_model = LowPrecisionNet([int(c) for c in qconfig], experiment_config)\n",
    "    quant_model.load_state_dict(model_state['model_state_dict_fin_val'])\n",
    "    quant_model.to(device)\n",
    "    quant_model.train()\n",
    "    \n",
    "    quant_model_state['FIT'] = q_FIT\n",
    "    \n",
    "    parameters = quant_accumulator(quant_model)\n",
    "    param_embed_naive = torch.cat([g.view(-1) for g in parameters])\n",
    "\n",
    "    quant_model_state['params_pre'] = param_embed_naive\n",
    "    \n",
    "    ## Generic training parameters:\n",
    "    EPOCHS = experiment_config['epochs_qp']\n",
    "\n",
    "    optimizer = torch.optim.Adam(params=quant_model.parameters(), lr=experiment_config['lr']*0.1)\n",
    "    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)\n",
    "    criterion = nn.CrossEntropyLoss().to(device)\n",
    "    \n",
    "    print('Evaluating before qat ...')\n",
    "    val_loss, val_accuracy = evaluate(quant_model, device, test_loader, criterion, 0)\n",
    "    quant_model_state['init_val_accuracy'] = val_accuracy\n",
    "    \n",
    "    ## Begin training\n",
    "    \n",
    "    best_accuracy = 0\n",
    "    for epoch in range(EPOCHS):\n",
    "        print('-'*80)\n",
    "        train_loss, train_accuracy = train(quant_model, device, train_loader, criterion, optimizer, epoch)\n",
    "        val_loss, val_accuracy = evaluate(quant_model, device, test_loader, criterion, epoch)\n",
    "        scheduler.step()\n",
    "        if val_accuracy > best_accuracy:\n",
    "            \n",
    "            best_accuracy = val_accuracy\n",
    "            \n",
    "            parameters = quant_accumulator(quant_model)\n",
    "            param_embed_qat = torch.cat([g.view(-1) for g in parameters])\n",
    "            quant_model_state['params_qat'] = param_embed_qat\n",
    "            quant_model_state['model_state_dict_top_val'] = quant_model.state_dict()\n",
    "            quant_model_state['top_val_accuracy'] = best_accuracy\n",
    "            print('Saving model ...')\n",
    "            \n",
    "    quant_model_state['model_state_dict_fin_val'] = quant_model.state_dict()\n",
    "    quant_model_state['fin_train_accuracy'] = train_accuracy\n",
    "    quant_model_state['fin_val_accuracy'] = val_accuracy\n",
    "    print('Saving model ...')\n",
    "    torch.save(quant_model_state, f\"{experiment_config['save_path']}QNN_{np.array(qconfig)}.pt\")\n",
    "    end = time.time()\n",
    "    print(f'Elapsed time: {end-start}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59f0b42c",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Plot analysis of results\n",
    "import glob\n",
    "from scipy.stats import spearmanr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18c38ca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "quant_model_paths = glob.glob(f\"{experiment_config['save_path']}QNN*\")\n",
    "print(len(quant_model_paths))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a3e48e",
   "metadata": {},
   "outputs": [],
   "source": [
    "float_checkpoint = torch.load(f\"{experiment_config['save_path']}FPNN.pt\",map_location=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "893e8a52",
   "metadata": {},
   "outputs": [],
   "source": [
    "FIT = []\n",
    "fin_val_accuracies = []\n",
    "top_val_accuracies = []\n",
    "fin_train_accuracies = []\n",
    "QR = []\n",
    "QRW = []\n",
    "QRA = []\n",
    "noise = []\n",
    "FITW = []\n",
    "FITA = []\n",
    "QRW == []\n",
    "QRA = []\n",
    "BN = []\n",
    "\n",
    "\n",
    "for qpath in quant_model_paths:\n",
    "    config = np.array([int(n) for n in qpath[-17:-4].split(' ')])\n",
    "    \n",
    "    print(config[1::2])\n",
    "    print(config[0::2])\n",
    "    quant_checkpoint = torch.load(qpath,map_location=device)\n",
    "    \n",
    "    pert_acts = fit_computer.noise_model(np.mean(float_checkpoint['act_ranges'], axis=0)[1:], config[1::2])\n",
    "    pert_params = fit_computer.noise_model(np.mean(float_checkpoint['param_ranges'], axis=0), config[0::2])\n",
    "    \n",
    "    f_acts_T = pert_acts*float_checkpoint['EFa']\n",
    "    f_params_T = pert_params*float_checkpoint['EFw']\n",
    "    pert_T = np.sum(f_acts_T) + np.sum(f_params_T)\n",
    "    \n",
    "    noise.append(np.sum(pert_params)+np.sum(pert_params))\n",
    "    fin_val_accuracies.append(quant_checkpoint['fin_val_accuracy'])\n",
    "    top_val_accuracies.append(quant_checkpoint['top_val_accuracy'])\n",
    "    fin_train_accuracies.append(quant_checkpoint['fin_train_accuracy'])\n",
    "    FIT.append(pert_T)\n",
    "    FITW.append(np.sum(f_params_T))\n",
    "    FITA.append(np.sum(f_acts_T))\n",
    "    \n",
    "    QR.append(np.sum(pert_params*(1/np.mean(float_checkpoint['param_ranges'], axis=0)))+np.sum(pert_acts*(1/np.mean(float_checkpoint['act_ranges'], axis=0)[1:])))\n",
    "    QRW.append(np.sum(pert_params*(1/np.mean(float_checkpoint['param_ranges'], axis=0))))\n",
    "    QRA.append(np.sum(pert_acts*(1/np.mean(float_checkpoint['act_ranges'], axis=0)[1:])))\n",
    "    \n",
    "    if experiment_config['batch_norm']:\n",
    "        BN.append(np.sum(pert_acts*(1/float_checkpoint['gammas'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75b90f81",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4, sharey=True, figsize=(10,2))\n",
    "\n",
    "ax1.scatter(FIT, fin_val_accuracies)\n",
    "ax1.set_xlabel('FIT', horizontalalignment='right', x=1.)\n",
    "ax1.set_ylabel('Post QAT Accuracy', horizontalalignment='right', y=1.)\n",
    "\n",
    "ax2.scatter(QR, fin_val_accuracies)\n",
    "ax2.set_xlabel('QR', horizontalalignment='right', x=1.)\n",
    "ax2.set_ylabel('Post QAT Accuracy', horizontalalignment='right', y=1.)\n",
    "\n",
    "ax3.scatter(noise, fin_val_accuracies)\n",
    "ax3.set_xlabel('Noise', horizontalalignment='right', x=1.)\n",
    "ax3.set_ylabel('Post QAT Accuracy', horizontalalignment='right', y=1.)\n",
    "\n",
    "ax4.scatter(FITW, fin_val_accuracies)\n",
    "ax4.set_xlabel('FIT (W)', horizontalalignment='right', x=1.)\n",
    "ax4.set_ylabel('Post QAT Accuracy', horizontalalignment='right', y=1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "613d81ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "coef, p = spearmanr(FIT, fin_val_accuracies)\n",
    "print(coef)\n",
    "coef, p = spearmanr(FITA, fin_val_accuracies)\n",
    "print(coef)\n",
    "coef, p = spearmanr(FITW, fin_val_accuracies)\n",
    "print(coef)\n",
    "coef, p = spearmanr(noise, fin_val_accuracies)\n",
    "print(coef)\n",
    "coef, p = spearmanr(QR, fin_val_accuracies)\n",
    "print(coef)\n",
    "coef, p = spearmanr(QRA, fin_val_accuracies)\n",
    "print(coef)\n",
    "coef, p = spearmanr(QRW, fin_val_accuracies)\n",
    "print(coef)\n",
    "if experiment_config['batch_norm']:\n",
    "    coef, p = spearmanr(BN, fin_val_accuracies)\n",
    "    print(coef)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83a56e8d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
