{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autosave 0\n",
    "\n",
    "import torch\n",
    "\n",
    "seed = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "\n",
    "data_dir = '../data/'\n",
    "os.listdir(data_dir)\n",
    "\n",
    "\n",
    "def from_data(path, data=False):\n",
    "    return os.path.join(data_dir, 'data' if data else '', path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nasbench import api\n",
    "\n",
    "nasbench_path = from_data('nasbench_only108.tfrecord')\n",
    "nb = api.NASBench(nasbench_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "model_dir = os.path.join(data_dir, 'out_0534eeefa12ca7c2b177541ad24929f1')\n",
    "\n",
    "subdirs = os.listdir(model_dir)\n",
    "subdirs = [os.path.join(model_dir, d) for d in subdirs]\n",
    "subdirs = [d for d in subdirs if os.path.isdir(d)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io_model.datasets.io.create_dataset import dataset_from_pretrained\n",
    "from nasbench_pytorch.datasets.cifar10 import prepare_dataset\n",
    "import torch\n",
    "import random\n",
    "\n",
    "seed_dir = os.path.join(data_dir, 'seed_experiment_models')\n",
    "if not os.path.exists(seed_dir):\n",
    "    os.mkdir(seed_dir)\n",
    "\n",
    "pretrain = False\n",
    "    \n",
    "if pretrain:\n",
    "    for sd in subdirs:\n",
    "        base_path = os.path.basename(sd)\n",
    "        print(base_path)\n",
    "    \n",
    "        cifar_batch = 128\n",
    "        \n",
    "        random.seed(1)\n",
    "        torch.manual_seed(1)\n",
    "        cifar = prepare_dataset(cifar_batch, validation_size=1000, num_workers=4, root=from_data('cifar'), random_state=1)\n",
    "\n",
    "        dataset_from_pretrained(sd, nb, cifar, os.path.join(seed_dir, f'net_seeds_{base_path}.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io_model.datasets.arch2vec_dataset import prepare_labeled_dataset\n",
    "\n",
    "all_labeled = {}\n",
    "seed_dir = os.path.join(data_dir, 'seed_experiment_models')\n",
    "\n",
    "for net_path in os.listdir(seed_dir):\n",
    "    print(net_path)\n",
    "    \n",
    "    net_path_pt = os.path.join(seed_dir, net_path)\n",
    "\n",
    "    labeled, _ = prepare_labeled_dataset(net_path_pt, nb, nb_dataset=from_data('nb_dataset.json'), dataset=from_data('cifar'),\n",
    "                                         remove_labeled=False)\n",
    "    \n",
    "    all_labeled[net_path] = labeled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_labeled['net_seeds_1.pt']['outputs']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io_model.datasets.io.semi_dataset import labeled_network_dataset\n",
    "from io_model.datasets.io.transforms import get_transforms\n",
    "\n",
    "def get_labeled_data(data):\n",
    "    transforms = get_transforms(from_data('scales/scales/scale-train-include_bias.pickle'),\n",
    "                                True, None, True, scale_whole_path=None)\n",
    "    transforms.transforms = [transforms.transforms[0], transforms.transforms[2]]\n",
    "\n",
    "    labeled = labeled_network_dataset(data, transforms=transforms, return_ref_id=True)\n",
    "    return torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=False, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def get_pred_and_orig(gen, model=None, print_freq=1000):\n",
    "    orig = []\n",
    "    pred = []\n",
    "    info = []\n",
    "    weights = []\n",
    "    labels = []\n",
    "\n",
    "\n",
    "    for i, batch in enumerate(gen):\n",
    "        if i % print_freq == 0:\n",
    "            print(f\"Batch {i}\")\n",
    "\n",
    "        info.append({w: batch[w] for w in ['label', 'hash', 'ref_id']})\n",
    "\n",
    "        b = batch['adj'], batch['ops'], batch['input'], batch['output']\n",
    "\n",
    "        if model is not None:\n",
    "            res = model(b[1].to(device), b[0].to(device), b[2].to(device))\n",
    "            pred.append(res[-1].detach().cpu().numpy())\n",
    "        orig.append(b[3].numpy())\n",
    "        weights.append(np.concatenate([batch['weights'], batch['bias'][:, :, np.newaxis]], axis=-1))\n",
    "        labels.append(batch['label'].numpy())\n",
    "\n",
    "    orig = np.vstack(orig)\n",
    "    weights = np.vstack(weights)\n",
    "    labels = np.hstack(labels)\n",
    "    \n",
    "    if model is None:\n",
    "        return orig, info, weights, labels\n",
    "    \n",
    "    pred = np.vstack(pred)\n",
    "    return orig, pred, info, weights, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = {}\n",
    "\n",
    "for k, v in all_labeled.items():\n",
    "    gen = get_labeled_data(v)\n",
    "    o, i, w, lab = get_pred_and_orig(gen)\n",
    "    \n",
    "    features[k] = (o, i, w, lab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def feat_norm(f, axis=1, no_div=False):\n",
    "    if axis is None:\n",
    "        res = (f - np.mean(f))\n",
    "        return res if no_div else res / np.std(f)\n",
    "    \n",
    "    args = (1, -1) if axis == 0 else (-1, 1)\n",
    "    res = (f - np.mean(f, axis=axis).reshape(*args))\n",
    "    return res if no_div else res / np.std(f, axis=axis).reshape(*args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "other_net = from_data('test_net', data=False)\n",
    "other_net_pt = from_data(f'other_net.pt', data=False)\n",
    "\n",
    "cifar_batch = 128\n",
    "\n",
    "random.seed(1)\n",
    "torch.manual_seed(1)\n",
    "cifar = prepare_dataset(cifar_batch, validation_size=1000, num_workers=4, root=from_data('cifar'), random_state=1)\n",
    "\n",
    "dataset_from_pretrained(other_net, nb, cifar, other_net_pt)\n",
    "\n",
    "\n",
    "other_net_labeled, _ = prepare_labeled_dataset(other_net_pt, nb, device=torch.device('cpu'), nb_dataset=from_data('nb_dataset.json'), dataset=from_data('cifar'),\n",
    "                                               remove_labeled=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = get_labeled_data(other_net_labeled)\n",
    "o_test, i_test, w_test, lab_test = get_pred_and_orig(gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "o_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mult_by_weights(feats, ws, labs):\n",
    "    ws = -np.sort(-ws, axis=-1)\n",
    "    mult_out = feats * ws[np.arange(1000), labs]\n",
    "    return mult_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mw = mult_by_weights(net_1[0], net_1[2], net_1[3])\n",
    "mw2 = mult_by_weights(net_2[0], net_2[2], net_2[3])\n",
    "mw_test = mult_by_weights(o_test, w_test, lab_test)\n",
    "\n",
    "mw = feat_norm(mw, axis=1, no_div=True)\n",
    "mw2 = feat_norm(mw2, axis=1, no_div=True)\n",
    "mw_test = feat_norm(mw_test, axis=1, no_div=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_k = 5\n",
    "\n",
    "vmin = 0\n",
    "vmax = 0.5\n",
    "xlim = (0, 0.5)\n",
    "ylim = (0, 5000)\n",
    "\n",
    "binwidth = 0.025\n",
    "\n",
    "sns.heatmap(mw[:, :top_k])\n",
    "plt.show()\n",
    "sns.heatmap(mw2[:, :top_k])\n",
    "plt.show()\n",
    "sns.heatmap(mw_test[:, :top_k])\n",
    "plt.show()\n",
    "\n",
    "mw_diff = mw - mw2\n",
    "mw_diff = np.square(mw_diff)\n",
    "sns.heatmap(mw_diff[:, :top_k], vmin=vmin, vmax=vmax)\n",
    "plt.show()\n",
    "sns.histplot(mw_diff[:, :top_k].flatten(), binwidth=binwidth)\n",
    "plt.xlim(*xlim)\n",
    "plt.ylim(*ylim)\n",
    "plt.show()\n",
    "\n",
    "mw_diff = mw - mw_test\n",
    "mw_diff = np.square(mw_diff)\n",
    "sns.heatmap(mw_diff[:, :top_k], vmin=vmin, vmax=vmax)\n",
    "plt.show()\n",
    "\n",
    "sns.histplot(mw_diff[:, :top_k].flatten(), binwidth=binwidth)\n",
    "plt.xlim(*xlim)\n",
    "plt.ylim(*ylim)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_1 = features['net_seeds_0.pt']\n",
    "net_2 = features['net_seeds_6.pt']\n",
    "\n",
    "sns.heatmap(mult_by_weights(net_1[0], net_1[2], net_1[3])[:, :10], vmax=1.6, vmin=0.2)\n",
    "plt.show()\n",
    "\n",
    "sns.heatmap(mult_by_weights(net_2[0], net_2[2], net_2[3])[:, :10], vmax=1.6, vmin=0.2)\n",
    "plt.show()\n",
    "\n",
    "sns.heatmap(mult_by_weights(o_test, w_test, lab_test)[:, :10], vmax=1.6, vmin=0.2)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "sns.set()\n",
    "\n",
    "sns.heatmap(features['net_seeds_0.pt'][0])\n",
    "plt.show()\n",
    "sns.heatmap(features['net_seeds_1.pt'][0])\n",
    "plt.show()\n",
    "\n",
    "sns.heatmap(feat_norm(features['net_seeds_0.pt'][0]), vmax=6, vmin=0)\n",
    "plt.show()\n",
    "sns.heatmap(feat_norm(features['net_seeds_1.pt'][0]), vmax=6, vmin=0)\n",
    "plt.show()\n",
    "sns.heatmap(feat_norm(o_test), vmax=6, vmin=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.colors import LogNorm\n",
    "\n",
    "diff = feat_norm(features['net_seeds_0.pt'][0]) - feat_norm(features['net_seeds_6.pt'][0])\n",
    "diff = np.square(diff)\n",
    "plt.title('Same net, different seeds')\n",
    "sns.heatmap(diff[:], cmap='YlGnBu', vmax=2, vmin=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(diff.flatten())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_2 = feat_norm(features['net_seeds_0.pt'][0]) - feat_norm(o_test)\n",
    "diff_2 = np.square(diff_2)\n",
    "plt.title('Different nets')\n",
    "sns.heatmap(diff_2[:], cmap='YlGnBu', vmax=2, vmin=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(diff_2.flatten())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_k = 100\n",
    "\n",
    "plt.title(f\"Abs. Difference histogram for top {top_k} features\")\n",
    "sns.histplot(diff[:, :top_k].flatten())\n",
    "plt.xlim(0, 6)\n",
    "#plt.ylim(0, 2000)\n",
    "plt.show()\n",
    "\n",
    "sns.histplot(diff_2[:, :top_k].flatten())\n",
    "plt.xlim(0, 6)\n",
    "#plt.ylim(0, 2000)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_kk = 10\n",
    "xlim = 90\n",
    "\n",
    "plt.title(f\"MSE between features (per image) - same net, top {top_kk}\")\n",
    "err1 = np.square(feat_norm(features['net_seeds_3.pt'][0])[:, :top_kk] - feat_norm(features['net_seeds_6.pt'][0])[:, :top_kk]).sum(axis=1)\n",
    "sns.histplot(err1, binwidth=5)\n",
    "plt.xlim(0, xlim)\n",
    "print(np.mean(err1))\n",
    "print(np.median(err1))\n",
    "print(np.std(err1))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "err2 = np.square(feat_norm(features['net_seeds_3.pt'][0])[:, :top_kk] - feat_norm(o_test)[:, :top_kk]).sum(axis=1)\n",
    "sns.histplot(err2, binwidth=5)\n",
    "plt.xlim(0, xlim)\n",
    "plt.title(f\"MSE between features (per image) - different nets, top {top_kk}\")\n",
    "print(np.mean(err2))\n",
    "print(np.median(err2))\n",
    "print(np.std(err2))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(features['net_seeds_1.pt'][0], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(o_test, axis=1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "text_representation": {
    "extension": ".Rmd",
    "format_name": "rmarkdown",
    "format_version": "1.2",
    "jupytext_version": "1.13.0"
   }
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
