{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autosave 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io_model.datasets.networks.pretrained import pretrain_network_cifar\n",
    "from io_model.datasets.networks.utils import load_nasbench\n",
    "from nasbench_pytorch.datasets.cifar10 import prepare_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "batch_size = 32\n",
    "num_workers = 2\n",
    "\n",
    "dataset = prepare_dataset(batch_size, root='../data/cifar/', validation_size=1000, random_state=seed,\n",
    "                          num_workers=num_workers)\n",
    "train, n_train, val, n_val, test, n_test = dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "checkpoint_dir = '../data/train_checkpoints//'\n",
    "\n",
    "os.listdir(f'{checkpoint_dir}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nasbench import api\n",
    "\n",
    "nasbench_path = '../data/nasbench_only108.tfrecord'\n",
    "nb = api.NASBench(nasbench_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from io_model.datasets.networks.utils import load_trained_net\n",
    "\n",
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "\n",
    "net_path = f'{checkpoint_dir}083a8045a46ac2a25a34c8fab47f6ecb.tar'\n",
    "net_hash, net, info = load_trained_net(net_path, nb, device=device)\n",
    "net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.classifier.bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- heatmaps todo\n",
    "  - all net weights for e.g. airplane\n",
    "  - minmaxscale to (-1,1), compare net\n",
    "    - bias?\n",
    "  - standardscale instead of normalize\n",
    "  - predict one image, sort values by weights\n",
    "    - how to normalize??"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "labels = \"airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck\".split(', ')\n",
    "\n",
    "weights = net.classifier.weight.detach().cpu()\n",
    "\n",
    "plt.figure(figsize=(5,6))\n",
    "sns.heatmap(weights, yticklabels=labels)\n",
    "#plt.imshow(net.classifier.weight.detach().cpu())\n",
    "plt.tight_layout()\n",
    "plt.savefig('cifar_weights.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "print(net_hash)\n",
    "np.array([net_hash, net_hash, net_hash])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "type(net_hash)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for batch_idx, (inputs, targets) in enumerate(val):\n",
    "    inputs, targets = inputs.to(device), targets.to(device)\n",
    "    break\n",
    "    \n",
    "print(inputs.shape)\n",
    "print(targets.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench = load_nasbench(nasbench_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nasbench_pytorch.model import Network as NBNetwork\n",
    "\n",
    "net = nasbench[0]\n",
    "net = NBNetwork((net[2], net[1]), 10)\n",
    "net = net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    outputs = net(inputs.to(device))\n",
    "    \n",
    "outputs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    out_list = net.get_cell_outputs(inputs, return_inputs=False)\n",
    "    \n",
    "[print(o.shape) for o in out_list]\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    in_list, out_list = net.get_cell_outputs(inputs, return_inputs=True)\n",
    "    \n",
    "[print(i.shape, ' -> ', o.shape) for i, o in zip(in_list, out_list)]\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cat(in_list[:2]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    in_list, out_list = net.get_cell_outputs(inputs, return_inputs=True)\n",
    "    \n",
    "for in_image in in_list[-2][0][:10]:\n",
    "    print(in_image.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "stds = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]\n",
    "im = inputs[1].detach().cpu() * stds + means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "targets[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(1,1))\n",
    "plt.imshow(im.moveaxis(0,2))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "for in_image in in_list[-2][1][:10]:\n",
    "    plt.figure(figsize=(1,1))\n",
    "    plt.imshow(in_image.detach().cpu())\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if all shapes the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, next_net in enumerate(nasbench[:150]):\n",
    "    if (i % 5) == 0:\n",
    "        print(i)\n",
    "\n",
    "    next_net = NBNetwork((next_net[2], next_net[1]), 10)\n",
    "    \n",
    "    with torch.no_grad():        \n",
    "        n_in, n_out = next_net.get_cell_outputs(inputs, return_inputs=True)\n",
    "        \n",
    "        for i, ni in zip(in_list, n_in):\n",
    "            assert i.shape == ni.shape\n",
    "            \n",
    "        for o, no in zip(out_list, n_out):\n",
    "            assert o.shape == no.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(nasbench)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "text_representation": {
    "extension": ".Rmd",
    "format_name": "rmarkdown",
    "format_version": "1.2",
    "jupytext_formats": "ipynb,Rmd:rmarkdown",
    "jupytext_version": "1.13.0"
   }
  },
  "kernelspec": {
   "display_name": "bioinf",
   "language": "python",
   "name": "bioinf"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
