{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a Meta Neural Network on NASBench\n",
    "## Predict the accuracy of neural networks to within one percent!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "from nasbench import api\n",
    "\n",
    "from data import Data\n",
    "from meta_neural_net import MetaNeuralnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a function to plot the meta neural networks\n",
    "\n",
    "def plot_meta_neuralnet(ytrain, train_pred, ytest, test_pred, max_disp=500, title=None):\n",
    "    \n",
    "    plt.scatter(ytrain[:max_disp], train_pred[:max_disp], label='training data', alpha=0.7, s=64)\n",
    "    plt.scatter(ytest[:max_disp], test_pred[:max_disp], label = 'test data', alpha=0.7, marker='^')\n",
    "\n",
    "    # axis limits\n",
    "    plt.xlim((5, 15))\n",
    "    plt.ylim((5, 15))\n",
    "    ax_lim = np.array([np.min([plt.xlim()[0], plt.ylim()[0]]),\n",
    "                    np.max([plt.xlim()[1], plt.ylim()[1]])])\n",
    "    plt.xlim(ax_lim)\n",
    "    plt.ylim(ax_lim)\n",
    "    \n",
    "    # 45-degree line\n",
    "    plt.plot(ax_lim, ax_lim, 'k:') \n",
    "     \n",
    "    plt.gca().set_aspect('equal', adjustable='box')\n",
    "    plt.title(title)\n",
    "    plt.legend(loc='best')\n",
    "    plt.xlabel('true percent error')\n",
    "    plt.ylabel('predicted percent error')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# load the NASBench dataset\n",
    "# takes about 1 minute to load the nasbench dataset\n",
    "search_space = Data('nasbench')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# method which runs a meta neural network experiment\n",
    "def meta_neuralnet_experiment(params, \n",
    "                              ns=[100, 500], \n",
    "                              num_ensemble=3, \n",
    "                              test_size=500):\n",
    "    \n",
    "    for n in ns:\n",
    "        for encode_paths in [False, True]:\n",
    "\n",
    "            train_data = search_space.generate_random_dataset(num=n, \n",
    "                                                encode_paths=encode_paths)\n",
    "            \n",
    "            test_data = search_space.generate_random_dataset(num=test_size, \n",
    "                                                encode_paths=encode_paths)\n",
    "            \n",
    "            test_data = search_space.remove_duplicates(test_data, train_data)\n",
    "            \n",
    "            xtrain = np.array([d[1] for d in train_data])\n",
    "            ytrain = np.array([d[2] for d in train_data])\n",
    "\n",
    "\n",
    "            xtest = np.array([d[1] for d in test_data])\n",
    "            ytest = np.array([d[2] for d in test_data])\n",
    "\n",
    "            train_errors = []\n",
    "            test_errors = []\n",
    "            meta_neuralnet = MetaNeuralnet()\n",
    "            for _ in range(num_ensemble):            \n",
    "                meta_neuralnet.fit(xtrain, ytrain, **params)\n",
    "                train_pred = np.squeeze(meta_neuralnet.predict(xtrain))\n",
    "                train_error = np.mean(abs(train_pred-ytrain))\n",
    "                train_errors.append(train_error)\n",
    "                test_pred = np.squeeze(meta_neuralnet.predict(xtest))        \n",
    "                test_error = np.mean(abs(test_pred-ytest))\n",
    "                test_errors.append(test_error)\n",
    "\n",
    "            train_error = np.round(np.mean(train_errors, axis=0), 3)\n",
    "            test_error = np.round(np.mean(test_errors, axis=0), 3)\n",
    "            print('Meta neuralnet training size: {}, encode paths: {}'.format(n, encode_paths))\n",
    "            print('Train error: {}, test error: {}'.format(train_error, test_error))\n",
    "\n",
    "            if encode_paths:\n",
    "                title = 'Path encoding, training set size {}'.format(n)\n",
    "            else:\n",
    "                title = 'Adjacency list encoding, training set size {}'.format(n)            \n",
    "\n",
    "            plot_meta_neuralnet(ytrain, train_pred, ytest, test_pred, title=title)\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# try out your favorite hyperparameter settings to get the best accuracy on NASBench\n",
    "meta_neuralnet_params = {'loss':'mae', 'num_layers':10, 'layer_width':20, 'epochs':200, \\\n",
    "                         'batch_size':32, 'lr':.01, 'regularization':0, 'verbose':0}\n",
    "ns = [100, 500]\n",
    "\n",
    "# takes about 1 minute\n",
    "meta_neuralnet_experiment(meta_neuralnet_params, ns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
