{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Search - Sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "import sys\n",
    "from vowpalwabbit import pyvw\n",
    "\n",
    "\n",
    "class SequenceLabeler(pyvw.SearchTask):\n",
    "    def __init__(self, vw, sch, num_actions):\n",
    "        # you must must must initialize the parent class\n",
    "        # this will automatically store self.sch <- sch, self.vw <- vw\n",
    "        pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
    "\n",
    "        # you can test program options with sch.po_exists\n",
    "        # and get their values with sch.po_get -> string and\n",
    "        # sch.po_get_int -> int\n",
    "        if sch.po_exists(\"search\"):\n",
    "            print(\"found --search\")\n",
    "            print(\n",
    "                \"--search value =\",\n",
    "                sch.po_get(\"search\"),\n",
    "                \", type =\",\n",
    "                type(sch.po_get(\"search\")),\n",
    "            )\n",
    "\n",
    "        # set whatever options you want\n",
    "        sch.set_options(sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES)\n",
    "\n",
    "    def _run(\n",
    "        self, sentence\n",
    "    ):  # it's called _run to remind you that you shouldn't call it directly!\n",
    "        output = []\n",
    "        for n in range(len(sentence)):\n",
    "            pos, word = sentence[n]\n",
    "            # use \"with...as...\" to guarantee that the example is finished properly\n",
    "            ex = self.vw.example({\"w\": [word]})\n",
    "            pred = self.sch.predict(\n",
    "                examples=ex, my_tag=n + 1, oracle=pos, condition=(n, \"p\")\n",
    "            )\n",
    "            vw.finish_example(\n",
    "                [ex]\n",
    "            )  # must pass the example in as a list because search is a MultiEx reduction\n",
    "            output.append(pred)\n",
    "        return output\n",
    "\n",
    "\n",
    "# wow! your data can be ANY type you want... does NOT have to be VW examples\n",
    "DET = 1\n",
    "NOUN = 2\n",
    "VERB = 3\n",
    "ADJ = 4\n",
    "my_dataset = [\n",
    "    [\n",
    "        (DET, \"the\"),\n",
    "        (NOUN, \"monster\"),\n",
    "        (VERB, \"ate\"),\n",
    "        (DET, \"a\"),\n",
    "        (ADJ, \"big\"),\n",
    "        (NOUN, \"sandwich\"),\n",
    "    ],\n",
    "    [(DET, \"the\"), (NOUN, \"sandwich\"), (VERB, \"was\"), (ADJ, \"tasty\")],\n",
    "    [(NOUN, \"it\"), (VERB, \"ate\"), (NOUN, \"it\"), (ADJ, \"all\")],\n",
    "]\n",
    "\n",
    "\n",
    "# initialize VW as usual, but use 'hook' as the search_task\n",
    "vw = pyvw.Workspace(\"--search 4 --search_task hook\", quiet=True)\n",
    "\n",
    "# tell VW to construct your search task object\n",
    "sequenceLabeler = vw.init_search_task(SequenceLabeler)\n",
    "\n",
    "# train it on the above dataset ten times; the my_dataset.__iter__ feeds into _run above\n",
    "print(\"training!\", file=sys.stderr)\n",
    "for i in range(10):\n",
    "    sequenceLabeler.learn(my_dataset)\n",
    "\n",
    "# now see the predictions on a test sentence\n",
    "print(\"predicting!\", file=sys.stderr)\n",
    "print(sequenceLabeler.predict([(1, w) for w in \"the sandwich ate a monster\".split()]))\n",
    "print(\"should have printed: [1, 2, 3, 1, 2]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "2cc929a270071711921cb2ad25a09768257b52278ee4b98c603d8d8861a97a9a"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('test': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
