{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.listdir('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nasbench import api\n",
    "\n",
    "nasbench_path = '../data/nasbench_only108.tfrecord'\n",
    "nb = api.NASBench(nasbench_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from io_model.datasets.arch2vec_dataset import get_labeled_unlabeled_datasets\n",
    "\n",
    "#torch.backends.cudnn.benchmark = True\n",
    "device = torch.device('cuda')\n",
    "\n",
    "# device = None otherwise the dataset is save to the cuda as a whole\n",
    "labeled, unlabeled = get_labeled_unlabeled_datasets(nb, device=device, seed=seed,\n",
    "                                                    train_pretrained=train_pretrained,\n",
    "                                                    valid_pretrained=val_pretrained,\n",
    "                                                    train_labeled_path='../data/train_long.pt',\n",
    "                                                    valid_labeled_path='../data/valid_long.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test dataset shapes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(labeled['valid_io']['inputs'])):\n",
    "    if i not in labeled['valid_io']['inputs']:\n",
    "        raise ValueError()\n",
    "        \n",
    "    assert (sum(labeled['train_io']['inputs'] == i)) == 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[l.shape for l in labeled['train_io'].values() if not isinstance(l, int)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 7 GB per 4000 nets\n",
    "labeled['train_io']['outputs'].element_size() * 4000 * 512 / 1024 / 1024 / 1024 * 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[l.shape for l in labeled['train_net']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unlabeled.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unlabeled['train'][1][0].shape, unlabeled['train'][2][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(unlabeled['train'][1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test model shapes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from arch2vec.extensions.get_nasbench101_model import get_arch2vec_model\n",
    "\n",
    "model, opt = get_arch2vec_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from arch2vec.extensions.get_nasbench101_model import get_nasbench_datasets\n",
    "\n",
    "nb_dataset = get_nasbench_datasets('../data/nb_dataset.json', batch_size=None, seed=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model.train()\n",
    "model.eval()\n",
    "\n",
    "batch_adj, batch_ops = nb_dataset['train'][1][:32], nb_dataset['train'][2][:32]\n",
    "\n",
    "mu, logvar = model._encoder(batch_ops, batch_adj)\n",
    "z = model.reparameterize(mu, logvar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"mu shape: {mu.shape}, logvar shape: {logvar.shape}, z shape: {z.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "m = nn.Sequential(\n",
    "    nn.Flatten(),\n",
    "    nn.Linear(z.shape[1] * z.shape[2], 5),\n",
    "    nn.ReLU()\n",
    ")\n",
    "m(z).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test unsqueeze and channels\n",
    "\n",
    "conv = nn.Conv2d(32, 8, 1, padding=0)\n",
    "conv(mu.unsqueeze(0)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# repeat and concat\n",
    "repeated = torch.Tensor([3]).repeat(mu.shape[0], mu.shape[1], 1)\n",
    "print(repeated.shape)\n",
    "\n",
    "torch.cat([mu, repeated], axis=-1).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extended models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from arch2vec.extensions.get_nasbench101_model import get_arch2vec_model\n",
    "from arch2vec.extensions.get_nasbench101_model import get_nasbench_datasets\n",
    "\n",
    "model, opt = get_arch2vec_model(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(labeled['train']['dataset'].shape)\n",
    "print(labeled['train']['inputs'].shape)\n",
    "print(labeled['train']['outputs'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io_model.models.io_model import ConcatConvModel\n",
    "\n",
    "extended_model = ConcatConvModel(model, 128, 512).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_batch = labeled['train_io']['inputs'][:32]\n",
    "out_batch = labeled['train_io']['outputs'][:32]\n",
    "\n",
    "batch_adj, batch_ops = labeled['train_net'][0][:32], labeled['train_net'][1][:32]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ops_recon, adj_recon, mu, logvar, _, outputs = extended_model(batch_ops.to(device), batch_adj.to(device), in_batch.to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Just some tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "_,_,_,_,z = model(batch_ops.to(device), batch_adj.to(device))\n",
    "z = extended_model.process_z(z)\n",
    "z = z.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, in_batch.shape[2], in_batch.shape[3])\n",
    "z.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "torch.cat([z, in_batch], dim=1).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labeled_2 = labeled.copy()\n",
    "\n",
    "omax = labeled_2['train']['outputs'].max()\n",
    "labeled_2['train']['outputs'] /= omax\n",
    "\n",
    "print(labeled_2['train']['outputs'].max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from io_model.trainer import train\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = None\n",
    "    \n",
    "model = train(labeled_2, unlabeled, nb, checkpoint_path='.', device=device,\n",
    "              use_reference_model=True, batch_len_labeled=7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(labeled['train_io']['inputs'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib notebook\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(labeled_2['train_io']['outputs'], cmap='hot', interpolation='nearest')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "text_representation": {
    "extension": ".Rmd",
    "format_name": "rmarkdown",
    "format_version": "1.2",
    "jupytext_formats": "ipynb,Rmd:rmarkdown",
    "jupytext_version": "1.13.0"
   }
  },
  "kernelspec": {
   "display_name": "pyt_conda",
   "language": "python",
   "name": "pyt_conda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
