{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Search - Speech Tagger"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial walks you through writing learning to search code using the VW python interface. Once you've completed this, you can graduate to the C++ version, which will be faster for the computer but more painful for you.\n",
    "\n",
    "The \"learning to search\" paradigm solves problems that look like the following. You have a sequence of decisions to make. At the end of making these decisions, the world tells you how bad your decisions were. You want to condition later decisions on earlier decisions. But thankfully, at \"training time,\" you have access to an *oracle* that will tell you the right answer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with a simple example: sequence labeling for Part of Speech tagging. The goal is to take a sequence of words (\"the monster ate a big sandwich\") and label them with their parts of speech (in this case: Det Noun Verb Det Adj Noun).\n",
    "\n",
    "We will choose to solve this problem with left-to-right search. I.e., we'll label the first word, then the second then the third and so on.\n",
    "\n",
    "For any vw project in python, we have to start by importing the pyvw library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import vowpalwabbit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's define our data first. We'll do this first by defining the labels (one annoying thing is that labels in vw have to be integers):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DET = 1\n",
    "NOUN = 2\n",
    "VERB = 3\n",
    "ADJ = 4\n",
    "my_dataset = [\n",
    "    [\n",
    "        (DET, \"the\"),\n",
    "        (NOUN, \"monster\"),\n",
    "        (VERB, \"ate\"),\n",
    "        (DET, \"a\"),\n",
    "        (ADJ, \"big\"),\n",
    "        (NOUN, \"sandwich\"),\n",
    "    ],\n",
    "    [(DET, \"the\"), (NOUN, \"sandwich\"), (VERB, \"was\"), (ADJ, \"tasty\")],\n",
    "    [(NOUN, \"it\"), (VERB, \"ate\"), (NOUN, \"it\"), (ADJ, \"all\")],\n",
    "]\n",
    "print(my_dataset[2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we have an example of a (correctly) tagged sentence.\n",
    "\n",
    "Now, we need to write the structured prediction code. To do this, we have to write a new class that derives from the `pyvw.SearchTask` class.\n",
    "\n",
    "This class *must* have two functions: `__init__` and `_run`.\n",
    "\n",
    "The initialization function takes three arguments (plus `self`): a vw object (`vw`), a search object (`sch`), and the number of actions (`num_actions`) that this object has been initialized with. Within the initialization function, we must first initialize the parent class, and then we can set whatever options we want via `sch.set_options(...)`. Of course we can also do whatever additional initialization we like.\n",
    "\n",
    "The `_run` function executes the sequence of decisions on a given input. The input will be of whatever type our data is (so, in the above example, it will be a list of (label,word) pairs).\n",
    "\n",
    "Here is a basic implementation of sequence labeling:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SequenceLabeler(vowpalwabbit.pyvw.SearchTask):\n",
    "    def __init__(self, vw, sch, num_actions):\n",
    "        # you must must must initialize the parent class\n",
    "        # this will automatically store self.sch <- sch, self.vw <- vw\n",
    "        vowpalwabbit.pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
    "\n",
    "        # set whatever options you want\n",
    "        sch.set_options(sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES)\n",
    "\n",
    "    def _run(\n",
    "        self, sentence\n",
    "    ):  # it's called _run to remind you that you shouldn't call it directly!\n",
    "        output = []\n",
    "        for n in range(len(sentence)):\n",
    "            pos, word = sentence[n]\n",
    "            # use \"with...as...\" to guarantee that the example is finished properly\n",
    "            ex = self.vw.example({\"w\": [word]})\n",
    "            pred = self.sch.predict(\n",
    "                examples=ex,\n",
    "                my_tag=n + 1,\n",
    "                oracle=pos,\n",
    "                condition=[(n, \"p\"), (n - 1, \"q\")],\n",
    "            )\n",
    "            self.vw.finish_example(\n",
    "                [ex]\n",
    "            )  # must pass the example in as a list because search is a MultiEx reduction\n",
    "            output.append(pred)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's unpack this a bit.\n",
    "\n",
    "The `__init__` function is simple. It first calls the parent initializer and then sets some options. The options it sets are two things designed to make the programmer's life easier. The first is `AUTO_HAMMING_LOSS`. Remember earlier we said that when the sequence of decision is made, you have to say how bad it was? This says that we want this to be computed automatically by comparing the individual decisions to the oracle decisions, and defining the loss to be the sum of incorrect decisions.\n",
    "\n",
    "The second is `AUTO_CONDITION_FEATURES`. This is a bit subtler. Later in the `_run` function, we will say that the label of the `n`th word depends on the label of the `n-1`th word. In order to get the underlying classifier to *pay attention* to that conditioning, we need to add features. We could do that manually (we'll do this later) or we can ask vw to do it automatically for us. For simplicity, we choose the latter.\n",
    "\n",
    "The `_run` function takes a sentence (list of pos/word pairs) as input. We loop over each word position `n` in the sentence and extract the `pos,word` pair. We then construct a VW example that consists of a single feature (the `word`) in the 'w' namespace. Given that example `ex`, we make a search-based prediction by calling `self.sch.predict(...)`. This is a fairly complicated function that takes a number of arguments. Here, we are calling it with the following:\n",
    "\n",
    " - `examples=ex`: This tells the predictor what features to use.\n",
    " - `my_tag=n+1`: In general, we want to condition the prediction of the `n`th label on the `n-1`th label. In order to do this, we have to give each prediction a \"name\" so that we can refer back to it in the future. This name needs to be an integer `>= 1`. So we'll call the first word `1`, the second word `2`, and so on. It has to be `n+1` and not `n` because of the 1-based requirement.\n",
    " - `oracle=pos`: As mentioned before, on training data, we need to tell the system what the \"true\" (or \"best\") decision is at this point in time. Here, it is the given part of speech label.\n",
    " - `condition=(n,'p')`: This says that this prediction depends on the output of whichever-prediction-was-called-`n`, and that the \"nature\" of that condition is called 'p' (for \"predecessor\" in this case, though this is entirely up to you)\n",
    "\n",
    "Now, we're ready to train the model. We do this in three steps. First, we initialize a vw object, telling it that we have a `--search` task with 4 labels, second that the specific type of `--search_task` is `hook` (you will always use the `hook` task) and finally that we want it to be quiet and use a larger `ring_size` (you can ignore the `ring_size` for now)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vw = vowpalwabbit.Workspace(\"--search 4 --search_task hook\", quiet=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we need to initialize the search task. We use the `vw.init_search_task` function to do this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequenceLabeler = vw.init_search_task(SequenceLabeler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can train on the dataset we defined earlier, using `sequenceLabeler.learn` (the `.learn` function is inherited from the `pyvw.SearchTask` class). The `.learn` function takes any iterator over data. Whatever type of data it iterates over is what it will feed to your `_run` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "    sequenceLabeler.learn(my_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Of course, we want to see if it's learned anything. So let's create a single test example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_example = [(1, w) for w in \"the sandwich ate a monster\".split()]\n",
    "print(test_example)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've used `0` as the labels so you can be sure that vw isn't cheating and it's actually making predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = sequenceLabeler.predict(test_example)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we look back at our POS tag definitions, this is DET NOUN VERB DET NOUN, which is indeed correct!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Removing the AUTO features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the above example we used both AUTO_HAMMING_LOSS and AUTO_CONDITION_FEATURES. To make more explicit what these are doing, let's rewrite our `SequenceLabeler` class without them! Here's a version that gets rid of both simultaneously. It is only modestly more complex:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SequenceLabeler2(vowpalwabbit.pyvw.SearchTask):\n",
    "    def __init__(self, vw, sch, num_actions):\n",
    "        vowpalwabbit.pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
    "\n",
    "    def _run(self, sentence):\n",
    "        output = []\n",
    "        loss = 0.0\n",
    "        for n in range(len(sentence)):\n",
    "            pos, word = sentence[n]\n",
    "            prevPred = output[n - 1] if n > 0 else \"<s>\"\n",
    "\n",
    "            ex = self.vw.example({\"w\": [word], \"p\": [prevPred]})\n",
    "\n",
    "            pred = self.sch.predict(\n",
    "                examples=ex, my_tag=n + 1, oracle=pos, condition=(n, \"p\")\n",
    "            )\n",
    "            vw.finish_example([ex])\n",
    "\n",
    "            output.append(pred)\n",
    "\n",
    "            if pred != pos:\n",
    "                loss += 1.0\n",
    "\n",
    "        self.sch.loss(loss)\n",
    "        return output\n",
    "\n",
    "\n",
    "sequenceLabeler2 = vw.init_search_task(SequenceLabeler2)\n",
    "sequenceLabeler2.learn(my_dataset)\n",
    "print(sequenceLabeler2.predict([(1, w) for w in \"the sandwich ate a monster\".split()]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If executed correctly, this should have printed `[1, 2, 3, 1, 2]`.\n",
    "\n",
    "There are essentially two things that changed here. In order to get rid of AUTO_HAMMING_LOSS, we had to keep track of how many errors the predictor had made. This is done by checking whether `pred != pos` inside the inner loop, and then at the end calling `self.sch.loss(loss)` to tell the search procedure how well we did.\n",
    "\n",
    "In order to get rid of AUTO_CONDITION_FEATURES, we need to explicitly add the previous prediction as features to the example we are predicting with. Here, we've done this by extracting the previous prediction (`prevPred`) and explicitly adding it as a feature (in the 'p' namespace) during the example construction.\n",
    "\n",
    "**Important Note:** even though we're not using AUTO_CONDITION_FEATURES, we *still* must tell the search process that this prediction depends on the previous prediction. We need to do this because the learning algorithm automatically memoizes certain computations, and so it needs to know that, when memoizing, to remember that this prediction *might have been different* if a previous decision were different."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Very silly Covington-esqu dependency parsing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's also try a variant of dependency parsing to see that this doesn't work just for sequence-labeling list tasks. First we need to define some data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the label for each word is its parent, or -1 for root\n",
    "my_dataset = [\n",
    "    [\n",
    "        (\"the\", 1),  # 0\n",
    "        (\"monster\", 2),  # 1\n",
    "        (\"ate\", -1),  # 2\n",
    "        (\"a\", 5),  # 3\n",
    "        (\"big\", 5),  # 4\n",
    "        (\"sandwich\", 2),\n",
    "    ],  # 5\n",
    "    [(\"the\", 1), (\"sandwich\", 2), (\"is\", -1), (\"tasty\", 2)],  # 0  # 1  # 2  # 3\n",
    "    [\n",
    "        (\"a\", 1),  # 0\n",
    "        (\"sandwich\", 2),  # 1\n",
    "        (\"ate\", -1),  # 2\n",
    "        (\"itself\", 2),  # 3\n",
    "    ],\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For instance, in the first sentence, the parent of \"the\" is \"monster\"; the parent of \"monster\" is \"ate\"; and \"ate\" is the root.\n",
    "\n",
    "The basic idea of a Covington-style dependency parser is to loop over all O(N^2) word pairs and ask if one is the parent of the other. In a real parser you would want to make sure that you don't have cycles, that you have a unique root and (perhaps) that the resulting graph is projective. I'm not doing that here. Hopefully I'll add a shift-reduce parser example later that *does* do this. Here's an implementation of this idea:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CovingtonDepParser(vowpalwabbit.pyvw.SearchTask):\n",
    "    def __init__(self, vw, sch, num_actions):\n",
    "        vowpalwabbit.pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
    "        sch.set_options(sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES)\n",
    "\n",
    "    def _run(self, sentence):\n",
    "        N = len(sentence)\n",
    "        # initialize our output so everything is a root\n",
    "        output = [-1 for i in range(N)]\n",
    "        for n in range(N):\n",
    "            wordN, parN = sentence[n]\n",
    "            for m in range(-1, N):\n",
    "                if m == n:\n",
    "                    continue\n",
    "                wordM = sentence[m][0] if m > 0 else \"*root*\"\n",
    "                # ask the question: is m the parent of n?\n",
    "                isParent = 2 if m == parN else 1\n",
    "\n",
    "                # construct an example\n",
    "                dir = \"l\" if m < n else \"r\"\n",
    "                ex = self.vw.example(\n",
    "                    {\n",
    "                        \"a\": [wordN, dir + \"_\" + wordN],\n",
    "                        \"b\": [wordM, dir + \"_\" + wordN],\n",
    "                        \"p\": [wordN + \"_\" + wordM, dir + \"_\" + wordN + \"_\" + wordM],\n",
    "                        \"d\": [\n",
    "                            str(m - n <= d) + \"<=\" + str(d)\n",
    "                            for d in [-8, -4, -2, -1, 1, 2, 4, 8]\n",
    "                        ]\n",
    "                        + [\n",
    "                            str(m - n >= d) + \">=\" + str(d)\n",
    "                            for d in [-8, -4, -2, -1, 1, 2, 4, 8]\n",
    "                        ],\n",
    "                    }\n",
    "                )\n",
    "                pred = self.sch.predict(\n",
    "                    examples=ex,\n",
    "                    my_tag=(m + 1) * N + n + 1,\n",
    "                    oracle=isParent,\n",
    "                    condition=[\n",
    "                        (max(0, (m) * N + n + 1), \"p\"),\n",
    "                        (max(0, (m + 1) * N + n), \"q\"),\n",
    "                    ],\n",
    "                )\n",
    "                self.vw.finish_example(\n",
    "                    [ex]\n",
    "                )  # must pass the example in as a list because search is a MultiEx reduction\n",
    "                if pred == 2:\n",
    "                    output[n] = m\n",
    "                    break\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this, `output` stores the predicted tree and is initialized with every word being a root. We then loop over every word (`n`) and every possible parent (`m`, which can be -1, though that's really kind of unnecessary).\n",
    "\n",
    "The features are basically the words under consideration, the words paired with the direction of the edge, the pair of words, and then a bunch of (binned) distance features.\n",
    "\n",
    "We can train and run this parser with:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vw = vowpalwabbit.Workspace(\"--search 2 --search_task hook\", quiet=True)\n",
    "task = vw.init_search_task(CovingtonDepParser)\n",
    "for p in range(10):  # do ten passes over the training data\n",
    "    task.learn(my_dataset)\n",
    "print(\"testing\")\n",
    "print(task.predict([(w, -1) for w in \"the monster ate a sandwich\".split()]))\n",
    "print(\"should have printed [ 1 2 -1 4 2 ]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One could argue that a more natural way to do this would be with LDF rather than the inner loop over `m`. We'll do that next."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LDF-based Covington-style dependency parser"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One of the weirdnesses in the previous parser implementation is that it makes N-many binary decisions per word (\"is word n my parent?\") rather than a single N-way decision. The latter makes more sense.\n",
    "\n",
    "The challenge is that you cannot set this up as a vanilla multiclass classification problem because (a) the number of \"classes\" is a function of the input (a length N sentence will have N classes) and (b) class \"1\" and \"2\" don't mean anything consistently across examples.\n",
    "\n",
    "The way around this is label-dependent features (LDF). In LDF mode, the class ids are (essentially -- see caveat below) irrelevant. Instead, you simply provide features that depend on the label (hence \"LDF\"). In particular, for each possible label, you provide a *different* feature vector, and the goal of learning is to pick one of those as the \"correct\" one.\n",
    "\n",
    "Here's a re-implementation of Covington using LDF:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CovingtonDepParserLDF(vowpalwabbit.pyvw.SearchTask):\n",
    "    def __init__(self, vw, sch, num_actions):\n",
    "        vowpalwabbit.pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
    "        sch.set_options(\n",
    "            sch.AUTO_HAMMING_LOSS | sch.IS_LDF | sch.AUTO_CONDITION_FEATURES\n",
    "        )\n",
    "\n",
    "    def makeExample(self, sentence, n, m):\n",
    "        wordN = sentence[n][0]\n",
    "        wordM = sentence[m][0] if m >= 0 else \"*ROOT*\"\n",
    "        dir = \"l\" if m < n else \"r\"\n",
    "        ex = self.vw.example(\n",
    "            {\n",
    "                \"a\": [wordN, dir + \"_\" + wordN],\n",
    "                \"b\": [wordM, dir + \"_\" + wordM],\n",
    "                \"p\": [wordN + \"_\" + wordM, dir + \"_\" + wordN + \"_\" + wordM],\n",
    "                \"d\": [\n",
    "                    str(m - n <= d) + \"<=\" + str(d)\n",
    "                    for d in [-8, -4, -2, -1, 1, 2, 4, 8]\n",
    "                ]\n",
    "                + [\n",
    "                    str(m - n >= d) + \">=\" + str(d)\n",
    "                    for d in [-8, -4, -2, -1, 1, 2, 4, 8]\n",
    "                ],\n",
    "            },\n",
    "            labelType=self.vw.lCostSensitive,\n",
    "        )\n",
    "        ex.set_label_string(\n",
    "            str(m + 2) + \":0\"\n",
    "        )  # project the m-indices (-1...N into 1...N+2)\n",
    "        return ex\n",
    "\n",
    "    def _run(self, sentence):\n",
    "        N = len(sentence)\n",
    "        # initialize our output so everything is a root\n",
    "        output = [-1 for i in range(N)]\n",
    "        for n in range(N):\n",
    "            # make LDF examples\n",
    "            examples = [\n",
    "                self.makeExample(sentence, n, m) for m in range(-1, N) if n != m\n",
    "            ]\n",
    "\n",
    "            # truth\n",
    "            parN = sentence[n][1]\n",
    "            oracle = (\n",
    "                parN + 1 if parN < n else parN\n",
    "            )  # have to -1 because we excluded n==m from list\n",
    "\n",
    "            # make a prediction\n",
    "            pred = self.sch.predict(\n",
    "                examples=examples,\n",
    "                my_tag=n + 1,\n",
    "                oracle=oracle,\n",
    "                condition=[(n, \"p\"), (n - 1, \"q\")],\n",
    "            )\n",
    "\n",
    "            output[n] = (\n",
    "                pred - 1 if pred < n else pred\n",
    "            )  # have to +1 because n==m excluded\n",
    "\n",
    "            for ex in examples:\n",
    "                ex.finish()  # clean up\n",
    "\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are a few things going on here. Let's focus first on the `__init__` function. The only difference here is that when we call `sch.set_options` we provide `sch.IS_LDF` to declare that this is an LDF task.\n",
    "\n",
    "Let's skip the `makeExample` function for a minute and look at the `_run` function. You should recognize most of this from the non-LDF version. We initialize the `output` (parent) of every word to `-1` (meaning that every word is connected to the root).\n",
    "\n",
    "For each word `n`, we construct `N`-many examples: one for every -1..(N-1), except for the current word `n` because you cannot have self-loops. If we were being more clever, we would only do the ones that won't result in the creation of a cycle, but we're not being clever.\n",
    "\n",
    "Now, because the \"labels\" are just examples, it's a bit more complicated to specify the oracle. The oracle is an *index* into the examples list. So if `oracle` is the oracle action, then `examples[oracle]` is the corresponding example. We compute the oracle as follows. `parN` is the *actual* parent, which is going to be something in the range `-1 .. (N-1)`. If `parN < n` (this is a left arrow), then the oracle index is `parN+1` because the root (`-1`) is index `0` and so on. If `parN > n` (note: it cannot be equal to `n`) then, beacuse `n == m` is left out of the examples list, the correct index is just `parN`. Phew.\n",
    "\n",
    "We then ask for a prediction. Now, instead of giving a single example, with give the list of examples. The tag works the same way, as does the conditioning.\n",
    "\n",
    "Once we get a prediction out (called `pred`) we need to figure out what parent it actually corresponds to. This is simply un-doing the computaiton two paragraphs ago!\n",
    "\n",
    "Finally -- and this is skippable if you trust the Python garbage collector -- we tell VW that we're done with all the examples we created. We do this just to be pedantic; it's optional.\n",
    "\n",
    "Okay, now let's go back to the `makeExample` function. This takes two word ids (`n` and `m`) and makes an example that roughly says \"what would it look like if I had an edge from `n` to `m`?\" We construct basically the same feautres as before. There are two major changes, though:\n",
    "\n",
    "1. When we run `self.vw.example(...)` we provide `labelType=self.vw.lCostSensitive` as an argument. This is because under the hood, vw treats LDF examples as cost-sensitive classification examples. This means they need to have cost-sensitive labels, so that's how we need to create them.\n",
    "\n",
    "1. We explicitly set the label of the this example to `str(m+2)+\":0\"`. What is this? Well, this is _optional_ but recommended. Here's the issue. In LDF mode, recall that labels have no intrinsic meaning. This means that when vw does auto-conditioning, it's not really clear what to use as the \"previous prediction.\" By giving explicit label names (in this case, m+2) we're recording what the position of the last parent, which may be useful for predicting the next parent. We could avoid this necessity if we did our own feature engineering on the history, but for now, this seems to capture the right intuition.\n",
    "\n",
    "Given all this, we can now train and test our parser:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vw = vowpalwabbit.Workspace(\"--search 0 --csoaa_ldf m --search_task hook\", quiet=True)\n",
    "task = vw.init_search_task(CovingtonDepParserLDF)\n",
    "# BUG: This currently does not work because oracle generation is incorrect (generates invalid oracle values)#for p in range(2): # do two passes over the training data\n",
    "#    task.learn(my_dataset)\n",
    "# print(task.predict( [(w,-1) for w in \"the monster ate a sandwich\".split()] ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The correct parse of this sentence is `[1, 2, -1, 4, 2]` which is what this should have printed.\n",
    "\n",
    "There are two major things to notice in the initialization of VW. The first is that we say `--search 0`. The zero labels argument to `--search` declares that this is going to be an LDF task. We also have to tell VW that we want an LDF-enabled cost-sensitive learner, which is what `--csoaa_ldf m` does (if you're wondering, `m` means \"multiline\" -- just treat it as something you have to do). The rest should be familiar."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A simple word-alignment model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Okay, as a last example we'll do a simple word alignment model in the spirit of the IBM models. Note that this will be a *supervised* model; doing unsupervised stuff is a bit trickier.\n",
    "\n",
    "Here's some word alignment data. The dataset is triples of `E, A, F` where `A[i]` = list of words `E[i]` aligned to, or `[]` for null-aligned:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_dataset = [\n",
    "    (\"the blue house\".split(), ([0], [2], [1]), \"la maison bleue\".split()),\n",
    "    (\"the house\".split(), ([0], [1]), \"la maison\".split()),\n",
    "    (\"the flower\".split(), ([0], [1]), \"la fleur\".split()),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's going to be useful to compute alignment mismatches at the word level between true alignments (like `[1,2]`) and predicted alignments (like `[2,3,4]`). We use intersection-over-union error:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def alignmentError(true, sys):\n",
    "    t = set(true)\n",
    "    s = set(sys)\n",
    "    if len(t | s) == 0:\n",
    "        return 0.0\n",
    "    return 1.0 - float(len(t & s)) / float(len(t | s))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we can define our structured prediction task. This is also an LDF problem. Basically for each word on the English side, we'll loop over all possible phrases on the Foreign side to which it could align (maximum phrase length of three). For each of these options we'll create an example to be fed into the LDF classifier. We also ensure that the same foreign word cannot be covered by multiple English words, though this might not be a good idea in general."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class WordAligner(vowpalwabbit.pyvw.SearchTask):\n",
    "    def __init__(self, vw, sch, num_actions):\n",
    "        vowpalwabbit.pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
    "        sch.set_options(\n",
    "            sch.AUTO_HAMMING_LOSS | sch.IS_LDF | sch.AUTO_CONDITION_FEATURES\n",
    "        )\n",
    "\n",
    "    def makeExample(self, E, F, i, j0, l):\n",
    "        f = \"Null\" if j0 is None else [F[j0 + k] for k in range(l + 1)]\n",
    "        ex = self.vw.example(\n",
    "            {\n",
    "                \"e\": E[i],\n",
    "                \"f\": f,\n",
    "                \"p\": \"_\".join(f),\n",
    "                \"l\": str(l),\n",
    "                \"o\": [str(i - j0), str(i - j0 - l)] if j0 is not None else [],\n",
    "            },\n",
    "            labelType=self.vw.lCostSensitive,\n",
    "        )\n",
    "        lab = \"Null\" if j0 is None else str(j0 + l)\n",
    "        ex.set_label_string(lab + \":0\")\n",
    "        return ex\n",
    "\n",
    "    def _run(self, alignedSentence):\n",
    "        E, A, F = alignedSentence\n",
    "\n",
    "        # for each E word, we pick a F span\n",
    "        covered = {}  # which F words have been covered so far?\n",
    "        output = []\n",
    "\n",
    "        for i in range(len(E)):\n",
    "            examples = []  # contains vw examples\n",
    "            spans = []  # contains triples (alignment error, index in examples, [range])\n",
    "\n",
    "            # empty span:\n",
    "            examples.append(self.makeExample(E, F, i, None, None))\n",
    "            spans.append((alignmentError(A[i], []), 0, []))\n",
    "\n",
    "            # non-empty spans\n",
    "            for j0 in range(len(F)):\n",
    "                for l in range(3):  # max phrase length of 3\n",
    "                    if j0 + l >= len(F):\n",
    "                        break\n",
    "                    if covered.has_key(j0 + l):\n",
    "                        break\n",
    "\n",
    "                    id = len(examples)\n",
    "                    examples.append(self.makeExample(E, F, i, j0, l))\n",
    "                    spans.append(\n",
    "                        (\n",
    "                            alignmentError(A[i], range(j0, j0 + l + 1)),\n",
    "                            id,\n",
    "                            range(j0, j0 + l + 1),\n",
    "                        )\n",
    "                    )\n",
    "\n",
    "            sortedSpans = []\n",
    "            for s in spans:\n",
    "                sortedSpans.append(s)\n",
    "            sortedSpans.sort()\n",
    "            oracle = []\n",
    "            for id in range(len(sortedSpans)):\n",
    "                if sortedSpans[id][0] > sortedSpans[0][0]:\n",
    "                    break\n",
    "                oracle.append(sortedSpans[id][1])\n",
    "\n",
    "            pred = self.sch.predict(\n",
    "                examples=examples,\n",
    "                my_tag=i + 1,\n",
    "                oracle=oracle,\n",
    "                condition=[(i, \"p\"), (i - 1, \"q\")],\n",
    "            )\n",
    "\n",
    "            for ex in examples:\n",
    "                ex.finish()\n",
    "\n",
    "            output.append(spans[pred][2])\n",
    "            for j in spans[pred][2]:\n",
    "                covered[j] = True\n",
    "\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The only really complicated thing here is computing the oracle. What we do is, for each possible alignment, compute an intersection-over-union error rate. The oracle is then that alignment that achieves the smallest (local) error rate. This is not perfect, but is good enough. One interesting thing here is that now the `oracle` could be a *list*; this is completely supported by the underlying algorithms.\n",
    "\n",
    "We can train and test this model to make sure it does the right thing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vw = vowpalwabbit.Workspace(\n",
    "    \"--search 0 --csoaa_ldf m --search_task hook -q ef -q ep\", quiet=True\n",
    ")\n",
    "task = vw.init_search_task(WordAligner)\n",
    "# BUG: This is currently broken due to incorrect oracle generation. Currently under investigation.#for p in range(10):\n",
    "#    task.learn(my_dataset)\n",
    "# print(task.predict( (\"the blue flower\".split(), ([],[],[]), \"la fleur bleue\".split()) ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If this worked correctly, it should have printed `[[0], [2], [1]]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "name": ""
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
