{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3b3ee8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import time\n",
    "\n",
    "from FIT_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eb9b32d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get cpu or gpu device\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "764ffb05",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Grab the required dataset - can be changed as required:\n",
    "\n",
    "def get_imnet_loaders(train_batch_size=200, imsize=224):\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        torchvision.datasets.ImageFolder(\n",
    "            root='path/to/dataset',\n",
    "            transform=transforms.Compose([\n",
    "                    transforms.Resize(imsize),\n",
    "                    transforms.CenterCrop(imsize),\n",
    "                    transforms.ToTensor(),\n",
    "                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "                ])),\n",
    "        batch_size=train_batch_size,\n",
    "        shuffle=True,\n",
    "        num_workers=4)\n",
    "\n",
    "    return train_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee0e4c0e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "## experiment setup\n",
    "batch_size = 32\n",
    "train_loader = get_imnet_loaders(batch_size, 224)\n",
    "criterion = nn.CrossEntropyLoss().to(device)\n",
    "model = models.resnet18(pretrained=True)\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "fit_computer = FIT(model, device, (3, 224,224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9541afad",
   "metadata": {},
   "outputs": [],
   "source": [
    "EFw, EFa, fap, faa, param_ranges, act_ranges = fit_computer.EF(model, train_loader, \n",
    "                                                               criterion, \n",
    "                                                               tol=5e-3, \n",
    "                                                               min_iterations=20,\n",
    "                                                               max_iterations=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbac61a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, figsize=(10,8))\n",
    "axs[0, 0].plot(np.array(fap)/fit_computer.param_nums)\n",
    "axs[0, 0].set_title('EF W Convergence')\n",
    "axs[1, 0].plot(np.array(faa)/fit_computer.act_nums[1:])\n",
    "axs[1, 0].set_title('EF A Convergence')\n",
    "axs[1, 0].sharex(axs[0, 0])\n",
    "axs[0, 1].plot(EFw/fit_computer.param_nums,'o-')\n",
    "axs[0, 1].set_title('W Trace')\n",
    "axs[0, 1].sharey(axs[0, 0])\n",
    "axs[1, 1].plot(EFa/fit_computer.act_nums[1:],'o-')\n",
    "axs[1, 1].set_title('A Trace')\n",
    "axs[1, 1].sharey(axs[1, 0])\n",
    "fig.tight_layout()\n",
    "for ax in axs.flat:\n",
    "    ax.set_yscale('log')\n",
    "    ax.grid(True, which='both')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe2723f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "placeholder_config = np.ones(41)*8\n",
    "print(fit_computer.FIT(placeholder_config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3df1aaca",
   "metadata": {},
   "outputs": [],
   "source": [
    "## experiment setup\n",
    "batch_size = 32\n",
    "train_loader = get_imnet_loaders(batch_size, 224)\n",
    "criterion = nn.CrossEntropyLoss().to(device)\n",
    "model = models.resnet50(pretrained=True)\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "fit_computer = FIT(model, device, (3, 224,224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1863a6ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "EFw, EFa, fap, faa, param_ranges, act_ranges = fit_computer.EF(model, train_loader, \n",
    "                                                               criterion, \n",
    "                                                               tol=5e-3, \n",
    "                                                               min_iterations=20,\n",
    "                                                               max_iterations=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49e36ecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, figsize=(10,8))\n",
    "axs[0, 0].plot(np.array(fap)/fit_computer.param_nums)\n",
    "axs[0, 0].set_title('EF W Convergence')\n",
    "axs[1, 0].plot(np.array(faa)/fit_computer.act_nums[1:])\n",
    "axs[1, 0].set_title('EF A Convergence')\n",
    "axs[1, 0].sharex(axs[0, 0])\n",
    "axs[0, 1].plot(EFw/fit_computer.param_nums,'o-')\n",
    "axs[0, 1].set_title('W Trace')\n",
    "axs[0, 1].sharey(axs[0, 0])\n",
    "axs[1, 1].plot(EFa/fit_computer.act_nums[1:],'o-')\n",
    "axs[1, 1].set_title('A Trace')\n",
    "axs[1, 1].sharey(axs[1, 0])\n",
    "fig.tight_layout()\n",
    "for ax in axs.flat:\n",
    "    ax.set_yscale('log')\n",
    "    ax.grid(True, which='both')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "551ea92c",
   "metadata": {},
   "outputs": [],
   "source": [
    "placeholder_config = np.ones(107)*8\n",
    "print(fit_computer.FIT(placeholder_config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6794593f",
   "metadata": {},
   "outputs": [],
   "source": [
    "## experiment setup\n",
    "batch_size = 32\n",
    "train_loader = get_imnet_loaders(batch_size, 224)\n",
    "criterion = nn.CrossEntropyLoss().to(device)\n",
    "model = models.mobilenet_v2(pretrained=True)\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "fit_computer = FIT(model, device, (3, 224,224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e488cfc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "EFw, EFa, fap, faa, param_ranges, act_ranges = fit_computer.EF(model, train_loader, \n",
    "                                                               criterion, \n",
    "                                                               tol=5e-3, \n",
    "                                                               min_iterations=20,\n",
    "                                                               max_iterations=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2122ff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def blockup(data, blocks, param_nums=None, normalise=False):\n",
    "    d = [data[i[0]:i[1]].sum() for i in blocks]\n",
    "    if normalise:\n",
    "        D = [a/b for a,b in zip(d,param_nums)]\n",
    "        return D\n",
    "    else:\n",
    "        return d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30a3780d",
   "metadata": {},
   "outputs": [],
   "source": [
    "block_indxs_a = [(0,1),(1,3),(3,6),(6,9),(9,12),(12,15),(15,18),(18,21),(21,24),(24,27),(27,30),(30,33),(33,36),(36,39),(39,42),(42,45),(45,48),(48,51),(51,52)]\n",
    "block_indxs_w = [(0,1),(1,3),(3,6),(6,9),(9,12),(12,15),(15,18),(18,21),(21,24),(24,27),(27,30),(30,33),(33,36),(36,39),(39,42),(42,45),(45,48),(48,51),(51,52),(52, 53)]\n",
    "\n",
    "block_param_nums = np.array(blockup(fit_computer.param_nums, block_indxs_w))\n",
    "block_act_nums = np.array(blockup(np.array(fit_computer.act_nums[1:]), block_indxs_a))\n",
    "\n",
    "block_EFw = np.array(blockup(EFw, block_indxs_w, block_param_nums, normalise=True))\n",
    "block_EFa = np.array(blockup(EFa, block_indxs_a, block_act_nums, normalise=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a6e8620",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, figsize=(10,8))\n",
    "axs[0, 0].plot(np.array(fap)/fit_computer.param_nums)\n",
    "axs[0, 0].set_title('EF W Convergence')\n",
    "axs[1, 0].plot(np.array(faa)/fit_computer.act_nums[1:])\n",
    "axs[1, 0].set_title('EF A Convergence')\n",
    "axs[1, 0].sharex(axs[0, 0])\n",
    "axs[0, 1].plot(block_EFw,'o-')\n",
    "axs[0, 1].set_title('W Trace')\n",
    "axs[0, 1].sharey(axs[0, 0])\n",
    "axs[1, 1].plot(block_EFa,'o-')\n",
    "axs[1, 1].set_title('A Trace')\n",
    "axs[1, 1].sharey(axs[1, 0])\n",
    "fig.tight_layout()\n",
    "for ax in axs.flat:\n",
    "    ax.set_yscale('log')\n",
    "    ax.grid(True, which='both')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41d1c074",
   "metadata": {},
   "outputs": [],
   "source": [
    "placeholder_config = np.ones(105)*8\n",
    "print(fit_computer.FIT(placeholder_config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "628967be",
   "metadata": {},
   "outputs": [],
   "source": [
    "## experiment setup\n",
    "batch_size = 32\n",
    "train_loader = get_imnet_loaders(batch_size, 299)\n",
    "criterion = nn.CrossEntropyLoss().to(device)\n",
    "model = models.inception_v3(pretrained=True)\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "fit_computer = FIT(model, device, (3, 299,299), layer_filter=('Aux')) # filter out the auxiliary layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcae529e",
   "metadata": {},
   "outputs": [],
   "source": [
    "EFw, EFa, fap, faa, param_ranges, act_ranges = fit_computer.EF(model, train_loader, \n",
    "                                                               criterion, \n",
    "                                                               tol=5e-3, \n",
    "                                                               min_iterations=20,\n",
    "                                                               max_iterations=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce733e25",
   "metadata": {},
   "outputs": [],
   "source": [
    "block_indxs_a = [(0,1),(1,2),(2,3),(3,4),(4,5),(5,12),(12,19),(19,26),(26,30),(30,40),(40,50),(50,60),(60,70),(70,76),(76,85),(85,94)]\n",
    "block_indxs_w = [(0,1),(1,2),(2,3),(3,4),(4,5),(5,12),(12,19),(19,26),(26,30),(30,40),(40,50),(50,60),(60,70),(70,76),(76,85),(85,94),(94,95)]\n",
    "\n",
    "block_param_nums = np.array(blockup(fit_computer.param_nums, block_indxs_w))\n",
    "block_act_nums = np.array(blockup(np.array(fit_computer.act_nums[1:]), block_indxs_a))\n",
    "\n",
    "block_EFw = np.array(blockup(EFw, block_indxs_w, block_param_nums, normalise=True))\n",
    "block_EFa = np.array(blockup(EFa, block_indxs_a, block_act_nums, normalise=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc468e55",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, figsize=(10,8))\n",
    "axs[0, 0].plot(np.array(fap)/fit_computer.param_nums)\n",
    "axs[0, 0].set_title('EF W Convergence')\n",
    "axs[1, 0].plot(np.array(faa)/fit_computer.act_nums[1:])\n",
    "axs[1, 0].set_title('EF A Convergence')\n",
    "axs[1, 0].sharex(axs[0, 0])\n",
    "axs[0, 1].plot(block_EFw,'o-')\n",
    "axs[0, 1].set_title('W Trace')\n",
    "axs[0, 1].sharey(axs[0, 0])\n",
    "axs[1, 1].plot(block_EFa,'o-')\n",
    "axs[1, 1].set_title('A Trace')\n",
    "axs[1, 1].sharey(axs[1, 0])\n",
    "fig.tight_layout()\n",
    "for ax in axs.flat:\n",
    "    ax.set_yscale('log')\n",
    "    ax.grid(True, which='both')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "306cc44a",
   "metadata": {},
   "outputs": [],
   "source": [
    "placeholder_config = np.ones(189)*8\n",
    "print(fit_computer.FIT(placeholder_config))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
