{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1\n",
    "\n",
    "from nasbench import api\n",
    "\n",
    "nasbench_path = '../data/nasbench_only108.tfrecord'\n",
    "nb = api.NASBench(nasbench_path)\n",
    "\n",
    "import torch\n",
    "from io_model.datasets.arch2vec_dataset import get_labeled_unlabeled_datasets\n",
    "\n",
    "#torch.backends.cudnn.benchmark = True\n",
    "device = torch.device('cuda')\n",
    "\n",
    "# device = None otherwise the dataset is save to the cuda as a whole\n",
    "dataset, _ = get_labeled_unlabeled_datasets(nb, device=device, seed=seed,\n",
    "                                            train_pretrained=None,\n",
    "                                            valid_pretrained=None,\n",
    "                                            train_labeled_path='../data/train_long.pt',\n",
    "                                            valid_labeled_path='../data/valid_long.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io_model.datasets.io.semi_dataset import labeled_network_dataset\n",
    "from scripts.train_vae import get_transforms\n",
    "\n",
    "transforms = get_transforms('../data/scales/scale-train-include_bias.pickle',\n",
    "                            True, None, True)\n",
    "\n",
    "labeled = labeled_network_dataset(dataset['train'], transforms=transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for b in labeled:\n",
    "    print(b[3].shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "np_outs = [b[3].detach().numpy() for b in labeled]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labeled_mean = np.mean(np_outs, axis=0)\n",
    "labeled_mean = torch.Tensor(labeled_mean)\n",
    "labeled_mean.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.std(np_outs, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "gen = torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=True, num_workers=0)\n",
    "ref = labeled_mean.repeat(32, 1)\n",
    "#ref = torch.full((32, 513), np.mean(np_outs))\n",
    "loss = nn.MSELoss()\n",
    "#loss = nn.L1Loss()\n",
    "\n",
    "losses = []\n",
    "for b in gen:\n",
    "    l = loss(ref, b[3])\n",
    "    losses.append(l.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms_val = get_transforms('../data/scales/scale-valid-include_bias.pickle',\n",
    "                                True, None, True)\n",
    "labeled_val = labeled_network_dataset(dataset['valid'], transforms=transforms_val)\n",
    "\n",
    "val_outs = [b[3].detach().numpy() for b in labeled_val]\n",
    "labeled_mean_val = np.mean(val_outs, axis=0)\n",
    "labeled_mean_val = torch.Tensor(labeled_mean_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen2 = torch.utils.data.DataLoader(labeled_val, batch_size=32, shuffle=False, num_workers=0)\n",
    "\n",
    "ref = labeled_mean_val.repeat(32, 1)\n",
    "#ref = labeled_mean.repeat(32, 1)\n",
    "#ref = torch.full((32, 513), np.mean(val_outs))\n",
    "#loss = nn.MSELoss()\n",
    "loss = nn.L1Loss()\n",
    "\n",
    "losses = []\n",
    "for b in gen2:\n",
    "    if b[3].shape != ref.shape:\n",
    "        print('Prr')  # TODO !!!!!!\n",
    "        continue\n",
    "    \n",
    "    l = loss(ref, b[3])\n",
    "    losses.append(l.item())\n",
    "\n",
    "np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "check_path = '../data/vae_checkpoints/2021-11-07_17-05-53/model_dense_epoch-29.pt'\n",
    "trained_checkpoint = torch.load(check_path, map_location=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from arch2vec.extensions.get_nasbench101_model import get_arch2vec_model\n",
    "from io_model.models.utils import load_extended_vae\n",
    "\n",
    "\n",
    "model, optimizer = get_arch2vec_model(device=device)\n",
    "model, _ = load_extended_vae(check_path, [model, 3, 513], device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "how_many = 20\n",
    "\n",
    "orig = []\n",
    "pred = []\n",
    "\n",
    "for i, b in enumerate(gen):\n",
    "    if i >= how_many:\n",
    "        break\n",
    "    \n",
    "    res = model(b[1].to(device), b[0].to(device), b[2].to(device))\n",
    "    pred.append(res[-1].detach().cpu().numpy())\n",
    "    orig.append(b[3].numpy())\n",
    "        \n",
    "orig = np.vstack(orig)\n",
    "pred = np.vstack(pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction vs original comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "plt.close()\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Original outputs\")\n",
    "sns.heatmap(orig, vmax=10)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Predicted outputs\")\n",
    "sns.heatmap(pred, vmax=10)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "how_many = 20\n",
    "\n",
    "orig = []\n",
    "pred = []\n",
    "\n",
    "for i, b in enumerate(gen2):\n",
    "    if i >= how_many:\n",
    "        break\n",
    "    \n",
    "    res = model(b[1].to(device), b[0].to(device), b[2].to(device))\n",
    "    pred.append(res[-1].detach().cpu().numpy())\n",
    "    orig.append(b[3].numpy())\n",
    "        \n",
    "orig = np.vstack(orig)\n",
    "pred = np.vstack(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "plt.close()\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Original outputs - val\")\n",
    "sns.heatmap(orig, vmax=10)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Predicted outputs - val\")\n",
    "sns.heatmap(pred, vmax=10)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig = []\n",
    "pred = []\n",
    "\n",
    "first = None\n",
    "\n",
    "print(len(gen))\n",
    "for i, b in enumerate(gen):\n",
    "    if i % 100 == 0:\n",
    "        print(i)\n",
    "    \n",
    "    if first is None:\n",
    "        first = b[2][0]\n",
    "    \n",
    "    res = model(b[1].to(device), b[0].to(device), b[2].to(device))\n",
    "    \n",
    "    pbatch = res[-1].detach().cpu().numpy()\n",
    "    obatch = b[3].numpy()\n",
    "    \n",
    "    for ins, o, p in zip(b[2], obatch, pbatch):\n",
    "        if (ins == first).all():\n",
    "            pred.append(p)\n",
    "            orig.append(o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_im = np.array(orig)\n",
    "pred_im = np.array(pred)\n",
    "\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Original outputs - same image\")\n",
    "sns.heatmap(orig_im, vmax=10)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Predicted outputs - same image\")\n",
    "sns.heatmap(pred_im, vmax=10)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig = []\n",
    "pred = []\n",
    "\n",
    "first = None\n",
    "model.eval()\n",
    "\n",
    "print(len(gen))\n",
    "for i, b in enumerate(gen):\n",
    "    if i % 100 == 0:\n",
    "        print(i)\n",
    "    \n",
    "    if first is None:\n",
    "        first = b[0][0], b[1][0]\n",
    "    \n",
    "    res = model(b[1].to(device), b[0].to(device), b[2].to(device))\n",
    "    \n",
    "    pbatch = res[-1].detach().cpu().numpy()\n",
    "    obatch = b[3].numpy()\n",
    "    \n",
    "    for adj, ops, o, p in zip(b[0], b[1], obatch, pbatch):\n",
    "        if (adj == first[0]).all() and (ops == first[1]).all():\n",
    "            pred.append(p)\n",
    "            orig.append(o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "orig_im = np.array(orig)\n",
    "pred_im = np.array(pred)\n",
    "\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Original outputs - same net\")\n",
    "sns.heatmap(orig_im, vmax=6)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Predicted outputs - same net\")\n",
    "sns.heatmap(pred_im, vmax=6)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "how_many = 20\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "gen = torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=True, num_workers=0)\n",
    "\n",
    "orig = []\n",
    "pred = []\n",
    "\n",
    "model.eval()\n",
    "\n",
    "for i, b in enumerate(gen):\n",
    "    #f i >= how_many:\n",
    "    #   break\n",
    "    if i % 100 == 0:\n",
    "        print(i)\n",
    "    \n",
    "    res = model(b[1].to(device), b[0].to(device), b[2].to(device))\n",
    "    pred.append(res[-1].detach().cpu().numpy())\n",
    "    orig.append(b[3].numpy())\n",
    "        \n",
    "orig = np.vstack(orig)\n",
    "pred = np.vstack(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.max(np.abs(orig - pred), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "diff = np.abs(orig - pred)\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"diff\")\n",
    "sns.heatmap(diff, vmax=5)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred.min()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dál bordel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io_model.datasets.io.semi_dataset import labeled_network_dataset\n",
    "from scripts.train_vae import get_transforms\n",
    "\n",
    "transforms = get_transforms('../data/scales/scale-train-include_bias-axis_0.pickle',\n",
    "                            True, 0, True)\n",
    "\n",
    "labeled = labeled_network_dataset(dataset['train'], transforms=transforms)\n",
    "gen3 = torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=True, num_workers=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "how_many = 20\n",
    "\n",
    "orig = []\n",
    "\n",
    "for i, b in enumerate(gen3):\n",
    "    if i >= how_many:\n",
    "        break\n",
    "    \n",
    "    orig.append(b[3].numpy())\n",
    "        \n",
    "orig = np.vstack(orig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Original outputs - same image\")\n",
    "sns.heatmap(orig)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io_model.datasets.io.semi_dataset import labeled_network_dataset\n",
    "from scripts.train_vae import get_transforms\n",
    "\n",
    "transforms = get_transforms('../data/scales/scale-train-include_bias.pickle',\n",
    "                            True, None, True)\n",
    "\n",
    "labeled = labeled_network_dataset(dataset['train'], transforms=transforms)\n",
    "gen = torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=True, num_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "how_many = 20\n",
    "\n",
    "orig = []\n",
    "\n",
    "for i, b in enumerate(gen):\n",
    "    if i >= how_many:\n",
    "        break\n",
    "    \n",
    "    orig.append(b[3].numpy())\n",
    "        \n",
    "orig = np.vstack(orig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Original outputs - same image\")\n",
    "sns.heatmap(orig)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig = []\n",
    "\n",
    "first = None\n",
    "\n",
    "print(len(gen3))\n",
    "for i, b in enumerate(gen3):\n",
    "    if i % 100 == 0:\n",
    "        print(i)\n",
    "    \n",
    "    if first is None:\n",
    "        first = b[0][0], b[1][0]\n",
    "    \n",
    "    obatch = b[3].numpy()\n",
    "    \n",
    "    for adj, ops, o in zip(b[0], b[1], obatch):\n",
    "        if (adj == first[0]).all() and (ops == first[1]).all():\n",
    "            orig.append(o)\n",
    "            \n",
    "orig = np.array(orig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib notebook\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.title(\"Original outputs\")\n",
    "sns.heatmap(orig)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "text_representation": {
    "extension": ".Rmd",
    "format_name": "rmarkdown",
    "format_version": "1.2",
    "jupytext_version": "1.13.0"
   }
  },
  "kernelspec": {
   "display_name": "bioinf",
   "language": "python",
   "name": "bioinf"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
