{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Edge Probing Predictions Sandbox\n",
    "\n",
    "Use this notebook as a starting point for #datascience on Edge Probing predictions. The code below (from `probing/analysis.py`) will load predictions from a run, do some pre-processing for convenience, and expose two DataFrames for analysis.\n",
    "\n",
    "We load the data into Pandas so it's easier to filter by various fields, and to select particular columns of interest (such as `labels.khot` and `preds.proba` for computing metrics). For an introduction to Pandas, see here: https://pandas.pydata.org/pandas-docs/stable/10min.html "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os, re, json\n",
    "import itertools\n",
    "import collections\n",
    "from importlib import reload\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The latest runs are here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0m\u001b[01;34mcove-edges-constituent-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34mcove-edges-coref-ontonotes-conll\u001b[0m/\r\n",
      "\u001b[01;34mcove-edges-dep-labeling-ewt\u001b[0m/\r\n",
      "\u001b[01;34mcove-edges-dpr\u001b[0m/\r\n",
      "\u001b[01;34mcove-edges-ner-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34mcove-edges-spr1\u001b[0m/\r\n",
      "\u001b[01;34mcove-edges-spr2\u001b[0m/\r\n",
      "\u001b[01;34mcove-edges-srl-conll2012\u001b[0m/\r\n",
      "\u001b[01;34melmo-chars-edges-constituent-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34melmo-chars-edges-coref-ontonotes-conll\u001b[0m/\r\n",
      "\u001b[01;34melmo-chars-edges-dep-labeling-ewt\u001b[0m/\r\n",
      "\u001b[01;34melmo-chars-edges-dpr\u001b[0m/\r\n",
      "\u001b[01;34melmo-chars-edges-ner-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34melmo-chars-edges-spr1\u001b[0m/\r\n",
      "\u001b[01;34melmo-chars-edges-spr2\u001b[0m/\r\n",
      "\u001b[01;34melmo-chars-edges-srl-conll2012\u001b[0m/\r\n",
      "\u001b[01;34melmo-full-edges-constituent-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34melmo-full-edges-coref-ontonotes-conll\u001b[0m/\r\n",
      "\u001b[01;34melmo-full-edges-dep-labeling-ewt\u001b[0m/\r\n",
      "\u001b[01;34melmo-full-edges-dpr\u001b[0m/\r\n",
      "\u001b[01;34melmo-full-edges-ner-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34melmo-full-edges-spr1\u001b[0m/\r\n",
      "\u001b[01;34melmo-full-edges-spr2\u001b[0m/\r\n",
      "\u001b[01;34melmo-full-edges-srl-conll2012\u001b[0m/\r\n",
      "\u001b[01;34melmo-ortho-edges-constituent-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34melmo-ortho-edges-coref-ontonotes-conll\u001b[0m/\r\n",
      "\u001b[01;34melmo-ortho-edges-dep-labeling-ewt\u001b[0m/\r\n",
      "\u001b[01;34melmo-ortho-edges-dpr\u001b[0m/\r\n",
      "\u001b[01;34melmo-ortho-edges-ner-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34melmo-ortho-edges-spr1\u001b[0m/\r\n",
      "\u001b[01;34melmo-ortho-edges-spr2\u001b[0m/\r\n",
      "\u001b[01;34melmo-ortho-edges-srl-conll2012\u001b[0m/\r\n",
      "\u001b[01;34mfailed\u001b[0m/\r\n",
      "\u001b[01;34mglove-edges-constituent-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34mglove-edges-coref-ontonotes-conll\u001b[0m/\r\n",
      "\u001b[01;34mglove-edges-dep-labeling-ewt\u001b[0m/\r\n",
      "\u001b[01;34mglove-edges-dpr\u001b[0m/\r\n",
      "\u001b[01;34mglove-edges-ner-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34mglove-edges-spr1\u001b[0m/\r\n",
      "\u001b[01;34mglove-edges-spr2\u001b[0m/\r\n",
      "\u001b[01;34mglove-edges-srl-conll2012\u001b[0m/\r\n",
      "\u001b[01;34mopenai-edges-constituent-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34mopenai-edges-coref-ontonotes-conll\u001b[0m/\r\n",
      "\u001b[01;34mopenai-edges-dep-labeling-ewt\u001b[0m/\r\n",
      "\u001b[01;34mopenai-edges-dpr\u001b[0m/\r\n",
      "\u001b[01;34mopenai-edges-ner-ontonotes\u001b[0m/\r\n",
      "\u001b[01;34mopenai-edges-spr1\u001b[0m/\r\n",
      "\u001b[01;34mopenai-edges-spr2\u001b[0m/\r\n",
      "\u001b[01;34mopenai-edges-srl-conll2012\u001b[0m/\r\n",
      "scores.tsv\r\n"
     ]
    }
   ],
   "source": [
    "ls /nfs/jsalt/home/iftenney/exp/edges-20180913/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `elmo-chars` experiments probe the char CNN layer only (lexical baseline), while the `elmo-full` models use full ELMo with learned mixing weights. The run dir for each is just called \"run\" by default. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of examples: 276\n",
      "Number of total targets: 582\n",
      "Labels (20 total):\n",
      "['awareness', 'change_of_location', 'change_of_possession', 'change_of_state', 'change_of_state_continuous', 'changes_possession', 'existed_after', 'existed_before', 'existed_during', 'exists_as_physical', 'instigation', 'location_of_event', 'makes_physical_contact', 'partitive', 'predicate_changed_argument', 'sentient', 'stationary', 'volition', 'was_for_benefit', 'was_used']\n"
     ]
    }
   ],
   "source": [
    "import analysis\n",
    "reload(analysis)\n",
    "\n",
    "run_dir = \"/nfs/jsalt/home/iftenney/exp/edges-20180913/elmo-full-edges-spr2/run\"\n",
    "preds = analysis.Predictions.from_run(run_dir, 'edges-spr2', 'test')\n",
    "print(\"Number of examples: %d\" % len(preds.example_df))\n",
    "print(\"Number of total targets: %d\" % len(preds.target_df))\n",
    "print(\"Labels (%d total):\" % len(preds.all_labels))\n",
    "print(preds.all_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Top-level example info\n",
    "\n",
    "`preds.example_df` contains information on the top-level examples. Mostly, this just stores the input text and any metadata fields that were present in the original data. This is useful if you want to link the targets back to the text, but you shouldn't need it to compute most metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>idx</th>\n",
       "      <th>info.grammatical</th>\n",
       "      <th>info.sent-id</th>\n",
       "      <th>info.sent_id</th>\n",
       "      <th>info.source</th>\n",
       "      <th>info.split</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>idx</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1008</td>\n",
       "      <td>1008</td>\n",
       "      <td>SPR2</td>\n",
       "      <td>test</td>\n",
       "      <td>In a timid voice , he says : &amp;quot; If an airp...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1009</td>\n",
       "      <td>1009</td>\n",
       "      <td>SPR2</td>\n",
       "      <td>test</td>\n",
       "      <td>&amp;quot; Wonderful ! &amp;quot; Winston beams .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1017</td>\n",
       "      <td>1017</td>\n",
       "      <td>SPR2</td>\n",
       "      <td>test</td>\n",
       "      <td>&amp;quot; Our new lunar transportation system uti...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1023</td>\n",
       "      <td>1023</td>\n",
       "      <td>SPR2</td>\n",
       "      <td>test</td>\n",
       "      <td>They want to use LTS to tie into NASA &amp;apos; s...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1024</td>\n",
       "      <td>1024</td>\n",
       "      <td>SPR2</td>\n",
       "      <td>test</td>\n",
       "      <td>&amp;quot; We are so excited that the White House ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     idx  info.grammatical info.sent-id info.sent_id info.source info.split  \\\n",
       "idx                                                                           \n",
       "0      0               5.0         1008         1008        SPR2       test   \n",
       "1      1               5.0         1009         1009        SPR2       test   \n",
       "2      2               5.0         1017         1017        SPR2       test   \n",
       "3      3               2.0         1023         1023        SPR2       test   \n",
       "4      4               5.0         1024         1024        SPR2       test   \n",
       "\n",
       "                                                  text  \n",
       "idx                                                     \n",
       "0    In a timid voice , he says : &quot; If an airp...  \n",
       "1            &quot; Wonderful ! &quot; Winston beams .  \n",
       "2    &quot; Our new lunar transportation system uti...  \n",
       "3    They want to use LTS to tie into NASA &apos; s...  \n",
       "4    &quot; We are so excited that the White House ...  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.example_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Target info and predictions\n",
    "\n",
    "`preds.target_df` contains the per-target input fields (`span1`, `span2`, and `label`) as well as any metadata associated with individual targets. The `idx` column references a row in `example_df` that this target belongs to, if you need to recover the original text.\n",
    "\n",
    "The loader code does some preprocessing for convenience. In particular, we add a `label.ids` column which maps the list-of-string `label` column into a list of integer ids for these targets, as well as `label.khot` which contains a K-hot encoding of these ids. \n",
    "\n",
    "Each entry in `label.khot` should align to the corresponding entry in `preds.proba`, which contains the model's predicted probabilities $\\hat{y} \\in [0,1]$ for each class.\n",
    "\n",
    "For specific analysis, it might be easier to work with the wide and long forms of this DataFrame - see cells below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>idx</th>\n",
       "      <th>info.is_pilot</th>\n",
       "      <th>info.pred_lemma</th>\n",
       "      <th>info.span1_text</th>\n",
       "      <th>info.span2_txt</th>\n",
       "      <th>label</th>\n",
       "      <th>preds.proba</th>\n",
       "      <th>span1</th>\n",
       "      <th>span2</th>\n",
       "      <th>label.ids</th>\n",
       "      <th>label.khot</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>say</td>\n",
       "      <td>says</td>\n",
       "      <td>he</td>\n",
       "      <td>[awareness, existed_after, existed_before, exi...</td>\n",
       "      <td>[0.9507238268852234, 0.08021300286054611, 0.00...</td>\n",
       "      <td>(6, 7)</td>\n",
       "      <td>(5, 6)</td>\n",
       "      <td>[0, 6, 7, 8, 10, 15, 17, 19]</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>carry</td>\n",
       "      <td>carrying</td>\n",
       "      <td>winston peters</td>\n",
       "      <td>[awareness, change_of_location, change_of_stat...</td>\n",
       "      <td>[0.8147344589233398, 0.8972967863082886, 0.146...</td>\n",
       "      <td>(12, 13)</td>\n",
       "      <td>(13, 15)</td>\n",
       "      <td>[0, 1, 4, 6, 7, 8, 10, 15, 17, 18, 19]</td>\n",
       "      <td>[1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>blow</td>\n",
       "      <td>blown</td>\n",
       "      <td>an airplane carrying winston peters</td>\n",
       "      <td>[change_of_location, change_of_state, existed_...</td>\n",
       "      <td>[0.20997169613838196, 0.7638567686080933, 0.11...</td>\n",
       "      <td>(16, 17)</td>\n",
       "      <td>(10, 15)</td>\n",
       "      <td>[1, 3, 7, 8]</td>\n",
       "      <td>[0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>beam</td>\n",
       "      <td>beams</td>\n",
       "      <td>winston</td>\n",
       "      <td>[awareness, change_of_state_continuous, existe...</td>\n",
       "      <td>[0.5660699605941772, 0.15035615861415863, 0.03...</td>\n",
       "      <td>(5, 6)</td>\n",
       "      <td>(4, 5)</td>\n",
       "      <td>[0, 4, 6, 7, 8, 10, 13, 15, 17, 18, 19]</td>\n",
       "      <td>[1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>False</td>\n",
       "      <td>tell</td>\n",
       "      <td>told</td>\n",
       "      <td>kistler</td>\n",
       "      <td>[awareness, existed_after, existed_before, exi...</td>\n",
       "      <td>[0.9896626472473145, 0.022328440099954605, 0.0...</td>\n",
       "      <td>(30, 31)</td>\n",
       "      <td>(29, 30)</td>\n",
       "      <td>[0, 6, 7, 8, 10, 15, 17, 19]</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   idx  info.is_pilot info.pred_lemma info.span1_text  \\\n",
       "0    0          False             say            says   \n",
       "1    0          False           carry        carrying   \n",
       "2    0          False            blow           blown   \n",
       "3    1          False            beam           beams   \n",
       "4    2          False            tell            told   \n",
       "\n",
       "                        info.span2_txt  \\\n",
       "0                                   he   \n",
       "1                       winston peters   \n",
       "2  an airplane carrying winston peters   \n",
       "3                              winston   \n",
       "4                              kistler   \n",
       "\n",
       "                                               label  \\\n",
       "0  [awareness, existed_after, existed_before, exi...   \n",
       "1  [awareness, change_of_location, change_of_stat...   \n",
       "2  [change_of_location, change_of_state, existed_...   \n",
       "3  [awareness, change_of_state_continuous, existe...   \n",
       "4  [awareness, existed_after, existed_before, exi...   \n",
       "\n",
       "                                         preds.proba     span1     span2  \\\n",
       "0  [0.9507238268852234, 0.08021300286054611, 0.00...    (6, 7)    (5, 6)   \n",
       "1  [0.8147344589233398, 0.8972967863082886, 0.146...  (12, 13)  (13, 15)   \n",
       "2  [0.20997169613838196, 0.7638567686080933, 0.11...  (16, 17)  (10, 15)   \n",
       "3  [0.5660699605941772, 0.15035615861415863, 0.03...    (5, 6)    (4, 5)   \n",
       "4  [0.9896626472473145, 0.022328440099954605, 0.0...  (30, 31)  (29, 30)   \n",
       "\n",
       "                                 label.ids  \\\n",
       "0             [0, 6, 7, 8, 10, 15, 17, 19]   \n",
       "1   [0, 1, 4, 6, 7, 8, 10, 15, 17, 18, 19]   \n",
       "2                             [1, 3, 7, 8]   \n",
       "3  [0, 4, 6, 7, 8, 10, 13, 15, 17, 18, 19]   \n",
       "4             [0, 6, 7, 8, 10, 15, 17, 19]   \n",
       "\n",
       "                                          label.khot  \n",
       "0  [1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, ...  \n",
       "1  [1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, ...  \n",
       "2  [0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, ...  \n",
       "3  [1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, ...  \n",
       "4  [1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, ...  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.target_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wide and Long Data\n",
    "\n",
    "For background on these views, see https://altair-viz.github.io/user_guide/data.html#long-form-vs-wide-form-data\n",
    "\n",
    "Here's a \"wide\" version of the data, with the usual metadata plus `2* num_labels` columns: `label.true.<label_name>` and `preds.proba.<label_name>` for each target class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>idx</th>\n",
       "      <th>info.is_pilot</th>\n",
       "      <th>info.pred_lemma</th>\n",
       "      <th>info.span1_text</th>\n",
       "      <th>info.span2_txt</th>\n",
       "      <th>span1</th>\n",
       "      <th>span2</th>\n",
       "      <th>label.true.awareness</th>\n",
       "      <th>label.true.change_of_location</th>\n",
       "      <th>label.true.change_of_possession</th>\n",
       "      <th>...</th>\n",
       "      <th>preds.proba.instigation</th>\n",
       "      <th>preds.proba.location_of_event</th>\n",
       "      <th>preds.proba.makes_physical_contact</th>\n",
       "      <th>preds.proba.partitive</th>\n",
       "      <th>preds.proba.predicate_changed_argument</th>\n",
       "      <th>preds.proba.sentient</th>\n",
       "      <th>preds.proba.stationary</th>\n",
       "      <th>preds.proba.volition</th>\n",
       "      <th>preds.proba.was_for_benefit</th>\n",
       "      <th>preds.proba.was_used</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>say</td>\n",
       "      <td>says</td>\n",
       "      <td>he</td>\n",
       "      <td>(6, 7)</td>\n",
       "      <td>(5, 6)</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.936860</td>\n",
       "      <td>0.004401</td>\n",
       "      <td>0.003261</td>\n",
       "      <td>0.254228</td>\n",
       "      <td>0.002805</td>\n",
       "      <td>0.975760</td>\n",
       "      <td>0.003733</td>\n",
       "      <td>0.938958</td>\n",
       "      <td>0.143198</td>\n",
       "      <td>0.945751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>carry</td>\n",
       "      <td>carrying</td>\n",
       "      <td>winston peters</td>\n",
       "      <td>(12, 13)</td>\n",
       "      <td>(13, 15)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.742384</td>\n",
       "      <td>0.006977</td>\n",
       "      <td>0.006286</td>\n",
       "      <td>0.078980</td>\n",
       "      <td>0.006507</td>\n",
       "      <td>0.667234</td>\n",
       "      <td>0.004935</td>\n",
       "      <td>0.652489</td>\n",
       "      <td>0.387525</td>\n",
       "      <td>0.932438</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>blow</td>\n",
       "      <td>blown</td>\n",
       "      <td>an airplane carrying winston peters</td>\n",
       "      <td>(16, 17)</td>\n",
       "      <td>(10, 15)</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.173488</td>\n",
       "      <td>0.009159</td>\n",
       "      <td>0.013924</td>\n",
       "      <td>0.277387</td>\n",
       "      <td>0.004450</td>\n",
       "      <td>0.194129</td>\n",
       "      <td>0.004115</td>\n",
       "      <td>0.029275</td>\n",
       "      <td>0.124122</td>\n",
       "      <td>0.724787</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>beam</td>\n",
       "      <td>beams</td>\n",
       "      <td>winston</td>\n",
       "      <td>(5, 6)</td>\n",
       "      <td>(4, 5)</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.918481</td>\n",
       "      <td>0.009543</td>\n",
       "      <td>0.007555</td>\n",
       "      <td>0.120318</td>\n",
       "      <td>0.007820</td>\n",
       "      <td>0.808103</td>\n",
       "      <td>0.007416</td>\n",
       "      <td>0.578732</td>\n",
       "      <td>0.544744</td>\n",
       "      <td>0.919867</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>False</td>\n",
       "      <td>tell</td>\n",
       "      <td>told</td>\n",
       "      <td>kistler</td>\n",
       "      <td>(30, 31)</td>\n",
       "      <td>(29, 30)</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.922490</td>\n",
       "      <td>0.015724</td>\n",
       "      <td>0.011915</td>\n",
       "      <td>0.314700</td>\n",
       "      <td>0.009969</td>\n",
       "      <td>0.963026</td>\n",
       "      <td>0.009797</td>\n",
       "      <td>0.985373</td>\n",
       "      <td>0.749102</td>\n",
       "      <td>0.960411</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 47 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   idx  info.is_pilot info.pred_lemma info.span1_text  \\\n",
       "0    0          False             say            says   \n",
       "1    0          False           carry        carrying   \n",
       "2    0          False            blow           blown   \n",
       "3    1          False            beam           beams   \n",
       "4    2          False            tell            told   \n",
       "\n",
       "                        info.span2_txt     span1     span2  \\\n",
       "0                                   he    (6, 7)    (5, 6)   \n",
       "1                       winston peters  (12, 13)  (13, 15)   \n",
       "2  an airplane carrying winston peters  (16, 17)  (10, 15)   \n",
       "3                              winston    (5, 6)    (4, 5)   \n",
       "4                              kistler  (30, 31)  (29, 30)   \n",
       "\n",
       "   label.true.awareness  label.true.change_of_location  \\\n",
       "0                     1                              0   \n",
       "1                     1                              1   \n",
       "2                     0                              1   \n",
       "3                     1                              0   \n",
       "4                     1                              0   \n",
       "\n",
       "   label.true.change_of_possession          ...           \\\n",
       "0                                0          ...            \n",
       "1                                0          ...            \n",
       "2                                0          ...            \n",
       "3                                0          ...            \n",
       "4                                0          ...            \n",
       "\n",
       "   preds.proba.instigation  preds.proba.location_of_event  \\\n",
       "0                 0.936860                       0.004401   \n",
       "1                 0.742384                       0.006977   \n",
       "2                 0.173488                       0.009159   \n",
       "3                 0.918481                       0.009543   \n",
       "4                 0.922490                       0.015724   \n",
       "\n",
       "   preds.proba.makes_physical_contact  preds.proba.partitive  \\\n",
       "0                            0.003261               0.254228   \n",
       "1                            0.006286               0.078980   \n",
       "2                            0.013924               0.277387   \n",
       "3                            0.007555               0.120318   \n",
       "4                            0.011915               0.314700   \n",
       "\n",
       "   preds.proba.predicate_changed_argument  preds.proba.sentient  \\\n",
       "0                                0.002805              0.975760   \n",
       "1                                0.006507              0.667234   \n",
       "2                                0.004450              0.194129   \n",
       "3                                0.007820              0.808103   \n",
       "4                                0.009969              0.963026   \n",
       "\n",
       "   preds.proba.stationary  preds.proba.volition  preds.proba.was_for_benefit  \\\n",
       "0                0.003733              0.938958                     0.143198   \n",
       "1                0.004935              0.652489                     0.387525   \n",
       "2                0.004115              0.029275                     0.124122   \n",
       "3                0.007416              0.578732                     0.544744   \n",
       "4                0.009797              0.985373                     0.749102   \n",
       "\n",
       "   preds.proba.was_used  \n",
       "0              0.945751  \n",
       "1              0.932438  \n",
       "2              0.724787  \n",
       "3              0.919867  \n",
       "4              0.960411  \n",
       "\n",
       "[5 rows x 47 columns]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.target_df_wide.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can fairly easily compute per-label metrics from the wide form, by selecting the appropriate pair of columns:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "awareness                     0.897436\n",
      "change_of_location            0.251969\n",
      "change_of_possession          0.048780\n",
      "change_of_state               0.387097\n",
      "change_of_state_continuous    0.595890\n",
      "changes_possession            0.000000\n",
      "existed_after                 0.951686\n",
      "existed_before                0.919081\n",
      "existed_during                0.987826\n",
      "exists_as_physical            0.000000\n",
      "instigation                   0.806565\n",
      "location_of_event             0.000000\n",
      "makes_physical_contact        0.000000\n",
      "partitive                     0.055944\n",
      "predicate_changed_argument    0.000000\n",
      "sentient                      0.888519\n",
      "stationary                    0.000000\n",
      "volition                      0.845735\n",
      "was_for_benefit               0.640316\n",
      "was_used                      0.917910\n",
      "dtype: float64\n",
      "Macro average F1: 0.4597\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/share/anaconda3/lib/python3.6/site-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no predicted samples.\n",
      "  'precision', 'predicted', average, warn_for)\n"
     ]
    }
   ],
   "source": [
    "wide_df = preds.target_df_wide\n",
    "scores_by_label = {}\n",
    "for label in preds.all_labels:\n",
    "    y_true = wide_df['label.true.' + label]\n",
    "    y_pred = wide_df['preds.proba.' + label] >= 0.5\n",
    "    score = metrics.f1_score(y_true=y_true, y_pred=y_pred)\n",
    "    scores_by_label[label] = score\n",
    "scores = pd.Series(scores_by_label)\n",
    "print(scores)\n",
    "print(\"Macro average F1: %.04f\" % scores.mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And here's a \"long\" version of the same, with a single `label` column, and one column each for `label.true` and `preds.proba` for that label:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>idx</th>\n",
       "      <th>label</th>\n",
       "      <th>label.true</th>\n",
       "      <th>preds.proba</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>awareness</td>\n",
       "      <td>1</td>\n",
       "      <td>0.950724</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>change_of_location</td>\n",
       "      <td>0</td>\n",
       "      <td>0.080213</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>change_of_possession</td>\n",
       "      <td>0</td>\n",
       "      <td>0.007079</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>change_of_state</td>\n",
       "      <td>0</td>\n",
       "      <td>0.093276</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>change_of_state_continuous</td>\n",
       "      <td>0</td>\n",
       "      <td>0.160939</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   idx                       label  label.true  preds.proba\n",
       "0    0                   awareness           1     0.950724\n",
       "1    0          change_of_location           0     0.080213\n",
       "2    0        change_of_possession           0     0.007079\n",
       "3    0             change_of_state           0     0.093276\n",
       "4    0  change_of_state_continuous           0     0.160939"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.target_df_long.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can easily get the set of labels available here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['awareness', 'change_of_location', 'change_of_possession',\n",
       "       'change_of_state', 'change_of_state_continuous',\n",
       "       'changes_possession', 'existed_after', 'existed_before',\n",
       "       'existed_during', 'exists_as_physical', 'instigation',\n",
       "       'location_of_event', 'makes_physical_contact', 'partitive',\n",
       "       'predicate_changed_argument', 'sentient', 'stationary', 'volition',\n",
       "       'was_for_benefit', 'was_used'], dtype=object)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.target_df_long.label.unique()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And easily compute micro-averaged metrics by simply comparing the `label.true` and `preds.proba` columns:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8297897060532125"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "long_df = preds.target_df_long\n",
    "metrics.f1_score(y_true=long_df['label.true'], y_pred=(long_df['preds.proba'] >= 0.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
