{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '../..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm import tnrange, tqdm_notebook\n",
    "\n",
    "from causal_meta.utils.data_utils import generate_data_categorical\n",
    "from models import StructuralModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10\n",
    "model = StructuralModel(N, dtype=torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.modules_parameters(), lr=1e-1)\n",
    "meta_optimizer = torch.optim.RMSprop([model.w], lr=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 1 # 10\n",
    "num_training = 1 # 100\n",
    "num_transfer = 1000\n",
    "num_gradient_steps = 2\n",
    "\n",
    "train_batch_size = 1000\n",
    "transfer_batch_size = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4c1ad719f3454d1fb51e02767d45188b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "alphas = np.zeros((num_runs, num_training, num_transfer))\n",
    "\n",
    "for j in tnrange(num_runs):\n",
    "    model.w.data.zero_()\n",
    "    for i in tnrange(num_training, leave=False):\n",
    "        # Step 1: Sample a joint distribution before intervention\n",
    "        pi_A_1 = np.random.dirichlet(np.ones(N))\n",
    "        pi_B_A = np.random.dirichlet(np.ones(N), size=N)\n",
    "        \n",
    "        transfers = tnrange(num_transfer, leave=False)\n",
    "        for k in transfers:\n",
    "            # Step 2: Train the modules on the training distribution\n",
    "            model.set_ground_truth(pi_A_1, pi_B_A)\n",
    "            # Step 3: Sample a joint distribution after intervention\n",
    "            pi_A_2 = np.random.dirichlet(np.ones(N))\n",
    "\n",
    "            # Step 4: Do k steps of gradient descent for adaptation on the\n",
    "            # distribution after intervention\n",
    "            model.zero_grad()\n",
    "            loss = torch.tensor(0., dtype=torch.float64)\n",
    "            for _ in range(num_gradient_steps):\n",
    "                x_train = torch.from_numpy(generate_data_categorical(transfer_batch_size, pi_A_2, pi_B_A))\n",
    "                loss += -torch.mean(model(x_train))\n",
    "                optimizer.zero_grad()\n",
    "                inner_loss_A_B = -torch.mean(model.model_A_B(x_train))\n",
    "                inner_loss_B_A = -torch.mean(model.model_B_A(x_train))\n",
    "                inner_loss = inner_loss_A_B + inner_loss_B_A\n",
    "                inner_loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "            # Step 5: Update the structural parameter alpha\n",
    "            meta_optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            meta_optimizer.step()\n",
    "\n",
    "            # Log the values of alpha\n",
    "            alpha = torch.sigmoid(model.w).item()\n",
    "            alphas[j, i, k] = alpha\n",
    "            transfers.set_postfix(alpha='{0:.4f}'.format(alpha), grad='{0:.4f}'.format(model.w.grad.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAFHCAYAAABDK11BAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xt8VOW97/HPLyTkQkgIBAyXhMhNhSqIseKFlqIVRQXrhX2sWFFavPZsOfWC17J73CKire5a6xZUkKq1ipSqLYgVLy0oFxEoSBDYIRAuSSDkAoEE8pw/ZpKTTBJIIMnKrHzfr9d6TeZ5nlnzm1kj8/VZa80y5xwiIiIifhPhdQEiIiIizUEhR0RERHxJIUdERER8SSFHREREfEkhR0RERHxJIUdERER8SSFHREREfEkhR0RERHxJIUdERER8KdLrAlpCcnKyS09P97oMERERaQKrVq3Kd851Pd64NhFy0tPTWblypddliIiISBMws20NGafdVSIiIuJLCjkiIiLiSwo5IiIi4ksKOSIiIuJLCjkiIiLiSwo5IiIi4ksKOSIiIuJLnoYcM/tfZva5mRWZ2ZEGjM8ws+VmdtDMtpjZ+JaoU0RERMKP1zM5BcALwD3HG2hmicDfgHlAEnA78KKZnd+sFYqIiEhY8vQXj51ziwDMbEQDhl8DHASecs45YLGZzQcmAcuO9cDDhw+zdevWGm2JiYl06dKFiooKsrKyaj0mKSmJpKQkjhw5QnZ2dq3+Ll26kJiYSFlZGTt27KjVn5ycTEJCAocPHyYnJ6dWf7du3YiPj6e0tJRdu3bV6k9JSSEuLo6DBw+ye/fuWv3du3cnNjaWkpIScnNza/X37NmT6OhoioqKyM/Pr9Xfq1cv2rdvT2FhIXv37q3Vn5aWRmRkJAUFBRQUFNTqT09PJyIigr1791JYWFirv0+fPgDk5+dTVFRUoy8iIoLKy2zk5uZSUlJSoz8yMpK0tDQAdu/ezcGDB2v0R0VFkZqaCsCuXbsoLS2t0R8dHU3Pnj0ByMnJ4fDhwzX6Y2Nj6d69OwDbt2+nvLy8Rn9cXBwpKSkAZGdnc+RIzUnG+Ph4unXrBkBWVhYVFRU1+hMSEkhOTgao9bkDffb02dNnT589ffZCnexnrz5ez+Q0xmBgdTDgVPoq2F6LmU0ys5VmtjJ0Y4qIiIj/Wc3M4FERgZmcj5xz9c4smdnLQKRz7uZqbbcADzvn+h1r/RkZGU7XrhIREQHnHEePHqW8vLzGUlZW1qC2xoytbDty5EiT3R45coTc3NxVzrmM473WcLpAZzGQHtLWCSiqPVRERKTlVFRUcOjQIQ4ePMjBgwcpLS2t98u/vvtlZWWUlZVx+PDhY96eaNCovrSGCY6WEE4hZw1wdUjb2cF2ERGRGsrLy2uEjsOHD1cFhbKyMkpLS6v6qy/1tR+r79ChQ16/3EYxM9q3b09UVFTVEnr/WO2NaYuMjKy6X/n3yd5WHj90PJ6GHDNrB0QB7YP3Y4Jdh13tmDkfeMrM7gOeA4YTOBj5hy1UroiINJGKigoOHjzIgQMHqoJC5d8lJSUcOHCg6rb6mMYsoQfPNrfY2Fji4uKIjY0lJiam6kv/WGEitC86Opro6Gjat29f721ThI927dq16HvjFa9ncm4CXq12v/Jw8VPNLJXAKeMDnXPZzrn9ZjYa+B3wK2AXcLtz7phnVomIyMlxzlFaWkphYWHVUhk8SktLq5aSkhKKiopqLMXFxTUCS3FxcVVbc4uIiKBDhw7ExcURExNDTExMVVBo3759VSCJi4urc2lMX0xMDBER4XQuT9vg9Snks4HZ9XRnAfEh41cA323WokREfObQoUM1AkphYSH79++v1VZf+/79+5tlViQuLq4qhFT/u0OHDnTo0IH4+Hji4+Or7tcXOOoLH1FRUZhZk9ct4cPrmRwRETmOiooKSkpKqpbKY0ry8vLYs2dP1bJ7925yc3MpKCioEVJCfzPlRERHR9OpUycSExNJTEwkPj6e2NjYqiUmJoaOHTuSkJBQY+nYsWONoNKxY0c6duxIhw4dNPMhzU4hR0SkhVRUVFBUVERubm5VMMnPzycvL6/GbX5+ftVuncrdPCcjKiqqRkCpXOpqq689Ojq6id4FkZajkCMicoIOHDhATk4OO3furLHs3r27xm6foqIiCgsLKS4uPuFTdytnQTp06FB1XEmXLl1ISUnhlFNO4ZRTTiElJYVu3brRuXPnGgElJiZGu22kTVLIEREJqqioIC8vj127drF9+3ays7PZsWMHubm55OXlsW/fvqqf/N+3b98J7Qaq/Hn8ymDStWtXunbtSnJyctVt5U/YV9/Vo107Io2nkCMivldRUUF+fj65ubk1lj179lTNvuTk5LBly5ZG/d5J5fWCevbsSffu3enRowc9evQgJSWFpKSkqpmUhISEqtu2cuquSGugkCMivnDo0CE2bdpEVlYW27ZtY/PmzXz77bds3bqVrKysBs+6dO7cmR49etCzZ0969+5NamoqKSkpJCcn07lzZzp37lx1McHY2FjtBhJpxRRyRKTVq6ioYP/+/WzcuJGdO3eSn59PTk4O2dnZ7Ny5k+zsbDZv3lzrysjVde7cueqYla5du9KtWze6detWNfvSo0cP0tPT6dSpUwu+MhFpTgo5IuKpkpISduzYwfbt29m+fTtZWVmsWLGCPXv2UFBQUHUA7/EO2I2IiGDAgAH069eP1NRU+vbtS//+/enbty/p6el07NixhV6RiLQWCjki0qxKSkrYs2cPO3bsYMOGDZSUlJCTk8OaNWtYu3Yt+/bta9B6OnbsSL9+/UhPT6dLly707NmT1NTUqmNi+vfvT0xMzPFXJCJthkKOiDSJo0ePkp2dTWZmJpmZmWzcuJG1a9fyxRdfHHM3UnR0NL169aJXr16kpaWRmprKkCFD6NOnD0lJSXTq1ImEhAQiI/XPlYg0jv7VEJFGOXToEOvXr2fVqlVkZmaSnZ3Nxo0b2bRpE2VlZbXGR0VFkZaWRpcuXTjrrLPo0qVL1d+DBw+mR48eOnhXRJqFQo6I1CsvL48vv/ySr7/+mtWrV7NmzRq2bt1a7/ExPXv2ZMCAAZx22mlVywUXXEBiYmILVy4iopAjIkH79u1jyZIlfPXVV2zYsIFvvvmGzMzMWuPatWtH//79GTp0KN/5znfo3bt3VbDRwb0i0poo5Ii0ESUlJVW/I7N37172799PZmYmmzZtYufOnXXO0MTExDBs2DCGDh3K2WefzZAhQxgwYADt27f36FWIiDScQo6ID+Xk5PDFF1/wzTffsGLFClasWEFubi5Hjx6t9zHt27fnggsu4MILL2TQoEGcccYZnH766TpjSUTClkKOSJjav38/y5YtY/369axfv54NGzawZcsWSktLOXjwYK3xZsagQYPo06cP3bp1o2PHjvTt25fTTz+d1NRU0tLSiI2N9eCViIg0D4UckTCRn5/P+++/z7Jly1i3bh3Lly+vd2YmMTGR8847j8GDBzNw4ECGDx9Or169iI6ObuGqRUS8o5Aj0grl5uaydOlS1q5dy7p161i3bh2bNm2qccxMZGQkF154IWeffTaDBg1i4MCBDBgwgA4dOuiq1SIiKOSItBpZWVm8/PLLvPfee6xZs6ZWf1RUFCNHjuTiiy/m9NNPZ8SIETqbSUTkGBRyRDzinOObb77hD3/4A7NmzSIvL6+qLyYmhgsuuIChQ4dy1llnceaZZ3LGGWdod5OISCMo5Ii0oE2bNjFr1iy+/vpr1qxZQ25ublVfp06duOSSS5g0aRLDhw/XWU0iIidJIUekGe3YsYOXX36ZL7/8kszMTLZu3Vqj/5RTTuGSSy7h5ptv5pJLLtHlDUREmpBCjkgT2b59O1u2bCErK4tPP/2UZcuW1frF4IiICMaOHcvNN9/MWWedRXp6uoKNiEgzUcgROQm7d+/m1Vdf5e2332b16tW1+uPj4xk5ciTjx49n0KBB9O3bV8fViIi0EIUckUbKzs5mzpw5vPHGG2zcuLGqPS4ujqFDh5KSksJ5553HRRddxDnnnENUVJSH1YqItF0KOSINcOTIkaqzoJYuXVr1ezUxMTGMGjWKn/3sZ4wcOVK/GCwi0ooo5IjUo7y8nM8//5xp06axdOnSqkslREdHc/XVV3PLLbcwcuRIzdSIiLRSCjki1ZSVlfH3v/+dBQsWsGDBAnbv3l3Vd9pppzFlyhSuueYaEhISPKxSREQaQiFH2rw9e/Ywc+ZMXnvtNb799tsafenp6Vx55ZVMmTKFnj17elShiIicCIUcaZMqKiqYNWsW//Vf/8X69etr9J111lmMGTOGsWPHcvbZZ9OuXTuPqhQRkZOhkCNtSm5uLs899xxvvPEGWVlZQOBClxdffDGTJ09mxIgROsVbRMQnFHLE95xzLFq0iFmzZrFw4UIOHDgAQGpqKg8++CC33HKLLqEgIuJDCjniW2VlZfzlL3/h+eef59NPP61qv+KKK7j33nv53ve+R0REhIcViohIc1LIEV9xzrFx40YWLVrECy+8UHUgcVJSEvfeey+jR49myJAhHlcpIiItQSFHfCEvL48XXniBV199lW3btlW19+3bl7vvvpsJEybQqVMnDysUEZGWppAjYSs/P5+33nqLDz/8kI8++qjqx/q6du3KpZdeyqWXXsp1111HXFycx5WKiIgXFHIkrJSWlvLaa6+xcOFCFi5cyKFDh6r6fvjDHzJ58mRGjRqlY21EREQhR8JDeXk5r7/+Og8//DA7d+6sah81ahRXXXUVV111FWlpaR5WKCIirY1CjrRqR48e5Y9//CNTp05l8+bNAAwaNIhbb72Va665hvT0dG8LFBGRVsvTkGNm7YAngQlADPAhcJtzLr+e8fcCdwDdgN3Ab5xzL7RMtdIS1q9fz7Jly3j//fdZvXo1eXl5lJaWAtC/f3/uuusufv7zn2t3lIiIHJfXMzlTgLHAecBe4BVgLnB56EAzGwP8B3Cxc+4LMzsf+MjMvnXOLW7BmqUZbNq0iWeeeYaZM2finKvRd+qpp/Loo49y0003ERnp9UdWRETChdffGJOAXznntgKY2f3AZjPr7ZzbFjK2H7DGOfcFgHNumZmtBQYDCjlhyjnHAw88wIwZM4DAJRbGjBnDd7/7Xa655hpSUlKIj4/HzDyuVEREwo1nIcfMOgFpwKrKNufcFjMrIhBcQkPOH4FbzexCYBlwITAAWNgyFUtT2rdvHy+//DKzZ89mw4YNAIwePZoZM2YwcOBAj6sTERE/8HImp2PwtjCkfT+QUMf4XOAdYAlQeUDGPc65f9W1cjObRGCmSGfdtCIfffQRM2fO5K9//SslJSUAdO7cmVmzZvGjH/3I4+pERMRPvAw5xcHbxJD2TkBRHeMfBW4AhgDfAAOBv5hZqXPu5dDBzrmXgJcAMjIyXGi/tIxDhw6xdOlSNmzYwPPPP09mZmZV34UXXsiUKVMYNWoUUVFRHlYpIiJ+5FnIcc7tN7NsYCjwNYCZ9SEwi7O2joecA8x3zm0I3l9vZn8GrgJqhRzx3ocffsikSZNqXGYhKiqK++67j+uuu44hQ4boWBsREWk2Xh94/BLwgJktIXB21XRgkXMuq46x/wQmmNks59y3ZnYGcDUwu6WKlYZZs2YNN998M2vWrAECZ0cNGzaM0aNHc9111xETE+NxhSIi0hZ4HXKeBJKAFUA0gbOkxgOY2Y3Afzvn4oNjZxDYtbXYzJKBfcDbwXWIx0pLS3nrrbf44IMPWLBgAeXl5cTGxvLoo49y33336dRvERFpcRb6myR+lJGR4VauXOl1Gb61bt06Ro8ezY4dO6raxo0bx6uvvqqLY4qISJMzs1XOuYzjjdP/XstJycrKYsyYMezYsYMzzzyTCRMmcNppp3HZZZfRrl07r8sTEZE2TCFHTlhBQQGXX345WVlZnHvuuXz22Wc63kZERFoNXQBIGu3o0aP88pe/pG/fvmzcuJHvfOc7LF68WAFHRERaFc3kSKOUl5czYcIE3njjDQDOOecc5s2bR2Ji6M8diYiIeEszOdJgX375Jeeffz5vvPEG8fHxLFq0iBUrVtC7d2+vSxMREalFIUeOa9++fUycOJFhw4axatUqevXqxUcffcSll16qH/MTEZFWS7urpF4bNmxg7ty5vPPOO2zevJmoqCjuvfdeHnroIeLj44+/AhEREQ8p5EgtBQUFTJw4kQULFlBRUQHAwIEDmT9/PgMGDPC4OhERkYZRyJEqhYWFvPjiizzxxBMUFRXRrl07xo0bx2WXXca4cePo0KGD1yWKiIg0mEKOALBkyRLGjRtHfn4+ABkZGcyZM4eBAwd6XJmIiMiJ0YHHwu9+9ztGjhxJfn4+5557LosWLWL58uUKOCIiEtYUctqwgoICJk+ezN133w3AlClT+Oyzz3TWlIiI+IJ2V7VRX375JRMnTmT9+vVAIOBMmzbN46pERESajkJOG5OXl8e0adN49tlncc7Rp08f5s2bx5AhQ7wuTUREpElpd1Ub8s477zBo0CB+85vf4Jxj8uTJLF++XAFHRER8STM5bcTs2bP56U9/ytGjR7nwwgv5zW9+w7nnnut1WSIiIs1GMzk+55xj4cKF3HnnnRw9epQHHniAzz77TAFHRER8TyHHx+bNm0evXr24/PLLKS0t5YYbbuDJJ58kIkKbXURE/E/fdj70zTffMHr0aK677jp27txJ165deeKJJ5g7d67XpYmIiLQYHZPjI845pk6dyrRp0ygvL6ddu3Y89thjPProo/rdGxERaXMUcnygoqKCV155hTlz5vCPf/wDgH/7t3/j8ccfp1+/fh5XJyIi4g2FHB948sknefjhhwGIi4vjzTffZMyYMR5XJSIi4i0dkxPmnn322aqA89vf/pYdO3Yo4IiIiKCZnLBWUVHBE088AcBjjz1WdQ0qERER0UxO2HLO8cgjj5CXl0fv3r2ZOnWq1yWJiIi0Kgo5YWrWrFlMmzaNiIgIpk6dqrOnREREQmh3VRj6+uuvuf/++4HA5RpuuukmjysSERFpfRRywsyyZcv4/ve/T3l5OZdddhnjx4/3uiQREZFWSburwsiOHTu45pprKC8v5/LLL+ftt9/WbioREZF6KOSEgbKyMu68805SU1PZvXs3Z511Fn/84x+Jj4/3ujQREZFWSyGnlXPO8dhjj/H73/8egFNOOYUFCxaQkJDgcWUiIiKtm0JOK/fggw8yffp0AN59911ycnJIT0/3tigREZEwoAOPWynnHK+//jrTp0/HzHj99df50Y9+5HVZIiIiYUMzOa3QoUOHuO2226pODR87diw33HCDx1WJiIiEF83ktDLbtm3jiiuuYP369URGRvL4449z1113eV2WiIhI2FHIaSUyMzN5++23eeqppyguLiY5OZlXXnmFq666yuvSREREwpJCTivwzDPPcO+991bdv/rqq5k5cybJyckeViUiIhLedEyOx+bMmVMVcK699loWL17M/PnzFXBEREROkmZyPFBQUMDChQtZtWoVzzzzDADTp0+vuh6ViIiInDyFnBZ24MABzj//fDIzM6vafvzjH3PPPfd4WJWIiIj/eLq7yszamdkMM8szs2Izm2dm9e6nMbNuZjbHzPaaWZGZfW1mPVqy5pM1efLkqoBz6aWX8vrrr/OHP/yB9u3be1yZiIiIv3g9kzMFGAucB+wFXgHmApeHDjSzGODvwBfAacA+4AygpKWKPVl/+tOfmDlzJtHR0axYsYIzzzzT65JERER8y+sDjycB051zW51zhcD9wGVm1ruOsTcDnYA7nXP5zrkK59x651xRSxZ8Ig4dOsS0adO49dZbAXj66acVcERERJqZZyHHzDoBacCqyjbn3BagCBhcx0N+AHwLzA7urtpoZpOPsf5JZrbSzFbm5eU1cfUNV1xczJgxY3jooYc4cOAA48eP14/7iYiItAAvZ3I6Bm8LQ9r3A3VdYjuZQNBZDnQHxgMPm9mNda3cOfeScy7DOZfRtWvXJiq5cY4cOcJtt93G4sWL6dy5M/Pnz2fOnDmYmSf1iIiItCVehpzi4G1iSHsnArM5dY3Pcc4955wrc86tBP5A4JieVicvL4/rrruON998k6ioKP7+979z9dVXExHh9R5CERGRtsGzb1zn3H4gGxha2WZmfQjM4qyt4yFfA66uVTVLgSfh888/Z/DgwSxYsIDY2FgWLlzIkCFDvC5LRESkTfF6WuEl4AEzO9XMEoDpwCLnXFYdY2cDXczsruCp54OBG4F3W6zaBigsLGTs2LHs2rWL4cOHs3btWkaOHOl1WSIiIm2O1yHnSeA9YAWQA7QjcKwNZnajmVWdHu6c2waMBn5KYHfWO8BU59xbLV10fY4ePcq4ceMoKCjgoosu4uOPP6Zfv35elyUiItImmXOtbm9Pk8vIyHArV65s1uf485//zF133cXOnTtJSEhg6dKlDBo0qFmfU0REpC0ys1XOuYzjjfP6xwB94U9/+hM33HADFRUVJCQkMGvWLAUcERERj3m9uyrsZWZm8pOf/ISKigp+/vOfs2fPHq6//nqvyxIREWnzNJNzEioqKpg4cSKHDx9mwoQJPPfcc/oNHBERkVZCMzknqKKignvuuYd//vOfpKSk8Otf/1oBR0REpBXRTM4JOHLkCBMnTuS1116jffv2zJ49m6SkJK/LEhERkWo0k9NIn3zyCWeccQavvfYaHTp04IMPPmDUqFFelyUiIiIhGj2TY2bRQA8gFshzznl39csWVlBQwNixYykqKqJLly68//77DBs2zOuyREREpA4Nmskxs45mdoeZfUbggpqbgX8Bu80s28xmmtm5zVloa/D8889TVFTEsGHD2Lp1qwKOiIhIK3bckGNm/wfIAm4FFhO4IOYQYABwPjCVwIzQYjNbaGb9m6tYL61bt44ZM2YA8MQTT5CQUNeF0kVERKS1aMjuqmHA951z/6qnfznwipndDkwEvg9820T1tQrvvvsud9xxB8XFxYwbN44RI0Z4XZKIiIgcx3FDjnNuXOXfZvYi8JBzbl8d4w4DLzRted5bunQp119/PRUVFVxyySXMmTNHp4qLiIiEgcaeXXUqsNnM7jEz359+vnr1aq6++moqKiq44447WLhwITExMV6XJSIiIg3QqJDjnBsF3AzcAfzLzEY3S1WtQFlZGddffz15eXmMGDGCZ599lnbt2nldloiIiDRQo2djnHPvmdlC4N+BN8zsC+Ae59zGJq/OA/v27WPBggV88sknbNmyhdNPP52FCxfSvn17r0sTERGRRjihXU7OuXLgaTObA/wnsNrMXgJeA9Y75w41YY0tZtWqVVx77bVs27YNADNjxowZREdHe1yZiIiINFajQo6ZxQEjgNOB04LL6UA0cCdwN1BhZpucc4OattTmNXv2bG677TbKysro1asXw4YNY8KECVxxxRVelyYiIiInoLEzOZ8Ag4CvgEzgA+DXwCZgC9CewG/oDGm6EptXUVERL774IlOmTME5x+23386zzz6r2RsREZEw19iQ0wEY5pxbV09/OfDP4NJqOedYsWIF999/P59++mlV+3333cdTTz3lYWUiIiLSVBoVcsJtF1RdPvvsM5566ik++OCDqrazzz6bX/ziF9xwww0eViYiIiJNyfe/dVPdkiVLuPjii3HO0aFDB66//nquuuoqxowZQ2Rkm3orREREfO+43+xmdqpz7n8asjIL/BRwL+fc9pOurIlt3ryZG2+8EeccV155JS+88AKpqalelyUiIiLNpCE/BrjMzF42s/PrG2BmSWZ2B7CBwAU8W5XS0lIyMjLYtWsXw4cPZ/78+Qo4IiIiPteQfTSnAw8DH5hZBbAK2AkcApKAgcAZBC7UeY9zblEz1XrCvv32W8rLy/nBD37Am2++qV1TIiIibYA55xo20CyWwG/hpAK9gVggH1gNLDrGVco9Z2Zu+PDhfPjhh7r2lIiISJgzs1XOuYzjjmtoyAmu9Cgw3jn35skU19JiY2NdTk4OnTt39roUEREROUkNDTmNvQq5Af/bzDLNbKOZzTWzH55YiS2nf//+CjgiIiJtTGNDDkAaMA+YC8QDC8xslpmdyLpahC6uKSIi0vacyBG4P3bOVf1MsJn1A94HHgCmNVVhIiIiIiejsbMv+UBu9Qbn3Gbg34GfNlVRIiIiIiersSHna2BSHe3bgJ4nX46IiIhI02js7qpHgCVm1hN4AVhL4FTyR4GtTVybiIiIyAlr7AU6l5vZecBzwGL+/0xQKXBdE9cmIiIicsIafeBx8Ef/LjazLsA5QDvgS+fcvqYuTkREROREnfD1DZxze4EPm7AWERERkSbTan/bRkRERORkKOSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiIL3kacsysnZnNMLM8Mys2s3lmltyAx91hZs7MHmmJOkVERCT8eD2TMwUYC5wH9Aq2zT3WA8ysN/ALYF3zliYiIiLhzOuQMwmY7pzb6pwrBO4HLgsGmfq8DDwM6McHRUREpF6ehRwz6wSkAasq25xzW4AiYHA9j7kNOOCce6tFihQREZGwdcK/eNwEOgZvC0Pa9wMJoYPNLI3ABUKHNWTlZjaJ4BXT09LSTrxKERERCUte7q4qDt4mhrR3IjCbE2oW8LhzLqchK3fOveScy3DOZXTt2vUkyhQREZFw5FnIcc7tB7KBoZVtZtaHwCzO2joe8kPgCTPLN7N84ELgQTP7vCXqFRERkfDi5e4qgJeAB8xsCbAXmA4scs5l1TE2NeT+28DnwDPNWqGIiIiEJa9DzpNAErACiAYWA+MBzOxG4L+dc/EAzrkd1R9oZoeBIufcnhatWERERMKCOee8rqHZZWRkuJUrV3pdhoiIiDQBM1vlnMs43jivfydHREREpFko5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeR8GD8EAAARCElEQVSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC95GnLMrJ2ZzTCzPDMrNrN5ZpZcz9jRZvaxmeWbWYGZfW5mw1u6ZhEREQkPXs/kTAHGAucBvYJtc+sZmwT8FugHdAXeAP5mZqnNXaSIiIiEH69DziRgunNuq3OuELgfuMzMeocOdM697pyb75zb75w74pz7PVACnNvCNYuIiEgY8CzkmFknIA1YVdnmnNsCFAGDG/D4M4FkYF09/ZPMbKWZrczLy2uaokVERCRseDmT0zF4WxjSvh9IONYDzawbMA942jn3bV1jnHMvOecynHMZXbt2PeliRUREJLx4GXKKg7eJIe2dCMzm1MnMegBLgA+BB5unNBEREQl3noUc59x+IBsYWtlmZn0IzOKsresxZpYOfA78zTl3t3PONX+lIiIiEo68PvD4JeABMzvVzBKA6cAi51xW6EAzOx34B/Cmc+7eli1TREREwo3XIedJ4D1gBZADtAPGA5jZjWZWUm3sA0BP4B4zK6m23NjSRYuIiEjrZ21hj09GRoZbuXKl12WIiIhIEzCzVc65jOON83omR0RERKRZKOSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC8p5IiIiIgvKeSIiIiILynkiIiIiC95GnLMrJ2ZzTCzPDMrNrN5ZpZ8jPGXmdl6Mys1s3+Z2aUtWa+IiIiED69ncqYAY4HzgF7Btrl1DTSzPsC7wDQgMXg738zSm71KERERCTuRHj//JOBXzrmtAGZ2P7DZzHo757aFjL0ZWOWc+0Pw/utmdnuw/T+O9SSHDx9m69atNdoSExPp0qULFRUVZGVl1XpMUlISSUlJHDlyhOzs7Fr9Xbp0ITExkbKyMnbs2FGrPzk5mYSEBA4fPkxOTk6t/m7duhEfH09paSm7du2q1Z+SkkJcXBwHDx5k9+7dtfq7d+9ObGwsJSUl5Obm1urv2bMn0dHRFBUVkZ+fX6u/V69etG/fnsLCQvbu3VurPy0tjcjISAoKCigoKKjVn56eTkREBHv37qWwsLBWf58+fQDIz8+nqKioRl9ERATp6ekA5ObmUlJSUqM/MjKStLQ0AHbv3s3Bgwdr9EdFRZGamgrArl27KC0trdEfHR1Nz549AcjJyeHw4cM1+mNjY+nevTsA27dvp7y8vEZ/XFwcKSkpAGRnZ3PkyJEa/fHx8XTr1g2ArKwsKioqavQnJCSQnByYkAz93IE+e/rs6bOnz54+e6FO9rNXH89mcsysE5AGrKpsc85tAYqAwXU8ZHD1sUFf1TMWM5tkZivNbGXoxhQRERH/M+ecN09slgpkA32cc/9TrX0b8HC1GZvK9r8D/3DO/bJa238AFzrnLjnWc2VkZLiVK1c2af0iIiLiDTNb5ZzLON44L4/JKQ7eJoa0dyIwm1PX+IaOFRERkTbOs5DjnNtPYCZnaGVb8ODiBGBtHQ9ZU31s0NnBdhEREZEavD676iXgATM71cwSgOnAIudcVh1jXwMyzOwGM4sysxuAc4A5LVeuiIiIhAuvQ86TwHvACiAHaAeMBzCzG82s6vDz4EHJ1wCPENhF9Qjwo3oCkYiIiLRxnh143JJ04LGIiIh/hMOBxyIiIiLNRiFHREREfEkhR0RERHxJIUdERER8SSFHREREfEkhR0RERHypTZxCbmbFQKbXdUiVZKD2JYLFS9omrY+2Seui7dG69HbOdT3eoMiWqKQVyGzI+fTSMsxspbZH66Jt0vpom7Qu2h7hSburRERExJcUckRERMSX2krIecnrAqQGbY/WR9uk9dE2aV20PcJQmzjwWERERNqetjKTIyIiIm2MQo6IiIj4kkKOiIiI+JKvQ46ZtTOzGWaWZ2bFZjbPzJK9rsuPzGy6ma03syIz22lmM82sc8iYn5jZFjM7aGZfmtk5If0ZZrY82L/FzMa37KvwJzOLMLOlZubMrFe1dm0PD5jZJWb2hZmVmFm+mb1QrU/bpIWZWYqZvRX8nigws4/NbHC1fm2TMObrkANMAcYC5wGV/7jP9a4cXzsKjAe6AIMJvN+zKzvN7CLg98AdQBIwD/irmSUE+xOBvwXbk4DbgRfN7PyWewm+NRk4WL1B28MbZjYCeAd4msB/K72AWcE+bRNvvAB0BgYApwArgfctQNsk3DnnfLsA24CJ1e73BRyBn4P2vD4/L8BlQFG1+3OAudXuG5AN3By8f0twe1m1MXOBV71+LeG8EPiHewswJPjZ76Xt4en2WAY8WU+ftok322QtMKna/dOC/60ka5uE/+LbmRwz6wSkAasq25xzW4AiAjMN0rwuBtZUuz+YmtvCAav5/9tiMLA62F7pK7StTpiZRQCvAPcC+0O6tT1amJl1AL4LRJrZV8FdVZ+YWeWlArRNvDEDuNbMuppZDDAJ+IdzLh9tk7Dn25ADdAzeFoa07wcSWriWNsXMriUwbfvv1Zo7cuxtcbx+abx/B3Y75+bX0aft0fKSCPybewMwAegBfEhg90cntE288k+gHZALlADXAD8L9mmbhDk/h5zi4G1iSHsnArM50gzM7HpgJjDGOfdVta5ijr0tjtcvjWBm/YBfAHfXM0Tbo+VV/pv0qnNurXOuDJgGRAEXoG3S4oKznR8Bmwi8t3HAfwKfm9kpaJuEPd+GHOfcfgL7TodWtplZHwIJe61XdfmZmd0C/DdwlXNuSUj3GmpuCyNwnMiaav1DQh5zNjV3eUnDXQR0Bf5lZvkEptAB1prZnWh7tDjnXCGQReB4jxpdwUXbpOV1Bk4FfuucK3LOlTnnZhH4bjwfbZPw5/VBQc25AA8DmQQ+xAnA28BCr+vy4wL8b2AvcG49/RcRmAq+GGhP4DiRPUBCsL8TkAfcF+y/ODj+fK9fWzguBP6PtFe1ZRiBL9IMIF7bw7Ptch+wAxgIRAL3A7sIzAZom3izTTKB3wIdgtvkVqAM6KNtEv6L5wU064sL7Gd9GsgnMK34LpDsdV1+XIJfoOXB/8CrlpAxPwG2AqXAcuCckP5zg+2lwXHjvX5dflmAdKqdXaXt4dl2MOBXwG4Cx24sAYZom3i6Tc4A3g9+TxQSONB4rLaJPxZdoFNERER8ybfH5IiIiEjbppAjIiIivqSQIyIiIr6kkCMiIiK+pJAjIiIivqSQIyInxMxmm9n7XtdRnZmNNbNvzeyImc1uxucZYWbOzJKb8TnuNbOs5lq/SFugkCMShoIBw5nZoyHtzf7l28q9DMwDelPz2mlNbSnQncAPYIpIK6WQIxK+DgH3mVlXrwtpSmYWdYKP6wR0ARY553Jc4DIKzcIFfv5/t9MPjYm0ago5IuFrCYFrIT1a34C6ZnbMLD3YlhEy5nIzW2VmpWb2uZn1MrPvm9kaMysxs/fNrEsdz/GIme0JjnnVzGKr9ZmZ3W9mW4LrXWdm4+uo5QYz+9jMSoHb6nktSWY2x8wKguv6yMwGVb4GoCA49OPgOkfUs572ZjbdzHaY2UEzW2Fmo+p4z640s6/N7FDwfTmnvvfVzBLNbK6Z5QbHbzWze6qNTzOz+WZWHFzeNbNeIXXdb2a7g+/jawQuvxFa+y1mtiH4HJvMbHLwIpOV/bcF2w+ZWb6ZLTKzyLreB5G2QCFHJHxVAFOA282sbxOs7z+Ae4DzgCTgLeAxYBIwAhgETA15zPeBwQSu2XMtcCkwvVr/48BE4C4C12uaBvy3mV0Rsp5pwAvBMX+up77ZwdrGAt8FDgILg6FqabA+gnV0D7bV5dVg3T8GvgPMAd4zs8Eh454GHiBwva+twPtmFlfPOh8HzgSuBE4jcP2jHKi60vUC4BTgB8GlB/Dn4AUfMbNxwXX8ksAFITOB/1P9CczsZ8ATBLbJGQSuMv8AcGewPwP4HYHteBqBbbKwnnpF2gavryuhRYuWxi8EvvDfD/69BPhj8O8RBK5RlVzX/WBberAtI2TMqGpj7g62Da3WNhX4V0gN+4H4am3jgcMELnbYgcD1fIaH1P4s8NeQWn5xnNfbPzjue9XaEglca+inwfvJwTEjjrGevgTCYVpI+5+BF0Lejxur9ccHX+tPQ8ZUvs9/AV6p5zl/CBwF0qu19QnWcUnw/lJgZsjjPgKyqt3PBm4KGXMPsCH49zXB96Oj159PLVpay6JpTJHw9wCwzMxmnOR61lb7e0/wdl1IW7fQxzjnSqrdX0bgasx9gWgghsBsS/VjV6II7GarbuVxajuDQChYVtngnCs0s3UEZn8aaiiBi2RuCE6iVIoGPg4ZW/25So7zXL8H3gnu0loMvOec+7Ra7Tudc1nV1rfVzHYG1/dRcMysOp6/H0DwuKtUArNgv682JjL4egg+7zbgf8xsEfAh8K5zrriemkV8TyFHJMw555ab2TzgKeD/hnRXBG+rf6PXd2BvefXVBtcd2taYXdyVY68iMAtR33MBHGjEekM15uDfiOD4c+uoofSEC3Dub2bWG7icwG6iD8zsbefcLcd7aAOfovK9vJ16dsM554rNbCjwPQKzRw8CT5jZuc65nQ18HhFf0TE5Iv7wEDAcuCykPS94271a25AmfN4zzaxDtfvDgDJgC7CBwK6r3s65zSHLtkY+zzcE/r06v7LBzBIIHAezoRHrWU0g8KXUUVNOyNhh1Z6rA4Hjd76pb8XOuXzn3Fzn3AQCxyHdbGbRwcf0MLP0auvrQ+C4nMrav6n+fKHP75zbA+wE+tZR9+Zq44445z52zj0InEVgl+GVx39bRPxJMzkiPuCc22xmL1H7t2E2A9uBqWY2hcAxMI804VNHAq+Y2a8IfGk/SeDYkgMAZvY08HTwANvPCBzbMgyocM691NAncc59a2YLCOyumUTg+Jj/BIqANxqxnk1m9jow28x+AXwFdCZwjM1W59y71YY/YmZ5BMLFYwTCW53PFXz9XwHrCbwn1wTXd9jMPiKwK/B1M6vcPr8Njq/cRfYc8JqZrQA+Aa4jcJD1vmpP80vgt2a2H/grgRm5oUBP59w0M7uSwG7Cz4KP+wHQkWMEMxG/00yOiH/8CjhSvSG4u+l/ETjQdQ2BM28easLn/JTAF/sSYD6BL+37q/U/SuCA5XuD4xYTOPvpf07guW4BlhM4yHc5EAdc5pxr7G6mWwicYfUUsBF4n8AuntDZpSnAMwTCSH/gysrwVofDBELXGuCfBMLFVQDOOUfgjLA8Au/TEmA3cHWwD+fcWwTep/8kMNt0JvDr6k/gnJtF4Kytm4LP8zmBM98q38v9wNUEjvHZSOA9/6lz7vMGvi8ivmPB/8ZERISq39xZAnR1zuV7XI6InATN5IiIiIgvKeSIiIiIL2l3lYiIiPiSZnJERETElxRyRERExJcUckRERMSXFHJERETElxRyRERExJcUckRERMSX/h9eCNFaQr4dIAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 648x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "alphas_50 = np.percentile(alphas.reshape((-1, num_transfer)), 50, axis=0)\n",
    "\n",
    "fig = plt.figure(figsize=(9, 5))\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "ax.tick_params(axis='both', which='major', labelsize=13)\n",
    "ax.axhline(1, c='lightgray', ls='--')\n",
    "ax.axhline(0, c='lightgray', ls='--')\n",
    "ax.plot(alphas_50, lw=2, color='k')\n",
    "\n",
    "ax.set_xlim([0, num_transfer - 1])\n",
    "ax.set_xlabel('Number of episodes', fontsize=14)\n",
    "ax.set_ylabel(r'$\\sigma(\\gamma)$', fontsize=14)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
