{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Search - Covington"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training non-LDF\n",
      "testing non-LDF\n",
      "[1, 2, -1, 4, 2]\n",
      "should have printed [ 1 2 -1 4 2 ]\n",
      "training LDF\n",
      "testing LDF\n",
      "[1, 2, -1, 1, 2]\n",
      "should have printed [ 1 2 -1 4 2 ]\n"
     ]
    }
   ],
   "source": [
    "from vowpalwabbit import pyvw\n",
    "\n",
    "# the label for each word is its parent, or -1 for root\n",
    "my_dataset = [\n",
    "    [\n",
    "        (\"the\", 1),  # 0\n",
    "        (\"monster\", 2),  # 1\n",
    "        (\"ate\", -1),  # 2\n",
    "        (\"a\", 5),  # 3\n",
    "        (\"big\", 5),  # 4\n",
    "        (\"sandwich\", 2),\n",
    "    ],  # 5\n",
    "    [(\"the\", 1), (\"sandwich\", 2), (\"is\", -1), (\"tasty\", 2)],  # 0  # 1  # 2  # 3\n",
    "    [(\"a\", 1), (\"sandwich\", 2), (\"ate\", -1), (\"itself\", 2)],  # 0  # 1  # 2  # 3\n",
    "]\n",
    "\n",
    "\n",
    "class CovingtonDepParser(pyvw.SearchTask):\n",
    "    def __init__(self, vw, sch, num_actions):\n",
    "        pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
    "        sch.set_options(sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES)\n",
    "\n",
    "    def _run(self, sentence):\n",
    "        N = len(sentence)\n",
    "        # initialize our output so everything is a root\n",
    "        output = [-1 for i in range(N)]\n",
    "        for n in range(N):\n",
    "            wordN, parN = sentence[n]\n",
    "            for m in range(-1, N):\n",
    "                if m == n:\n",
    "                    continue\n",
    "                wordM = sentence[m][0] if m > 0 else \"*root*\"\n",
    "                # ask the question: is m the parent of n?\n",
    "                isParent = 2 if m == parN else 1\n",
    "\n",
    "                # construct an example\n",
    "                dir = \"l\" if m < n else \"r\"\n",
    "                ex = self.vw.example(\n",
    "                    {\n",
    "                        \"a\": [wordN, dir + \"_\" + wordN],\n",
    "                        \"b\": [wordM, dir + \"_\" + wordN],\n",
    "                        \"p\": [wordN + \"_\" + wordM, dir + \"_\" + wordN + \"_\" + wordM],\n",
    "                        \"d\": [\n",
    "                            str(m - n <= d) + \"<=\" + str(d)\n",
    "                            for d in [-8, -4, -2, -1, 1, 2, 4, 8]\n",
    "                        ]\n",
    "                        + [\n",
    "                            str(m - n >= d) + \">=\" + str(d)\n",
    "                            for d in [-8, -4, -2, -1, 1, 2, 4, 8]\n",
    "                        ],\n",
    "                    }\n",
    "                )\n",
    "                pred = self.sch.predict(\n",
    "                    examples=ex,\n",
    "                    my_tag=(m + 1) * N + n + 1,\n",
    "                    oracle=isParent,\n",
    "                    condition=[\n",
    "                        (max(0, (m) * N + n + 1), \"p\"),\n",
    "                        (max(0, (m + 1) * N + n), \"q\"),\n",
    "                    ],\n",
    "                )\n",
    "                vw.finish_example(\n",
    "                    [ex]\n",
    "                )  # must pass the example in as a list because search is a MultiEx reduction\n",
    "                if pred == 2:\n",
    "                    output[n] = m\n",
    "                    break\n",
    "        return output\n",
    "\n",
    "\n",
    "class CovingtonDepParserLDF(pyvw.SearchTask):\n",
    "    def __init__(self, vw, sch, num_actions):\n",
    "        pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
    "        sch.set_options(\n",
    "            sch.AUTO_HAMMING_LOSS | sch.IS_LDF | sch.AUTO_CONDITION_FEATURES\n",
    "        )\n",
    "\n",
    "    def makeExample(self, sentence, n, m):\n",
    "        wordN = sentence[n][0]\n",
    "        wordM = sentence[m][0] if m >= 0 else \"*ROOT*\"\n",
    "        dir = \"l\" if m < n else \"r\"\n",
    "        ex = self.vw.example(\n",
    "            {\n",
    "                \"a\": [wordN, dir + \"_\" + wordN],\n",
    "                \"b\": [wordM, dir + \"_\" + wordM],\n",
    "                \"p\": [wordN + \"_\" + wordM, dir + \"_\" + wordN + \"_\" + wordM],\n",
    "                \"d\": [\n",
    "                    str(m - n <= d) + \"<=\" + str(d)\n",
    "                    for d in [-8, -4, -2, -1, 1, 2, 4, 8]\n",
    "                ]\n",
    "                + [\n",
    "                    str(m - n >= d) + \">=\" + str(d)\n",
    "                    for d in [-8, -4, -2, -1, 1, 2, 4, 8]\n",
    "                ],\n",
    "            },\n",
    "            labelType=self.vw.lCostSensitive,\n",
    "        )\n",
    "        # the label string is (m+2):0. The :0 means cost zero (this is\n",
    "        # irrelevant and could be any number). +2 ensures >= 1\n",
    "        ex.set_label_string(str(100 + n - m) + \":0\")\n",
    "        return ex\n",
    "\n",
    "    def _run(self, sentence):\n",
    "        N = len(sentence)\n",
    "        # initialize our output so everything is a root\n",
    "        output = [-1 for i in range(N)]\n",
    "        for n in range(N):\n",
    "            # make LDF examples\n",
    "            examples = []\n",
    "            for m in range(-1, N):\n",
    "                if n != m:\n",
    "                    examples.append(self.makeExample(sentence=sentence, n=n, m=m))\n",
    "\n",
    "            # truth\n",
    "            parN = sentence[n][1]\n",
    "\n",
    "            # Mapping:\n",
    "            # -1      => 1\n",
    "            # 0...n-1 => 2...n+1\n",
    "            # n+1...N => n+2 ...N+1\n",
    "            oracle = (\n",
    "                parN + 2 if parN < n else parN + 1\n",
    "            )  # have to -1 because we excluded n==m from list\n",
    "\n",
    "            # make a prediction\n",
    "            pred = self.sch.predict(\n",
    "                examples=examples,\n",
    "                my_tag=n + 1,\n",
    "                oracle=oracle,\n",
    "                condition=[(n, \"p\"), (n - 1, \"q\")],\n",
    "            )\n",
    "\n",
    "            vw.finish_example(examples)\n",
    "\n",
    "            # Reverse mapping:\n",
    "            # 1 => -1\n",
    "            # 2...n+1 => 0...n-1\n",
    "            # n+2...N+1 => n+1...N\n",
    "            output[n] = (\n",
    "                pred - 2 if pred <= n + 1 else pred - 1\n",
    "            )  # have to +1 because n==m excluded\n",
    "\n",
    "        return output\n",
    "\n",
    "\n",
    "# TODO: if they make sure search=0 <==> ldf <==> csoaa_ldf\n",
    "\n",
    "# demo the non-ldf version:\n",
    "\n",
    "print(\"training non-LDF\")\n",
    "vw = pyvw.Workspace(\"--search 2 --search_task hook\", quiet=True)\n",
    "task = vw.init_search_task(CovingtonDepParser)\n",
    "for p in range(2):  # do two passes over the training data\n",
    "    task.learn(my_dataset)\n",
    "print(\"testing non-LDF\")\n",
    "print(task.predict([(w, -1) for w in \"the monster ate a sandwich\".split()]))\n",
    "print(\"should have printed [ 1 2 -1 4 2 ]\")\n",
    "\n",
    "print(\"training LDF\")\n",
    "vw = pyvw.Workspace(\"--search 0 --csoaa_ldf m --search_task hook\", quiet=True)\n",
    "task = vw.init_search_task(CovingtonDepParserLDF)\n",
    "for p in range(100):  # do two passes over the training data\n",
    "    task.learn(my_dataset)\n",
    "print(\"testing LDF\")\n",
    "print(task.predict([(w, -1) for w in \"the monster ate a sandwich\".split()]))\n",
    "print(\"should have printed [ 1 2 -1 4 2 ]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "2cc929a270071711921cb2ad25a09768257b52278ee4b98c603d8d8861a97a9a"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('test': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
