{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/Envs/venv/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is  deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n",
      "  warnings.warn(message, FutureWarning)\n",
      "/home/ubuntu/Envs/venv/lib/python3.6/site-packages/sklearn/externals/joblib/__init__.py:15: FutureWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n",
      "  warnings.warn(msg, category=FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "import os.path as osp\n",
    "import argparse\n",
    "import math\n",
    "import time\n",
    "import random\n",
    "import string\n",
    "\n",
    "import torch\n",
    "import torch_geometric.transforms as T\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "import wandb\n",
    "import os\n",
    "\n",
    "from games_dataset import Games\n",
    "#from indian_village_dataset import IndianVillageGames\n",
    "#from yelp import Yelp\n",
    "from utils import get_encoder, get_decoder, mask_diagonal, get_loss_weights#, permute_features\n",
    "from evaluation import eval_baseline, eval_everything, eval\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/generalized_graph_learning/gnn\n"
     ]
    }
   ],
   "source": [
    "!pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_graphs = 100\n",
    "n_nodes = 20\n",
    "m = 1\n",
    "n_games = 50\n",
    "target_spectral_radius = 0.8\n",
    "alpha = 1\n",
    "marginal_benefits_noise_variance = 0.1\n",
    "game_type = \"linear_quadratic\"\n",
    "graph = \"erdos_renyi\"\n",
    "regenerate_data = True\n",
    "val_ratio = 0.05\n",
    "test_ratio = 0.1\n",
    "batch_size = 1\n",
    "test_batch_size = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating graphs\n",
      "Finished generating graphs. It took 0.5315923690795898s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 1\n",
      "Correlation baseline     --- train ROC_AUC:0.1462+-0.0055, val ROC_AUC:0.1688+-0.0129, test ROC_AUC:0.1465+-0.0142\n",
      "Anticorrelation baseline --- train ROC_AUC:0.8538+-0.0055, val ROC_AUC:0.8312+-0.0129, test ROC_AUC:0.8535+-0.0142\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.5000+-0.0000, test ROC_AUC:0.5000+-0.0000\n",
      "Generating graphs\n",
      "Finished generating graphs. It took 0.5955779552459717s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 1\n",
      "Correlation baseline     --- train ROC_AUC:0.3438+-0.0072, val ROC_AUC:0.3238+-0.0319, test ROC_AUC:0.3477+-0.0128\n",
      "Anticorrelation baseline --- train ROC_AUC:0.6562+-0.0072, val ROC_AUC:0.6762+-0.0319, test ROC_AUC:0.6523+-0.0128\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.5000+-0.0000, test ROC_AUC:0.5000+-0.0000\n",
      "Generating graphs\n",
      "Finished generating graphs. It took 1.1476397514343262s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 0.01\n",
      "Correlation baseline     --- train ROC_AUC:0.5824+-0.0085, val ROC_AUC:0.6189+-0.0301, test ROC_AUC:0.5623+-0.0182\n",
      "Anticorrelation baseline --- train ROC_AUC:0.4176+-0.0085, val ROC_AUC:0.3811+-0.0301, test ROC_AUC:0.4377+-0.0182\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.6329+-0.0336, test ROC_AUC:0.5728+-0.0184\n",
      "Generating graphs\n",
      "Finished generating graphs. It took 1.1320297718048096s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 0.1\n",
      "Correlation baseline     --- train ROC_AUC:0.7650+-0.0071, val ROC_AUC:0.7619+-0.0104, test ROC_AUC:0.7860+-0.0163\n",
      "Anticorrelation baseline --- train ROC_AUC:0.2350+-0.0071, val ROC_AUC:0.2381+-0.0104, test ROC_AUC:0.2140+-0.0163\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.7796+-0.0148, test ROC_AUC:0.7767+-0.0176\n",
      "Generating graphs\n",
      "Finished generating graphs. It took 1.1449918746948242s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 0.01\n",
      "Correlation baseline     --- train ROC_AUC:0.8492+-0.0044, val ROC_AUC:0.8472+-0.0198, test ROC_AUC:0.8525+-0.0201\n",
      "Anticorrelation baseline --- train ROC_AUC:0.1508+-0.0044, val ROC_AUC:0.1528+-0.0198, test ROC_AUC:0.1475+-0.0201\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.8507+-0.0187, test ROC_AUC:0.8539+-0.0194\n",
      "Generating graphs\n",
      "Finished generating graphs. It took 0.5946147441864014s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 0.0001\n",
      "Correlation baseline     --- train ROC_AUC:0.8949+-0.0038, val ROC_AUC:0.9218+-0.0031, test ROC_AUC:0.9118+-0.0077\n",
      "Anticorrelation baseline --- train ROC_AUC:0.1051+-0.0038, val ROC_AUC:0.0782+-0.0031, test ROC_AUC:0.0882+-0.0077\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.9226+-0.0040, test ROC_AUC:0.9127+-0.0075\n",
      "Generating graphs\n",
      "Finished generating graphs. It took 0.5242044925689697s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 0.0001\n",
      "Correlation baseline     --- train ROC_AUC:0.9293+-0.0029, val ROC_AUC:0.9268+-0.0087, test ROC_AUC:0.9234+-0.0059\n",
      "Anticorrelation baseline --- train ROC_AUC:0.0707+-0.0029, val ROC_AUC:0.0732+-0.0087, test ROC_AUC:0.0766+-0.0059\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.9242+-0.0083, test ROC_AUC:0.9210+-0.0063\n",
      "Generating graphs\n",
      "Finished generating graphs. It took 0.5220906734466553s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 0.0001\n",
      "Correlation baseline     --- train ROC_AUC:0.9352+-0.0023, val ROC_AUC:0.9139+-0.0081, test ROC_AUC:0.9378+-0.0050\n",
      "Anticorrelation baseline --- train ROC_AUC:0.0648+-0.0023, val ROC_AUC:0.0861+-0.0081, test ROC_AUC:0.0622+-0.0050\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.9130+-0.0073, test ROC_AUC:0.9333+-0.0053\n",
      "Generating graphs\n",
      "Finished generating graphs. It took 0.6664915084838867s\n",
      "Tuning Graphical Lasso regularization parameter... Done! Best regularization parameter found: 0.0001\n",
      "Correlation baseline     --- train ROC_AUC:0.9370+-0.0022, val ROC_AUC:0.9157+-0.0110, test ROC_AUC:0.9409+-0.0069\n",
      "Anticorrelation baseline --- train ROC_AUC:0.0630+-0.0022, val ROC_AUC:0.0843+-0.0110, test ROC_AUC:0.0591+-0.0069\n",
      "Lasso baseline           --- train ROC_AUC:0.0000+-0.0000, val ROC_AUC:0.9056+-0.0127, test ROC_AUC:0.9339+-0.0080\n"
     ]
    }
   ],
   "source": [
    "all_baseline_res = []\n",
    "for target_spectral_radius in np.arange(-0.8, 1.0, 0.2):\n",
    "    dataset = Games('../data/tmp', n_graphs=n_graphs, n_nodes=n_nodes, m=m, n_games=n_games,\n",
    "                target_spectral_radius=target_spectral_radius, alpha=alpha,\n",
    "                marginal_benefits_noise_variance=marginal_benefits_noise_variance, game_type=game_type, transform=None,\n",
    "                regenerate_data=regenerate_data, graph_type=graph, cost_distribution=None)\n",
    "\n",
    "\n",
    "    # Split datasets.\n",
    "    train_ratio = 1 - val_ratio - test_ratio\n",
    "    n_train_samples = math.floor(len(dataset) * train_ratio)\n",
    "    n_val_samples = math.floor(len(dataset) * val_ratio)\n",
    "\n",
    "    train_dataset = dataset[:n_train_samples]\n",
    "    val_dataset = dataset[n_train_samples:n_train_samples + n_val_samples]\n",
    "    test_dataset = dataset[n_train_samples + n_val_samples:]\n",
    "\n",
    "    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "    train_eval_loader = DataLoader(train_dataset, batch_size=test_batch_size, shuffle=True)\n",
    "    val_loader = DataLoader(val_dataset, batch_size=test_batch_size, shuffle=False)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)\n",
    "\n",
    "    # Model\n",
    "    baseline_results = eval_baseline(train_dataset, val_dataset, test_dataset)\n",
    "    all_baseline_res.append(baseline_results['correlation_test_roc_auc_mean'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f93b3600908>]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deXyU5b338c8vISEQdgh7Isgmi7IYUKTW3YPVQq1WcDtqbXmOlj4utFWr5bS2T1tttdqjHkt73Fot4I6Kh+JuKWrCLjuyJayBsMfsv+ePGXGMwQxkJvdk8n2/Xnl17nuuZL7N8vXimmvmNndHREQav5SgA4iISGyo0EVEkoQKXUQkSajQRUSShApdRCRJNAvqgTt16uS9evUK6uFFRBqlBQsW7HL3rNruC6zQe/XqRX5+flAPLyLSKJnZpiPdpyUXEZEkoUIXEUkSKnQRkSShQhcRSRIqdBGRJKFCFxFJEip0EZEkEdg+dBGRRFRV7VRUVVNRVU1lVfh2tVNZVU1F+LiyyqmorqaisprKaq91bGWVU15VHbpd/dnt0H3nDOzC0Ox2Mc+uQheRRsnd2V9aSdGBstDHwbLPbx8oY/ehMkorqsLl6+HyjSxjp7L6yyXdEJeI6NI2Q4UuIsmvtKKKXZHlXKOod0acL6+s/tLnp6emkNW6OR1bpZORlkrztBQyU1JISzXSUlNolppCWorRLHyclppCsxSjWWoK6amh/22WaqRHnP/y50acO/y5KeGvaTRLSSGt2edjI79eaophZnH53qnQRSTuqqqd4kPltRZ06Lj08PH+0spav0bHzHSyWjcnq3Vzju+Uefj2Zx+dWzcnq1UGbVo0i1thJjoVuogck6pqZ/+nFRSXlLOrxsy5ZmHvPlhGdS1LGZnpqYcL+YSubTi9X7igW32xrDtkppOWqj0cdVGhiwjlldXs/bScPYcq2FNSzt6Scopr3N5bUs6eknL2lITO7/u0otb15mYpdriIu7XN4KSebT8v54ii7tSqOZnNVUGxpO+mSJL5tLwqXLxfUdAlEQV9qIKDZbUvcwBkpKXQvmV66CMzjW7tWtChZTrtW6bRLnwuq1UGnduECrttizRSUprmkkfQVOgijUB5ZTV5G4vZeaCUPeHZcnF4tlxzBl1a8eUnCj/Tunkz2mWmhQs5neM7ZdI+M1zWLdMO327XMu1wibdIT23A/6dSH1EVupmNBR4EUoG/uPtva9x/HPAYkAUUA1e5e2GMs4o0OWt3HGBGXgEvLNpC8aHyw+fNoF2LtMPl26NdBoO7t/lCKbf/rJQzQ2PatUgnvZnWoZNZnYVuZqnAw8B5QCGQZ2az3H1FxLDfA0+5+5NmdjbwG+DqeAQWSXaHyip5delWZuQVsHDzXtJSjXMHduGSET05PiuT9i3TadMijVQta0gN0czQRwHr3H09gJlNB8YDkYU+CLg1fPtt4KVYhhRJdu7OooK9zPiogFeXbuVQeRV9O7firgsHcvHwHnRs1TzoiNIIRFPoPYCCiONC4JQaY5YA3ya0LHMx0NrMOrr77shBZjYJmASQk5NzrJlFkkbxoXJeWFjIjLwC1u48SMv0VC46qRsTRuYwIqddk91PLccmVk+K/gh4yMyuBd4DtgBVNQe5+zRgGkBubm4DvMBWJPFUVTv/XLeLGXmbmbtiBxVVzvCcdvz22ydy0dDutNJWPjlG0fzmbAGyI457hs8d5u5bCc3QMbNWwCXuvjdWIUWSQUFxCc8uKOS5/AK27iulfcs0/n10LyaMzKZ/l9ZBx5MkEE2h5wH9zKw3oSKfCFwROcDMOgHF7l4N3EFox4tIk1dWWcXcFTuYkVfAP9ftAuD0flnceeEgzh3UmebNtCVQYqfOQnf3SjObDMwhtG3xMXdfbmZ3A/nuPgs4E/iNmTmhJZcfxDGzSMJbvT203fDFRYXsKamgR7sW3HROPy49uSc927cMOp4kKfOGeK/IWuTm5np+fn4gjy0SDwfLKnllSWi74eKC0HbD8wd1ZcLIbMb07aRthhITZrbA3XNru0/PvojUg7uzYNMeZuQV8OrSbXxaUUX/Lq342UWDuHh4DzpkpgcdUZoQFbrIMdh1sOzwdsNPig6RmZ7K+GHdmTAym2HZ2m4owVChi0Spqtp5b00RM/IKeGPlDiqrnZOPa8+9l/bhwhO76Z0DJXD6DRSpQ0FxCTPzC3huQSHb9pXSMTOd68aEthv27azthpI4VOgitSitqGLO8u3MzC9g3rrdpBh8vX8WUy8axDkDu+hNriQhqdBFInxSdJC/zt/Ei4u2sO/TCnq2b8Gt5/Xn0pN70r1di6DjiXwlFbo0edXVzntri3h83kbeXVNEemoK/zakKxNyszmtT0ddrEEaDRW6NFkl5ZU8v3ALT8zbwCdFh8hq3Zxbz+vPFafk0EnvbiiNkApdmpzCPSU8NX8T0z/azP7SSk7q2ZY/TBjKhSd219q4NGoqdGkS3J28jXt4fN4G5izfjpkxdkhXvjumFyNy2mvfuCQFFboktbLKKl5dso3H5m1g+db9tG2RxqSv9+HfRx+nJzkl6ajQJSntPFDK0x9s5ukPN7PrYBn9Orfi1xefyMXDe+iix5K0VOiSVD7eso/H5m3g1SXbKK+q5uwTOnPdmF58rW8nLatI0lOhS6NXWVXN3BU7eGzeBvI27qFleiqXj8rmmtN6cXxWq6DjiTQYFbo0WvtKKpiet5mn5m9iy95P6dm+BXddOJDLRmbTJiMt6HgiDU6FLo3Oup0HeeJfG3h+wRY+raji1OM7MPWbgzh3YBe957g0aVEVupmNBR4kdMWiv7j7b2vcnwM8CbQLj7nd3WfHOKs0YdXVzrvhV3O+t6aI9GYpjB/anevG9GZQ9zZBxxNJCHUWupmlAg8D5wGFQJ6ZzXL3FRHD7gJmuvt/m9kgYDbQKw55pYk5VFbJCwsLefxfG1lfdIjOrZszJfxqzo56NafIF0QzQx8FrHP39QBmNh0YD0QWugOfTZPaAltjGVKanoLiEp6av5HpeQUcKK1kaM+2PDBhGN84sZtezSlyBNEUeg+gIOK4EDilxpifA/8wsx8CmcC5tX0hM5sETALIyck52qyS5NydjzYU8/i8jfxjReSrOXszIkdXARKpS6yeFL0ceMLd7zOz0cBfzWyIu1dHDnL3acA0CF0kOkaPLY1cWWUVryzZxuPhV3O2a5nG/zmjD1efqldzihyNaAp9C5AdcdwzfC7S9cBYAHefb2YZQCdgZyxCSnLaeaCUv32wmWc+3MSug+V6NadIPUVT6HlAPzPrTajIJwJX1BizGTgHeMLMBgIZQFEsg0pyeWf1TiY9tYDyqmrOOaEz143pzZi+HbWsIlIPdRa6u1ea2WRgDqEtiY+5+3IzuxvId/dZwBTgz2Z2C6EnSK91dy2pSK32HCrnx88tpXenTB69+mR6d8oMOpJIUohqDT28p3x2jXNTI26vAMbENpokq7te/pi9JeU8cd1IlblIDGn/lzSoV5Zs5bWl27jpnH4M7t426DgiSUWFLg1m5/5SfvbyxwzLbsd/nNEn6DgiSUeFLg3C3bnt+aWUVlRx32VDaZaqXz2RWNNflTSIGXkFvL26iNvGnkAfvaWtSFyo0CXuCopL+OWrKxh9fEeuGd0r6DgiSUuFLnFVXe386NklmBm/+85JpOjtbUXiRoUucfXYvA18uKGYqd8cRM/2LYOOI5LUVOgSN+t2HuDeOas5d2BnvnNyz6DjiCQ9FbrERUVVNbfOXEJmeiq//vaJekm/SAPQJegkLh55+xOWFu7jkStH0Ll1RtBxRJoEzdAl5pYV7uO/3lrL+GHd+caJ3YKOI9JkqNAlpkorqrh15mI6ZKbzi3GDg44j0qRoyUVi6v65a1i78yCPXzeSdi3Tg44j0qRohi4x89GGYv78/nouH5XDWQM6Bx1HpMlRoUtMHCqr5EfPLqFn+xbceeHAoOOINElacpGY+H+zV1Kwp4QZk0bTqrl+rUSCENUM3czGmtlqM1tnZrfXcv8fzGxx+GONme2NfVRJVO+s3skzH27me1/rzajeHYKOI9Jk1TmVMrNU4GHgPKAQyDOzWeGrFAHg7rdEjP8hMDwOWSUB7Sup4Lbnl9KvcyumnD8g6DgiTVo0M/RRwDp3X+/u5cB0YPxXjL8c+HsswknimzrrY3YfLOf+y4aRkZYadByRJi2aQu8BFEQcF4bPfYmZHQf0Bt46wv2TzCzfzPKLioqONqskmNnLtvHy4q1MPrsvJ/bU5eREghbrXS4Tgefcvaq2O919mrvnuntuVlZWjB9aGtLOA6Xc+eIyTuzRlh+c1TfoOCJCdIW+BciOOO4ZPlebiWi5Jem5Oz99YRmHyqu4/7KhpOlyciIJIZq/xDygn5n1NrN0QqU9q+YgMzsBaA/Mj21ESTTPLijkjZU7+cm/DaBfl9ZBxxGRsDoL3d0rgcnAHGAlMNPdl5vZ3WY2LmLoRGC6u3t8okoiKNxTwt2vrGBU7w58d0zvoOOISISoXgHi7rOB2TXOTa1x/PPYxZJEVF3t/PjZpbg7931nqC4nJ5JgtPgpUXty/kbmr9/NXRcNIruDLicnkmhU6BKVT4oO8tvXV3HmgCwmjsyu+xNEpMGp0KVOleHLyWWkpXLPJSfpcnIiCUrvoiR1evTdT1hSsJc/Xj6cLm10OTmRRKUZunyl5Vv38eCba7nwpG6MG9o96Dgi8hVU6HJEZZVVTJm5hHYt0/nV+CFBxxGROmjJRY7ogTfWsmr7AR67Npf2mbqcnEii0wxdarVgUzF/evcTJuRmc/YJXYKOIyJRUKHLl5SUVzJl5hK6tW3BXRfpcnIijYWWXORLfvv6KjbuLuHv3z+V1hlpQccRkShphi5f8M+1u3hq/ia+O6Y3o/t0DDqOiBwFFboctu/TCn783BL6ZGXyk7G6nJxIY6MlFznsF68sZ+eBMl644TRdTk6kEdIMXQCYs3w7Lyzcwg/O7MPQ7HZBxxGRY6BCF3YdLOOnLyxjcPc2TD67X9BxROQYacmliXN37nxxGQdKK3nm+8NIb6b/xos0VlH99ZrZWDNbbWbrzOz2I4y5zMxWmNlyM3smtjElXl5ctIU5y3cw5fz+DOiqy8mJNGZ1ztDNLBV4GDgPKATyzGyWu6+IGNMPuAMY4+57zKxzvAJL7Gzd+yn/OWs5I3u153unHx90HBGpp2hm6KOAde6+3t3LgenA+Bpjvg887O57ANx9Z2xjSqy5O7c9v5Sqauf33xlKqi4nJ9LoRVPoPYCCiOPC8LlI/YH+ZjbPzD4ws7G1fSEzm2Rm+WaWX1RUdGyJJSb+9sEm3l+7i59+YyDHdcwMOo6IxECsngFrBvQDzgQuB/5sZl/a++bu09w9191zs7KyYvTQcrQ27jrEr2ev4uv9s7jylJyg44hIjERT6FuAyItI9gyfi1QIzHL3CnffAKwhVPCSYKqqnSnPLiEt1bhXl5MTSSrRFHoe0M/MeptZOjARmFVjzEuEZueYWSdCSzDrY5hTYmTae+tZsGkPd48fQte2upycSDKps9DdvRKYDMwBVgIz3X25md1tZuPCw+YAu81sBfA28GN33x2v0HJsVm3fzx/mruGCIV0ZP0yXkxNJNubugTxwbm6u5+fnB/LYTVF5ZTXjH55H0YFS5tz8dTq2ah50JBE5Bma2wN1za7tPrxRtIv745lpWbtvPtKtPVpmLJCm9zrsJWLR5D4+8s45LRvTk/MFdg44jInGiQk9yn5ZXMWXmErq2yeA/xw0KOo6IxJGWXJLcb15fyfpdh3j6e6fQRpeTE0lqmqEnsTdX7uCp+Zu4/mu9GdO3U9BxRCTOVOhJaueBUn7y3FJO6Npal5MTaSK05JKE3J0fP7uUg2WV/H3SqTRvpsvJiTQFmqEnoSf/tZF31xRx54UD6d9F73Eu0lSo0JPMqu37+fXrqzj7hM5cfepxQccRkQakQk8ipRVV3PT3xbTJaMa9l+qNt0SaGq2hJ5F7/ncVq3cc4PHrRtJJrwYVaXI0Q08S76zeyePzNnLtab04a4CuACjSFKnQk8Cug2X86NmlDOjSmtsvOCHoOCISEC25NHLuzm3PLWV/aQV/+94oMtK0RVGkqdIMvZH724ebeXPVTu644ARO6Nom6DgiEiAVeiO2dscBfvXqCs7on8W1p/UKOo6IBCyqQjezsWa22szWmdnttdx/rZkVmdni8Mf3Yh9VIpVVVvF/py+mVfNm/O472qIoIlGsoZtZKvAwcB6hi0Hnmdksd19RY+gMd58ch4xSi9/PWc3Kbfv5n2ty6dxa1wYVkehm6KOAde6+3t3LgenA+PjGkq/y/toi/vz+Bq4+9TjOGdgl6DgikiCiKfQeQEHEcWH4XE2XmNlSM3vOzLJr+0JmNsnM8s0sv6io6BjiSvGhcqbMXELfzq346TcGBh1HRBJIrJ4UfQXo5e4nAXOBJ2sb5O7T3D3X3XOzsrJi9NBNh7tz2/NL2VtSwYMTh9EiXVsUReRz0RT6FiByxt0zfO4wd9/t7mXhw78AJ8cmnkT6+0cFzF2xg5+MHcDg7m2DjiMiCSaaQs8D+plZbzNLByYCsyIHmFm3iMNxwMrYRRSAdTsPcveryzm9Xye+O6Z30HFEJAHVucvF3SvNbDIwB0gFHnP35WZ2N5Dv7rOA/2tm44BKoBi4No6Zm5zyympunrGIFmmp/P47Q0lJ0RZFEfmyqF767+6zgdk1zk2NuH0HcEdso8ln7pu7mo+37OdPV59MlzbaoigitdMrRRPcv9btYtp767l8VA7/Nrhr0HFEJIGp0BPY3pJybp25hN6dMvnZRdqiKCJfTYWeoNydO15Yxu5DZfxx4nBapuuNMUXkq6nQE9Sz+YW8/vF2ppw/gCE9tEVRROqmQk9AG3Yd4uevLOe0Ph2ZdPrxQccRkUZChZ5gKqqquXn6ItJSU7jvMm1RFJHoaWE2wTzwxhqWFO7jv68cQbe2LYKOIyKNiGboCeTD9bt55J1PmJCbzQUndqv7E0REIqjQE8S+kgpumbGY4zq0ZOo3BwUdR0QaIS25JAB356cvLWPngTKev+E0MpvrxyIiR08z9ATwwsItvLZ0G7ec15+h2e2CjiMijZQKPWCbdh9i6ssfM6p3B/7jjD5BxxGRRkyFHqCKqmpumr6Y1BTjDxOGkaotiiJSD1qsDdB/vbWOxQV7eeiK4fRopy2KIlI/mqEHJH9jMQ+9tZZLRvTkopO6Bx1HRJKACj0A+0sruGn6Ynq2b8kvxg8OOo6IJImoCt3MxprZajNbZ2a3f8W4S8zMzSw3dhGTz9SXPmb7/lIemDiMVtqiKCIxUmehm1kq8DBwATAIuNzMvvTKFzNrDdwEfBjrkMnkpUVbeGnxVm46px8jctoHHUdEkkg0M/RRwDp3X+/u5cB0YHwt434J3AOUxjBfUikoLuFnL31M7nHtufFMbVEUkdiKptB7AAURx4Xhc4eZ2Qgg291f+6ovZGaTzCzfzPKLioqOOmxjVllVzS0zFgPwhwnDaJaqpy9EJLbq3SpmlgLcD0ypa6y7T3P3XHfPzcrKqu9DNyqPvPMJ+Zv28KuLh5DdoWXQcUQkCUVT6FuA7IjjnuFzn2kNDAHeMbONwKnALD0x+rkFm/bw4JtruXh4D8YP61H3J4iIHINoCj0P6Gdmvc0sHZgIzPrsTnff5+6d3L2Xu/cCPgDGuXt+XBI3MgdKK7h5xiK6tc3QFkURias6C93dK4HJwBxgJTDT3Zeb2d1mNi7eARu7n89awZY9n/LAhGG0yUgLOo6IJLGoNkG7+2xgdo1zU48w9sz6x0oOryzZyvMLC7npnH7k9uoQdBwRSXLaahEnW/Z+yk9fXMbwnHb88Oy+QccRkSZAhR4HVdXOLTMW4w4PThiuLYoi0iD0uvM4ePTdT/hoQzH3XzaUnI7aoigiDUNTxxhbXLCXP8xdwzeHdufi4dqiKCINR4UeQ4fKKrl5+iK6tMngV98agpkuWCEiDUdLLjH0i1eWs7m4hOmTRtO2hbYoikjD0gw9Rl5buo2Z+YXceGZfRvXWFkURaXgq9BhYX3SQ255fyvCcdtx0br+g44hIE6VCr6fSiipufHohaanGw1eMIE1bFEUkIFpDr6f/fHk5q7Yf4InrRtJdF3oWkQBpOlkPzy0oZEZ+AZPP6suZAzoHHUdEmjgV+jFavf0Ad720jNHHd+SW8/oHHUdERIV+LA6WVXLD0wtonZHGg5cPIzVF+81FJHgq9KPk7tzxwjI27jrEHycOp3PrjKAjiYgAKvSj9rcPN/PKkq1MOX8Ao/t0DDqOiMhhKvSjsKxwH798ZQVnDcjihjP6BB1HROQLoip0MxtrZqvNbJ2Z3V7L/f9hZsvMbLGZ/dPMBsU+arD2lVRw4zML6NQqnfsvG0aK1s1FJMHUWehmlgo8DFwADAIur6Wwn3H3E919GHAvcH/MkwbI3fnRc0vYvq+Uh64cQfvM9KAjiYh8STQz9FHAOndf7+7lwHRgfOQAd98fcZgJeOwiBu8v729g7ood3HHBQEbktA86johIraJ5pWgPoCDiuBA4peYgM/sBcCuQDpxd2xcys0nAJICcnJyjzRqI/I3F/PZ/V3HBkK5cN6ZX0HFERI4oZk+KuvvD7t4HuA246whjprl7rrvnZmVlxeqh42b3wTImP7OInu1bcM+lJ+n9zUUkoUVT6FuA7IjjnuFzRzId+FZ9QiWCqmrn5hmLKS4p55ErR9AmQ+9vLiKJLZpCzwP6mVlvM0sHJgKzIgeYWeR7xl4IrI1dxGA89NY63l+7i1+MG8zg7m2DjiMiUqc619DdvdLMJgNzgFTgMXdfbmZ3A/nuPguYbGbnAhXAHuCaeIaOt3nrdvHAm2v49vAeTByZXfcniIgkgKjePtfdZwOza5ybGnH7phjnCsyO/aXcNH0RfbNa8auLdV1QEWk89H7oESqrqvnhM4s4VFbF9EkjaJmub4+INB5qrAi//8caPtpYzAMThtG3c+ug44iIHBW9l0vYmyt38Oi7n3DFKTl8a3iPoOOIiBw1FTpQUFzCrTOXMLh7G6ZelHRvQyMiTUSTL/SyyiomP7OQanceuXIEGWmpQUcSETkmTX4N/TezV7GkcB+PXnUyx3XMDDqOiMgxa9Iz9NeWbuOJf23k+q/1ZuyQrkHHERGplyZb6OuLDnLb80sZkdOO2y84Ieg4IiL11iQLvbSiihufXkhaqvHQFSNIS22S3wYRSTJNcg196ssfs2r7AZ64biTd27UIOo6ISEw0uanps/kFzMwvZPJZfTlzQOeg44iIxEyTKvRV2/fzs5c/ZvTxHbnlvP5BxxERiakmU+gHyyq58emFtM5I48HLh5GqizyLSJJpEoXu7tzxwjI27jrEHycOp3PrjKAjiYjEXJMo9L99sIlXlmxlyvkDGN2nY9BxRETiIqpCN7OxZrbazNaZ2e213H+rma0ws6Vm9qaZHRf7qMdmaeFefvnqSs4akMUNZ/QJOo6ISNzUWehmlgo8DFwADAIuN7Oa72C1CMh195OA54B7Yx30WOwrqeDGpxfSqVU69182jBStm4tIEotmhj4KWOfu6929nNBFoMdHDnD3t929JHz4AaELSQfK3Zny7BJ27C/loStH0D4zPehIIiJxFU2h9wAKIo4Lw+eO5Hrg9fqEioU/v7+eN1bu4I4LBjIip33QcURE4i6mrxQ1s6uAXOCMI9w/CZgEkJOTE8uH/oK8jcXc87+ruWBIV64b0ytujyMikkiimaFvAbIjjnuGz32BmZ0L3AmMc/ey2r6Qu09z91x3z83KyjqWvHXafbCMyc8sJLt9C+659CRd5FlEmoxoCj0P6Gdmvc0sHZgIzIocYGbDgT8RKvOdsY8Znapq5+YZi9lTUsHDV46gTUZaUFFERBpcnYXu7pXAZGAOsBKY6e7LzexuMxsXHvY7oBXwrJktNrNZR/hycfVfb63l/bW7+MW4wQzu3jaICCIigYlqDd3dZwOza5ybGnH73BjnOmr/XLuLB99cy7eH92DiyOy6P0FEJMkkxStFt+8r5abpi+ib1YpfXTxE6+Yi0iQ1+kKvrKrmh39fyKcVVfz3VSNomd4k3+JdRKTxX+Did/9YTd7GPTw4cRh9O7cOOo6ISGAa9Qx97ood/Ond9VxxSg7jh33Va51ERJJfoy30guISpsxczODubZh6Uc23lhERaXoaZaGXVVbxg2cW4sAjV44gIy016EgiIoFrlGvov35tJUsL9/HoVSdzXMfMoOOIiCSERjdDf2XJVp6cv4nrv9absUO6Bh1HRCRhNLpC75CZznmDunD7BScEHUVEJKE0uiWXMX07MaZvp6BjiIgknEY3QxcRkdqp0EVEkoQKXUQkSajQRUSShApdRCRJqNBFRJKECl1EJEmo0EVEkoS5ezAPbFYEbDrGT+8E7IphnFhRrqOjXEcvUbMp19GpT67j3D2rtjsCK/T6MLN8d88NOkdNynV0lOvoJWo25To68cqlJRcRkSShQhcRSRKNtdCnBR3gCJTr6CjX0UvUbMp1dOKSq1GuoYuIyJc11hm6iIjUoEIXEUkSjaLQzayDmc01s7Xh/21/hHH3mtlyM1tpZn80M0uQXDlm9o9wrhVm1isRcoXHtjGzQjN7KJ6Zos1lZsPMbH7457jUzCbEMc9YM1ttZuvM7PZa7m9uZjPC938Y75/bUeS6Nfx7tNTM3jSz4xIhV8S4S8zMzaxBtgtGk8vMLgt/z5ab2TOJkCvcC2+b2aLwz/Ib9X5Qd0/4D+Be4Pbw7duBe2oZcxowD0gNf8wHzgw6V/i+d4DzwrdbAS0TIVf4/geBZ4CHEuTn2B/oF77dHdgGtItDllTgE+B4IB1YAgyqMeZG4NHw7YnAjAb4HkWT66zPfoeAGxIlV3hca+A94AMgNxFyAf2ARUD78HHnBMk1DbghfHsQsLG+j9soZujAeODJ8O0ngW/VMsaBDELfvOZAGrAj6FxmNgho5u5zAdz9oLuXBJ0rnO1koAvwjzjniTqXu69x97Xh21uBnUCtr4qrp1HAOndf7+7lwPRwviPlfQ44J97/6osml7u/HfE79AHQM86ZosBI7i0AAANKSURBVMoV9kvgHqC0ATJFm+v7wMPuvgfA3XcmSC4H2oRvtwW21vdBG0uhd3H3beHb2wmV0Be4+3zgbUIzum3AHHdfGXQuQjPOvWb2QvifVr8zs9Sgc5lZCnAf8KM4ZzmqXJHMbBSh/0B/EocsPYCCiOPC8Llax7h7JbAP6BiHLEebK9L1wOtxTRRSZy4zGwFku/trDZAn6lyE/gb7m9k8M/vAzMYmSK6fA1eZWSEwG/hhfR80YS4SbWZvAF1ruevOyAN3dzP70l5LM+sLDOTz2cpcMzvd3d8PMheh7/HpwHBgMzADuBb4n4Bz3QjMdvfCWE46Y5Drs6/TDfgrcI27V8csYBIxs6uAXOCMBMiSAtxP6Hc70TQjtOxyJqF+eM/MTnT3vYGmgsuBJ9z9PjMbDfzVzIbU5/c9YQrd3c890n1mtsPMurn7tvAfem3/ZLoY+MDdD4Y/53VgNFCvQo9BrkJgsbuvD3/OS8Cp1LPQY5BrNHC6md1IaF0/3cwOuvsRn+xqoFyYWRvgNeBOd/+gPnm+whYgO+K4Z/hcbWMKzawZoX8W745TnqPJhZmdS+g/kme4e1mcM0WTqzUwBHgnPEHoCswys3Hunh9gLgj9DX7o7hXABjNbQ6jg8wLOdT0wFkIrDGaWQehNu455SaixLLnMAq4J374GeLmWMZuBM8ysmZmlEZq1xHvJJZpceUA7M/tsHfhsYEXQudz9SnfPcfdehJZdnqpvmccil5mlAy+G8zwXxyx5QD8z6x1+zInhfEfKeynwloefwQoyl5kNB/4EjGug9eA6c7n7Pnfv5O69wr9TH4TzxbPM68wV9hKh2Tlm1onQEsz6BMi1GTgnnGsgoecAi+r1qPF+tjcWH4TWLd8E1gJvAB3C53OBv/jnzyr/iVCJrwDuT4Rc4ePzgKXAMuAJID0RckWMv5aG2eUSzc/xKqACWBzxMSxOeb4BrCG0Rn9n+NzdhIoIQn9gzwLrgI+A4+P9PYoy1xuEnvD/7PszKxFy1Rj7Dg2wyyXK75cRWg5aEf4bnJgguQYR2pm3JPxzPL++j6mX/ouIJInGsuQiIiJ1UKGLiCQJFbqISJJQoYuIJAkVuohIklChi4gkCRW6iEiS+P9lCuhvQ9MaKAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.plot(np.arange(-0.8, 1.0, 0.2), all_baseline_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.7843270541058744,\n",
       " 0.7844227207144391,\n",
       " 0.8205610079752967,\n",
       " 0.7860841703409596,\n",
       " 0.859748512414839,\n",
       " 0.8489608006967932,\n",
       " 0.8680758639491074,\n",
       " 0.8690764998690474]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_baseline_res"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
