{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '../..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from causal_meta.modules.mdn import mdn_nll\n",
    "from causal_meta.utils.data_utils import RandomSplineSCM\n",
    "from causal_meta.utils.train_utils import train_nll, make_alpha, train_alpha\n",
    "from models import mdn, gmm, auc_transfer_metric\n",
    "from argparse import Namespace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normal(mean, std, N): \n",
    "    return torch.normal(torch.ones(N).mul_(mean), torch.ones(N).mul_(std)).view(-1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = Namespace()\n",
    "# Model\n",
    "opt.CAPACITY = 32\n",
    "opt.NUM_COMPONENTS = 10\n",
    "opt.GMM_NUM_COMPONENTS = 10\n",
    "# Training\n",
    "opt.LR = 0.001\n",
    "opt.NUM_ITER = 3000\n",
    "opt.CUDA = False\n",
    "opt.REC_FREQ = 10\n",
    "# Meta\n",
    "opt.ALPHA_LR = 0.1\n",
    "opt.ALPHA_NUM_ITER = 500\n",
    "opt.FINETUNE_LR = 0.001\n",
    "opt.FINETUNE_NUM_ITER = 10\n",
    "opt.PARAM_DISTRY = lambda mean: normal(mean, 2, opt.NUM_SAMPLES)\n",
    "opt.PARAM_SAMPLER = lambda: np.random.uniform(-4, 4)\n",
    "# Sampling \n",
    "opt.NUM_SAMPLES = 1000\n",
    "opt.TRAIN_DISTRY = lambda: normal(0, 2, opt.NUM_SAMPLES)\n",
    "opt.TRANS_DISTRY = lambda: normal(random.randint(-4, 4), 2, opt.NUM_SAMPLES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "scm = RandomSplineSCM(False, True, 8, 10, 3, range_scale=1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ff998af471f43529353f6b64ad60425",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model_x2y = mdn(opt)\n",
    "frames_x2y = train_nll(opt, model_x2y, scm, opt.TRAIN_DISTRY, polarity='X2Y',\n",
    "    loss_fn=mdn_nll, decoder=None, encoder=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bc7ac064cc66424cbfbad19ecd98b8cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model_y2x = mdn(opt)\n",
    "frames_y2x = train_nll(opt, model_y2x, scm, opt.TRAIN_DISTRY, polarity='Y2X',\n",
    "    loss_fn=mdn_nll, decoder=None, encoder=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f7d57db083a6466e85ccb3ad98f1bf7e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=500), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "alpha = make_alpha(opt)\n",
    "alpha_frames = train_alpha(opt, model_x2y, model_y2x, None, None, alpha, scm, \n",
    "                           opt.PARAM_DISTRY, opt.PARAM_SAMPLER, mdn_nll,\n",
    "                           auc_transfer_metric, mixmode='logsigp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAFHCAYAAABDK11BAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xt0XXWd///nO2nSJm2Tll64BHqBCiKtxVLkqjIjXgeRUb9eaB3FS1FkfQVFYRTl8h0RhPmuH2v8emFQYRBnGKcWAS8oWOViFVqwLQKi1NKhLdCWtml6S9N8fn+ck0x6ctKmue1m5/lYa6+T8/l89t7vc/Zp8uq+nB0pJSRJkvKmIusCJEmS+oMhR5Ik5ZIhR5Ik5ZIhR5Ik5ZIhR5Ik5ZIhR5Ik5ZIhR5Ik5ZIhR5Ik5ZIhR5Ik5dKwrAsYCOPHj09TpkzJugxJktQHlixZsj6lNGFf44ZEyJkyZQqLFy/OugxJktQHIuK57ozzcJUkScolQ44kScolQ44kScolQ44kScolQ44kScolQ44kScolQ44kScqlTENORLw/Ih6MiMaIaOnG+NkR8UhEbIuIZyNi7kDUKUmSBp+s9+RsBL4BXLSvgRFRD/wMmA+MBT4BfCsiTunXCiVJ0qCU6Tcep5TuBYiIM7ox/F3ANuBrKaUE/DIiFgDzgEV7m3Hnzp2sWLFij7b6+nrGjRtHa2srK1eu7DTP2LFjGTt2LC0tLaxatapT/7hx46ivr6e5uZnnn3++U//48eOpq6tj586drF69ulP/xIkTGTVqFNu3b2ft2rWd+g855BBqa2vZtm0bL7zwQqf+Qw89lJqaGpqamnjppZc69Tc0NDB8+HAaGxtZv359p/7DDz+c6upqNm/ezIYNGzr1T5o0iWHDhrFx40Y2btwIQEqJ5uZmmpubOfjgg2lpaWHdunVs2rSJ1tZWdu/eze7du2ltbeWwww5j9+7dbNiwgcbGxj36U0pMmDCB1tZWXn75ZbZv305KidbWVgAigvHjx5NSYuPGje39bTVUVFRw0EEH0draysaNG2lubial1L6MyspKxowZ0z7/rl27aG1tbR9TVVXF6NGj2+dvaWlp7wOorKxk1KhR7fO31dxx/tra2vb+jstu6x8xYgQpJTZt2rRHbQBVVVUMHz6c1tZWGhsbO733w4cPb+9vamrq1D9ixAiqqqrYvXs3W7duLdtfXV1dtj+lRE1NDVVVVbS0tLBt27ZO89fW1jJs2LBO/W3vT1v/rl272L59e6f5R44cSWVlJc3NzezYsaNsf0VFBc3NzezcubPL/p07d5btHzVqFBHBzp07aW5u7tQ/evRoUkrs2LGDXbt27dEXEYwaNQqA7du309LS0ql/5MiRXfZXVFS092/bto3du3fv8f5UVlZSW1tbth9g2LBh1NTUALB169b2z87e+jtq+2wBNDU1deqvrq5m+PDhAGzZsqXTe9PW39raWvazU11dTXV1Na2trWU/G9XV1VRVVXXZP3z48PbPZulnI6XEiBEjGDZsWNl+oL2/paWl7GenpqaGysrKffbv2rWr7GentraWioqKffa3/Z4rNXLkyPbPXulnC2j/bJXr7/jZ2rFjxz4/e6WfnYqKivbP1r76t23b1umzUVlZ2f7Z2lf/1q1b9/hcQuGz2fbZK9ff9nsNKPt7q60/pdSjz17bZ6ur/q4Mpts6zAQeT3u+s48BHyw3OCLmUQhAHHPMMf1f3QFu165dvPjii6xZs4Y1a9awY8cOmpqaWL9+PevWraOpqYktW7bQ1NREc3Mzra2t7Ny5k+3bt7f/MSn3j16SpANVlKaxTIoo7Mm5L6XUZeiKiO8Aw1JKH+rQdh7wxZTStL0tf/bs2Wmo3LvqxRdfZPny5TzxxBMsX76cp59+mueee441a9Z0St490ZbG2/YWDBs2jMrKyk5TRUVF2fbSMRUVFUQEEbHHz3try8vYjlNPDPR8Q2Wdg6nWLNaZRa1Sqblz5y5JKc3e17jBtCdnCzClpG0M0Hl//xCwe/dunnjiCRYvXszy5cvbp3Xr1pUdHxE0NDQwadIkJk+eTENDA2PHjqWuro76+nrq6uqoq6tj9OjR1NTUtB8yaQs0w4cPp7q6moqKrE/jkiQNdXPndu+6o8EUcpYC55S0vabYnnuNjY088MADLFq0iEWLFvHII4+UPa5ZV1fH9OnTmTFjBtOnT+e4445j6tSpNDQ0UFVVlUHlkiRlI9OQExGVQBVQXXw+oti1M3U+trIA+FpEfA64EXgdhZOR3zRA5Q64tWvXctddd3HnnXdy//33dzqRberUqZx00knMnDmTGTNmMGPGDI444gh3CUuSRPZ7cj4IfK/D87bT7adGxBEULhl/VUppVUppU0S8Hfh/wNXAWuATKaW9Xlk12OzYsYMFCxZw00038etf/7q9vaKiglNPPZXTTz+dU045hVNOOYWDDz44u0IlSTrAZX0J+S3ALV10rwRGlYx/FHhtvxaVkcbGRm644Qa+8Y1vtF/SPWLECN785jdzzjnncNZZZzFhwoSMq5QkafDIek/OkLd7926+/e1v8+Uvf7k93Bx//PGcf/75nHvuudTV1WVcoSRJg5MhJ0NPPPEEH/3oR3nkkUcAOP300/nqV7/Kaaed5nk1kiT1kiEnIz/5yU943/vex9atW2loaODGG2/kXe96l+FGkqQ+4peeZODmm2/m7LPPZuvWrZx77rk8+eSTvPvd7zbgSJLUhww5A+yee+5h3rx5tLa2csUVV/D973/f824kSeoHHq4aQEuXLuX9738/KSWuvPJKrrjiiqxLkiQpt9yTM0DWrFnDWWedxdatW5kzZw5f/vKXsy5JkqRcM+QMgK1bt3L22Wfz/PPPc9ppp3HzzTd7/o0kSf3MkNPPUkp85CMfYcmSJRx55JEsWLCAESNG7HtGSZLUK4acfnbrrbfyn//5n4wePZp77rnHby2WJGmAGHL60dq1a/n0pz8NwNe//nWOPfbYjCuSJGnoMOT0o4svvpjGxkbOOussPvjBD2ZdjiRJQ4ohp5/ce++93HHHHdTW1vL1r3/dE40lSRpghpx+sH37di644AIArrjiCiZPnpxxRZIkDT2GnH7wla98hRUrVjBjxgwuvvjirMuRJGlIMuT0saeeeoqvfe1rAHzrW9+iqqoq44okSRqaDDl9KKXEpz71KXbt2sXHP/5xTj311KxLkiRpyDLk9KH777+fhQsXMnbsWK699tqsy5EkaUgz5PSRlFL7/ag+97nPcdBBB2VckSRJQ5shp48sXLiQRYsWMW7cOC688MKsy5Ekacgz5PSRr3zlKwBcdNFFjB49OuNqJEmSIacPLFq0iF/96lfU1dW5F0eSpAOEIacPtO3FufDCCxkzZkzG1UiSJDDk9Nrjjz/OT37yE2pra7nooouyLkeSJBUZcnrpmmuuAeD8889nwoQJGVcjSZLaGHJ6Yfny5cyfP5/q6mo++9nPZl2OJEnqwJDTC1dddRUpJc4//3waGhqyLkeSJHVgyOmhpUuXMn/+fEaMGMFll12WdTmSJKmEIaeHrrrqKqBwLs5hhx2WcTWSJKmUIacHnnnmGRYsWMDw4cO59NJLsy5HkiSVYcjpgRtvvBGAuXPncuihh2ZcjSRJKseQs5+ampq49dZbAfxeHEmSDmCGnP00f/58tm7dyqmnnsr06dOzLkeSJHXBkLOf2vbifPjDH862EEmStFeGnP3wzDPPsHDhQmpqanjve9+bdTmSJGkvDDn74Vvf+hYA5557LvX19RlXI0mS9saQ003btm3je9/7HgAXXHBBxtVIkqR9MeR00x133MGmTZt47Wtfy6xZs7IuR5Ik7YMhp5u++c1vAvDJT34y40okSVJ3ZBpyIqIyIq6PiHURsSUi5kfE+L2MvyQini2O/XNEDMhxoyeffJJHH32U+vp63ve+9w3EKiVJUi9lvSfnMuCdwEnA4cW228oNjIizgauAOSml0cA/ANdHxJv6u8jbb78dgPe85z3U1NT09+okSVIfyDrkzAOuSymtSCltBj4PvDUiJpcZOw1YmlL6HUBKaRGwDJjZnwW2trbygx/8ACjcxkGSJA0OmYWciBgDTAKWtLWllJ4FGikfXP4DqIuI0yKiIiJeBxwN/Lw/6/ztb3/LypUrOfzww3n961/fn6uSJEl9aFiG6x5dfNxc0r4JqCsz/iXgv4CF/E84uyil9ES5hUfEPAp7ipg0aVKPi/z+978PFL4bp6Ii6x1fkiSpu7L8q72l+Fj6rXpjKOzNKfUl4FzgeKCKwt6eiyPio+UWnlK6KaU0O6U0e8KECT0qsKWlhR/+8IcAzJkzp0fLkCRJ2cgs5KSUNgGrgPYvnYmIIynsxVlWZpYTgAUppSdTwR+BO4F39FeNDz/8MC+//DJHH300r371q/trNZIkqR9kffzlJuDSiJgaEXXAdcC9KaWVZcY+DJwTEa8AiIhjgXPocE5PX7vrrrsAOPvss/trFZIkqZ9keU4OwLXAWOBRYDjwS2AuQETMAb6dUhpVHHs9hUNbvyx+l87LwA+Ly+hzKaX2kPOOd/TbziJJktRPIqWUdQ39bvbs2Wnx4sX7Nc/TTz/Nsccey0EHHcSLL77IsGFZ50FJkgQQEUtSSrP3NS7rw1UHrLvvvhuAt7/97QYcSZIGIUNOF9pCjufjSJI0OBlyytiwYQMPP/wwVVVVvOUtb8m6HEmS1AOGnDJ++tOf0trayhlnnEFdXbnvJZQkSQc6Q04ZXlUlSdLgZ8gp0dLSwr333gsYciRJGswMOSUee+wxtmzZwrRp05gyZUrW5UiSpB4y5JT4zW9+A8AZZ5yRbSGSJKlXDDkl2kLOG97whowrkSRJvWHI6WD37t08+OCDgCFHkqTBzpDTwdKlS2lsbGTq1KkcccQRWZcjSZJ6wZDTgYeqJEnKD0NOB4YcSZLyw5BT1Nra6vk4kiTliCGn6A9/+AMvv/wykydP9vtxJEnKAUNO0f333w/AmWeeSURkXI0kSeotQ05RW8h54xvfmHElkiSpLxhygJ07d/LAAw8A8Ld/+7cZVyNJkvqCIQf43e9+x/bt25kxYwYHH3xw1uVIkqQ+YMjBQ1WSJOWRIQe47777gMJJx5IkKR+GfMhpamrikUceobKykte//vVZlyNJkvrIkA85v//979m9ezezZs1i9OjRWZcjSZL6yJAPOQ899BAAp512WsaVSJKkvjTkQ87DDz8MwOmnn55xJZIkqS8N6ZDT0tLCokWLAPfkSJKUN0M65CxbtoympiamTZvGIYccknU5kiSpDw3pkNN2Po6HqiRJyh9DDoYcSZLyaMiGnJSSIUeSpBwbsiHnueeeY+3atYwbN46jjz4663IkSVIfG7Ih53e/+x0AJ598MhGRcTWSJKmvDdmQ03bp+Mknn5xxJZIkqT8M2ZDTtifnlFNOybgSSZLUH4ZkyNmxYwePP/44EcGJJ56YdTmSJKkfDMmQ89hjj7Fr1y6mT59OXV1d1uVIkqR+MCRDTseTjiVJUj4NyZDjSceSJOXfkAw5nnQsSVL+DbmQs2rVKp5//nnGjBnDMccck3U5kiSpn2QaciKiMiKuj4h1EbElIuZHxPi9jJ8YEbdGxIaIaIyIP0TEYfuzzocffhiAU089lYqKIZfxJEkaMrL+K38Z8E7gJODwYttt5QZGxAjgfqAZOAYYA8wBmvZnhW0h57TTTutRwZIkaXAYlvH65wFXp5RWAETE54G/RMTklNJzJWM/RCHYXJBS2lVs++P+rrAt5HhTTkmS8i2zPTkRMQaYBCxpa0spPQs0AjPLzPI3wJ+BW4qHq56OiIv3svx5EbE4IhavW7cOgMbGRpYtW0ZVVZVfAihJUs5lebhqdPFxc0n7JqDcN/SNpxB0HgEOBeYCX4yIOeUWnlK6KaU0O6U0e8KECUDhqqrW1lZmzZpFTU1NX7wGSZJ0gMoy5GwpPtaXtI+hsDen3PjVKaUbU0rNKaXFwPcpnNPTLQ899BDgoSpJkoaCzEJOSmkTsAqY1dYWEUdS2IuzrMwsfwBSuUV1d52edCxJ0tCR9dVVNwGXRsTUiKgDrgPuTSmtLDP2FmBcRHyqeOn5TApXV/2oOytqbm7m97//PVC4fFySJOVb1iHnWuBu4FFgNVBJ4VwbImJORLRfHl682urtwMcoHM76L+DKlNId3VnRb37zG7Zu3cr06dM5+OCD+/ZVSJKkA06ml5CnlHYDlxSn0r7bgdtL2n4NvKYn67rzzjsBOOecc3oyuyRJGmSy3pMzYH784x8DhhxJkoaKIRFydu7cyerVqzn44IOZNWvWvmeQJEmD3pAIOVu3bgXg5JNPJiIyrkaSJA2EIRVyXvva12ZciSRJGihDKuScdNJJGVciSZIGypAIOdu2bQNg9uzZGVciSZIGyn5fQh4Rw4HDgBpgXUppXZ9X1cdSShx11FHU15feQUKSJOVVt/bkRMToiPhkRDxA4YaafwGeAF6IiFUR8a8RcUDf1nv69OlZlyBJkgbQPkNORHwGWAl8BPglhRtiHg8cDZwCXElhj9AvI+LnEfGK/iq2Nww5kiQNLd05XHUy8IaU0hNd9D8CfDciPgF8FHgD8Oc+qq/PGHIkSRpa9hlyUkrvbfs5Ir4FfCGl9HKZcTuBb/RteX3HkCNJ0tCyv1dXTQX+EhEXRUSm973aX0cffXTWJUiSpAG0XyEnpfQW4EPAJ4EnIuLt/VJVH6uvr6e6ujrrMiRJ0gDa7+/JSSndDUwHbgZ+UDzZ+JV9XlkfmjZtWtYlSJKkAdajLwNMKe1KKd0AvAJYBTweETdGxAkRMaJPK5QkSeqB/TqvJiJqgTOAVwLHFKdXAsOBC4ALgdaIeCaldFzflipJktR9+3vy8K+B44DHgD8BPwH+L/AM8CxQTeE7dI7vuxIlSZL23/6GnJHAySml5V307wIeLk6SJEmZ2a+Q4yEoSZI0WAyJu5BLkqShpzv3rpra3YVFwRG9K0mSJKn3urMnZ1FEfCciTulqQESMjYhPAk9SuIGnJElSprpzTs4rgS8CP4mIVmAJsAbYAYwFXgUcS+FGnRellO7tp1olSZK6bZ97clJKm1JKnwMagK8CTwFjKNzHqgW4FXhNSuk0A44kSTpQdPvqqpTS9oj4GjA3pXRRP9YkSZLUa/t7dVUA/zsi/hQRT0fEbRHxpv4oTJIkqTd6cgn5JGA+cBswCvhxRNwcEV6OLkmSDhj7+43HAOemlH7T9iQipgH3AJdSOGdHkiQpc/u792U98FLHhpTSX4BPAx/rq6IkSZJ6a39Dzh+AeWXan6Nw9ZUkSdIBYX8PV10OLIyIBuAbwDKgBvgSsKKPa5MkSeqx/b1B5yMRcRJwI/BL/mdP0HbgPX1cmyRJUo/t94nHKaUngDdGxDjgBKAS+H1K6eW+Lk6SJKmnenJ1FQAppQ3AL/qwFkmSpD7jd9tIkqRcMuRIkqRcMuRIkqRcMuRIkqRcMuRIkqRcyjTkRERlRFwfEesiYktEzI+I8d2Y75MRkSLi8oGoU5IkDT5Z78m5DHgncBJweLHttr3NEBGTgc8Cy/u3NEmSNJhlHXLmAdellFaklDYDnwfeWgwyXfkO8EXALx+UJEldyizkRMQYYBKwpK0tpfQs0AjM7GKe84GtKaU7BqRISZI0aPX4G4/7wOji4+aS9k1AXengiJhE4QahJ3dn4RExj+Id0ydNmtTzKiVJ0qCU5eGqLcXH+pL2MRT25pS6GfinlNLq7iw8pXRTSml2Smn2hAkTelGmJEkajDILOSmlTcAqYFZbW0QcSWEvzrIys7wJuCYi1kfEeuA04B8j4sGBqFeSJA0uWR6uArgJuDQiFgIbgOuAe1NKK8uMPaLk+Q+BB4F/7tcKJUnSoJR1yLkWGAs8CgwHfgnMBYiIOcC3U0qjAFJKz3ecMSJ2Ao0ppRcHtGJJkjQoREop6xr63ezZs9PixYuzLkOSJPWBiFiSUpq9r3FZf0+OJElSvzDkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXDLkSJKkXMo05EREZURcHxHrImJLRMyPiPFdjH17RPwqItZHxMaIeDAiXjfQNUuSpMEh6z05lwHvBE4CDi+23dbF2LHAvwDTgAnAD4CfRcQR/V2kJEkafLIOOfOA61JKK1JKm4HPA2+NiMmlA1NKt6eUFqSUNqWUWlJK3wSagBMHuGZJkjQIZBZyImIMMAlY0taWUnoWaARmdmP+GcB4YHkX/fMiYnFELF63bl3fFC1JkgaNLPfkjC4+bi5p3wTU7W3GiJgIzAduSCn9udyYlNJNKaXZKaXZEyZM6HWxkiRpcMky5GwpPtaXtI+hsDenrIg4DFgI/AL4x/4pTZIkDXaZhZyU0iZgFTCrrS0ijqSwF2dZuXkiYgrwIPCzlNKFKaXU/5VKkqTBKOsTj28CLo2IqRFRB1wH3JtSWlk6MCJeCTwE/HtK6ZKBLVOSJA02WYeca4G7gUeB1UAlMBcgIuZERFOHsZcCDcBFEdHUYZoz0EVLkqQDXwyFIz6zZ89OixcvzroMSZLUByJiSUpp9r7GZb0nR5IkqV8YciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4ZciRJUi4Ny7oASZIGi8bGRl566SV27dqVdSm5VlVVxcSJE6mrq+vVcgw5kiR1Q2NjIy+++CINDQ3U1NQQEVmXlEspJbZv387q1asBehV0PFwlSVI3vPTSSzQ0NFBbW2vA6UcRQW1tLQ0NDbz00ku9WpYhR5Kkbti1axc1NTVZlzFk1NTU9PqwoCFHkqRucg/OwOmL99qQI0mScsmQI0mScsmQI0lSzpxxxhlEBA888MAe7dOmTeOWW27pl3VefvnlvOY1r6G6upozzzyz7Jjrr7+ehoYGRo4cyZlnnsmKFSv6pZY2hhxJknJo3LhxXHLJJaSUBmR9Rx11FFdffTXz5s0r23/77bdz/fXXc/fdd7Nu3Tpe9apXcfbZZ7N79+5+q8mQI0lSDn384x/n+eef59///d8HZH3nnXce73jHOxg/fnzZ/ptuuonzzz+fWbNmUVtbyzXXXMOKFSt46KGH+q0mQ44kST0UEQM67Y+RI0dy9dVX84UvfIGdO3d2a55Xv/rVjBkzpsupN4Fk6dKlnHDCCe3PR40axSte8QqWLl3a42XuiyFHkqScOu+88xg1ahQ33nhjt8YvW7aMTZs2dTmdfvrpPa5ly5Yt1NfX79E2ZswYGhsbe7zMfTHkSJLUQymlAZ32V2VlJddffz3XXHMNGzZs6Id3oPtGjx7N5s2b92jbtGlTr+9PtTeGHEmScuxtb3sbJ554IldfffU+xx533HGMGjWqy+nBBx/scR0zZ87ksccea3/e1NTEn//8Z2bOnNnjZe5LpjfojIhK4Frgw8AI4BfA+Sml9V2Mfyvwz8CRwLPAZ1JKvxiYaiVJGpxuuOEGTjrpJIYPH77XcX/84x97vI5du3axe/duWlpaaG1tZceOHURE+zrnzZvHZz7zGf7+7/+eV77ylVx++eVMnTq1V4fA9iXru5BfBrwTOAnYAHwXuA14W+nAiDgS+BEwD/hP4H8BCyLiuJTSyoEqWJKkwWbmzJl84AMf6LfvyIHC1Vy33npr+/OamhomT57MypUrAZgzZw6rV6/m7/7u79i0aROnnHIKd911F5WVlf1WUwzU9fNlVx7xHHB1Suk7xedHAX8BpqSUnisZexXwtyml13VoexC4L6V01d7WM2PGjPTjH/94j7b6+nrGjRtHa2tr+wboaOzYsYwdO5aWlhZWrVrVqX/cuHHU19fT3NzM888/36l//Pjx1NXVsXPnzvbbxXc0ceJERo0axfbt21m7dm2n/kMOOYTa2lq2bdvGCy+80Kn/0EMPpaamhqamprJ3aW1oaGD48OE0Njayfn3nHWOHH3441dXVbN68uexx2kmTJjFs2DA2btzIxo0bO/VPmTKFiooKNmzY0OkYK8CRRx4JwPr16zudVFZRUcGUKVOAwl19m5qa9ugfNmwYkyZNAuCFF15g27Zte/RXVVVxxBFHALB27Vq2b9++R//w4cNpaGgAYPXq1Z2uKqipqeHQQw8F4L//+7873QCutraWQw45BIBVq1bR0tKyR/+oUaOYOHEiACtXrqS1tXWP/rq6uvZLKMt90ZWfPT974GdvMH72duzYwate9SpaWlrKfrdL2x6LXbt2ddo2ve2PCKqrqwFobm7udH5Ob/srKiqoqqoCKHslVm/7KysrGTZsGCklmpubu93/zDPPMHLkyE6fvaOOOmpJSml2pwWV1rWvAf0lIsYAk4AlbW0ppWeBRqDcAbqZHccWPdbFWCJiXkQsjojFvb2LqSRJGnwy25MTEUcAq4AjU0p/7dD+HPDFlNL3S8bfDzyUUrqiQ9tVwGkppfLfH100e/bstHjx4j6tX5I0tDz11FMce+yxWZcxpHT1nkfEgb0nB9hSfKwvaR9DYW9OufHdHStJkoa4zEJOSmkThT05s9raiicX1wHLysyytOPYotcU2yVJkvaQ9ffk3ARcGhFTI6IOuA64t4urpf4NmB0RH4iIqoj4AHACcGuZsZIk9blyJwyrf/TFe511yLkWuBt4FFgNVAJzASJiTkS0X/pQPCn5XcDlFA5RXQ78vZePS5IGwsiRI1m9enXZq5Ntx18NAAALmUlEQVTUd9qusFq9ejUjR47s1bIyvYR8oHjisSSpt1pbW1m/fj2bN2/udHm/+tawYcOor69n/PjxVFR03h/T3ROPs/4yQEmSBoWKigomTpzY/l1FOvBlfbhKkiSpXxhyJElSLhlyJElSLhlyJElSLhlyJElSLg2JS8gjYgvwp6zr0B7GA51vE6ysuD0OPG6TA4/b5MAxOaU0YV+Dhsol5H/qzvX0GjgRsdhtcuBwexx43CYHHrfJ4OPhKkmSlEuGHEmSlEtDJeTclHUB6sRtcmBxexx43CYHHrfJIDMkTjyWJElDz1DZkyNJkoYYQ44kScolQ44kScqlXIeciKiMiOsjYl1EbImI+RExPuu68ioi3h8RD0ZEY0S0lOl/a0T8MSK2R8QTEfHmkv5pEXFfRGyNiOcj4rMDV33+RMR1xfe7MSLWRMS/RsRBJWP+ISKejYhtEfH7iDihpH92RDxS7H82IuYO7KvIl4j4SkT8tbhNXoqI/4qISR363R4ZiYiKiPhtRKSIOLxDu9tkEMt1yAEuA94JnAS0fWhvy66c3NsIfAO4qLQjIo4EfgR8FagvPi6IiCnF/krgbuApYAJwNnBpRLxvIArPqd3AXGAcMJPCv4Fb2joj4nTgm8AngbHAfOCnEVFX7K8HflZsHwt8AvhWRJwycC8hd24Djk8p1QFTgFXAf4Db4wBwMbCtY4PbJAdSSrmdgOeAj3Z4fhSQKHwddOb15XUCzgBaStquAh4saXsQuKL4899Q+AUzqkP//wEWZv168jIBbwUaOzy/Fbitw/Og8Ef3Q8Xn5xX/DUWHMbcB38v6teRhAkYCNwAb3B6Zb4ujgWeB44t/Iw53m+Rjyu2enIgYA0wClrS1pZSeBRop/K9WA2smHbZF0WP8z7aYCTyTUmrqol+990ZgaYfne2yTVPgN/Th7bpPHi+1t3Ca9FBHnRsRmoAn4NHBlscvtkYGIqAC+C1wCbCrpdpsMcrkNOcDo4uPmkvZNQN0A16LC9tjbtthXv3ohIt5NYVf6pzs0u00ykFL6QUqpHjiUQsBZXuxye2Tj08ALKaUFZfrcJoNcnkPOluJjfUn7GAp7czSwtrD3bbGvfvVQRPwv4F+Bs1NKj3XocptkKKX0AoXtck/xhHC3xwCLiGnAZ4ELuxjiNhnkchtyUkqbKBw7ndXWVjz5tQ5YllVdQ9hSOmyLotfwP4dPlgJHR8TILvrVAxFxHvBt4B0ppYUl3Xtsk4gICuckdNwmx5fM4zbpW8MonJtzGG6PLJxO4UKHJyJiPYVDTQDLIuIC3CaDX9YnBfXnBHwR+BMwlUK4+SHw86zryusEVAIjgDcDLcWfR1A4We8oCicWfwCoKj5uBaZ0mPcp4EaghsIvjheB92f9ugbrBPxvYANwYhf9p1M4L+SNQDWFcxJeBOqK/WOAdcDniv1vLI4/JevXNhgnCv+pvBCYWHx+OLAA+CuFsOP2GPhtUlvcDm3TyRROPJ4NjHKbDP4p8wL69cUV/nDeAKynsFvxR8D4rOvK6wR8uPgLonSaUux/K/BHYHvx8c0l808D7i+GoTXAJVm/psE8Fd/7XcVfuu1TyZh/AFYUt8kjwAkl/ScW27cXx83N+nUN1qkYcn4KvEQh4K8GbgeOcnscGBOFy/rbr65ymwz+yRt0SpKkXMrtOTmSJGloM+RIkqRcMuRIkqRcMuRIkqRcMuRIkqRcMuRI6pGIuCUi7sm6jo4i4p0R8eeIaImIW/pxPWdERIqI8f24jksiYmV/LV8aCgw50iBUDBgpIr5U0t7vf3wPcN8B5gOT2fM+XX3ttxTuPbWhH9chqZcMOdLgtQP4XERMyLqQvhQRVT2cbwwwDrg3pbQ6pVR648Q+k1JqTim9kPyiMemAZsiRBq+FwErgS10NKLdnJyKmFNtml4x5W0QsiYjtEfFgRBweEW+IiKUR0RQR90TEuDLruDwiXiyO+V5E1HToi4j4fEQ8W1zu8oiYW6aWD0TEryJiO3B+F69lbETcGhEbi8u6LyKOa3sNwMbi0F8Vl3lGF8upjojrIuL5iNgWEY9GxFvKvGdnRcQfImJH8X05oav3NSLqI+K2iHipOH5FRFzUYfykiFgQEVuK048i4vCSuj4fES8U38d/o3BbgdLaz4uIJ4vreCYiLo6Iig795xfbd0TE+oi4NyKGlXsfpKHAkCMNXq3AZcAnIuKoPljeVcBFwEnAWOAO4MvAPOAM4DjgypJ53gDMpHDPnndTuG/ZdR36/wn4KPAp4FXAV4FvR8TflSznq8A3imPu7KK+W4q1vRN4LYXbf/y8GKp+W6yPYh2HFtvK+V6x7nOB6cCtwN0RMbNk3A3ApRTuY7SCwt3Ca7tY5j8BM4CzgGOAj1C4bQPFEPJj4GDgb4rTYcCdxRs+EhHvLS7jCgo3hPwT8JmOK4iIjwPXUNgmx1K4e/alwAXF/tnA/6OwHY+hsE1+3kW90tCQ9X0lnJyc9n+i8Af/nuLPC4H/KP58BoV774wv97zYNqXYNrtkzFs6jLmw2DarQ9uVwBMlNWwCRnVomwvspHBn7ZEU7ufzupLa/z/gpyW1fHYfr/cVxXGv79BWD2wGPlZ8Pr445oy9LOcoCuFwUkn7ncA3St6POR36RxVf68dKxrS9z3cB3+1inW8CdlO8h1ux7chiHWcWn/8W+NeS+e4DVnZ4vgr4YMmYi4Aniz+/q/h+jM768+nkdKBM7saUBr9LgUURcX0vl7Osw88vFh+Xl7RNLJ0npdTU4fkiCndjPgoYTuEu9D+PiI7nrlRROMzW0eJ91HYshVCwqK0hpbQ5IpZT2PvTXbOAAJ4s7kRpMxz4VcnYjutq2se6vgn8V/GQ1i+Bu1NKv+lQ+5qU0soOy1sREWuKy7uvOObmMuufBlA87+oICnvBvtlhzLDi66G43ueAv0bEvcAvgB+llLZ0UbOUe4YcaZBLKT0SEfOBrwH/p6S7tfjY8S96Vyf27uq42OKyS9v25xB329h3UNgL0dW6oHBX7p7an5N/K4rjTyxTw/YeF5DSzyJiMvA2CoeJfhIRP0wpnbevWbu5irb38hN0cRgupbQlImYBr6ew9+gfgWsi4sSU0ppurkfKFc/JkfLhC8DrgLeWtK8rPh7aoe34PlzvjIgY2eH5yUAz8CzwJIVDV5NTSn8pmZ7bz/U8ReH31SltDRFRR+E8mCf3YzmPUwh8h5SpaXXJ2JM7rGskhfN3nupqwSml9Sml21JKH6ZwHtKHImJ4cZ7DImJKh+UdSeG8nLban+q4vtL1p5ReBNYAR5Wp+y8dxrWklH6VUvpH4NUUDhmete+3Rcon9+RIOZBS+ktE3ETn74b5C/DfwJURcRmFc2Au78NVDwO+GxFXU/ijfS2Fc0u2AkTEDcANxRNsH6BwbsvJQGtK6aburiSl9OeI+DGFwzXzKJwf8xWgEfjBfiznmYi4HbglIj4LPAYcROEcmxUppR91GH55RKyjEC6+TCG8lV1X8fU/BvyRwnvyruLydkbEfRQOBd4eEW3b51+K49sOkd0I/FtEPAr8GngPhZOsX+6wmiuAf4mITcBPKeyRmwU0pJS+GhFnUThM+EBxvr8BRrOXYCblnXtypPy4Gmjp2FA83PR+Cie6LqVw5c0X+nCdv6Hwh30hsIDCH+3Pd+j/EoUTli8pjvslhauf/tqDdZ0HPELhJN9HgFrgrSml/T3MdB6FK6y+BjwN3EPhEE/p3qXLgH+mEEZeAZzVFt7K2EkhdC0FHqYQLt4BkFJKFK4IW0fhfVoIvACcU+wjpXQHhffpKxT2Ns0A/m/HFaSUbqZw1dYHi+t5kMKVb23v5SbgHArn+DxN4T3/WErpwW6+L1LuRPHfmCSJ9u/cWQhMSCmtz7gcSb3gnhxJkpRLhhxJkpRLHq6SJEm55J4cSZKUS4YcSZKUS4YcSZKUS4YcSZKUS4YcSZKUS4YcSZKUS/8/fCmhBG7fbD4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 648x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "alphas = np.asarray([frame.sig_alpha for frame in alpha_frames])\n",
    "\n",
    "fig = plt.figure(figsize=(9, 5))\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "ax.tick_params(axis='both', which='major', labelsize=13)\n",
    "ax.axhline(1, c='lightgray', ls='--')\n",
    "ax.axhline(0, c='lightgray', ls='--')\n",
    "ax.plot(alphas, lw=2, color='k', label='N = {0}'.format(10))\n",
    "\n",
    "ax.set_xlim([0, opt.ALPHA_NUM_ITER - 1])\n",
    "ax.set_xlabel('Number of episodes', fontsize=14)\n",
    "ax.set_ylabel(r'$\\sigma(\\gamma)$', fontsize=14)\n",
    "ax.legend(loc=4, prop={'size': 13})\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
