{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Contextual Bandits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "FFk0grV-cDLo"
   },
   "outputs": [],
   "source": [
    "import vowpalwabbit\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ySF3GkA8qe6z"
   },
   "outputs": [],
   "source": [
    "# generate sample data that could originate from previous random trial, e.g. AB test, for the CB to explore\n",
    "## data here are equivalent to example in https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Logged-Contextual-Bandit-Example\n",
    "train_data = [\n",
    "    {\n",
    "        \"action\": 1,\n",
    "        \"cost\": 2,\n",
    "        \"probability\": 0.4,\n",
    "        \"feature1\": \"a\",\n",
    "        \"feature2\": \"c\",\n",
    "        \"feature3\": \"\",\n",
    "    },\n",
    "    {\n",
    "        \"action\": 3,\n",
    "        \"cost\": 0,\n",
    "        \"probability\": 0.2,\n",
    "        \"feature1\": \"b\",\n",
    "        \"feature2\": \"d\",\n",
    "        \"feature3\": \"\",\n",
    "    },\n",
    "    {\n",
    "        \"action\": 4,\n",
    "        \"cost\": 1,\n",
    "        \"probability\": 0.5,\n",
    "        \"feature1\": \"a\",\n",
    "        \"feature2\": \"b\",\n",
    "        \"feature3\": \"\",\n",
    "    },\n",
    "    {\n",
    "        \"action\": 2,\n",
    "        \"cost\": 1,\n",
    "        \"probability\": 0.3,\n",
    "        \"feature1\": \"a\",\n",
    "        \"feature2\": \"b\",\n",
    "        \"feature3\": \"c\",\n",
    "    },\n",
    "    {\n",
    "        \"action\": 3,\n",
    "        \"cost\": 1,\n",
    "        \"probability\": 0.7,\n",
    "        \"feature1\": \"a\",\n",
    "        \"feature2\": \"d\",\n",
    "        \"feature3\": \"\",\n",
    "    },\n",
    "]\n",
    "\n",
    "train_df = pd.DataFrame(train_data)\n",
    "\n",
    "## add index to df\n",
    "train_df[\"index\"] = range(1, len(train_df) + 1)\n",
    "train_df = train_df.set_index(\"index\")\n",
    "\n",
    "# generate some test data that you want the CB to make decisions for, e.g. features describing new users, for the CB to exploit\n",
    "test_data = [\n",
    "    {\"feature1\": \"b\", \"feature2\": \"c\", \"feature3\": \"\"},\n",
    "    {\"feature1\": \"a\", \"feature2\": \"\", \"feature3\": \"b\"},\n",
    "    {\"feature1\": \"b\", \"feature2\": \"b\", \"feature3\": \"\"},\n",
    "    {\"feature1\": \"a\", \"feature2\": \"\", \"feature3\": \"b\"},\n",
    "]\n",
    "\n",
    "test_df = pd.DataFrame(test_data)\n",
    "\n",
    "## add index to df\n",
    "test_df[\"index\"] = range(1, len(test_df) + 1)\n",
    "test_df = test_df.set_index(\"index\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 238,
     "output_extras": [
      {
       "item_id": 1
      }
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 598,
     "status": "ok",
     "timestamp": 1520198312440,
     "user": {
      "displayName": "Julian Runge",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
      "userId": "115854653587627279362"
     },
     "user_tz": 480
    },
    "id": "IDyiQVTJ4EBs",
    "outputId": "a5f7d5e7-7478-4c63-db7d-ca9e14ca259b"
   },
   "outputs": [],
   "source": [
    "# take a look at dataframes\n",
    "print(train_df)\n",
    "print(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85,
     "output_extras": [
      {
       "item_id": 1
      }
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 669,
     "status": "ok",
     "timestamp": 1520198318632,
     "user": {
      "displayName": "Julian Runge",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
      "userId": "115854653587627279362"
     },
     "user_tz": 480
    },
    "id": "pXLrjx6PjKio",
    "outputId": "897a674c-f076-4ba2-8978-f54a8eb9fbc7"
   },
   "outputs": [],
   "source": [
    "# create python model - this stores the model parameters in the python vw object; here a contextual bandit with four possible actions\n",
    "vw = vowpalwabbit.Workspace(\"--cb 4\", quiet=True)\n",
    "\n",
    "# use the learn method to train the vw model, train model row by row using a loop\n",
    "for i in train_df.index:\n",
    "    ## provide data to cb in requested format\n",
    "    action = train_df.loc[i, \"action\"]\n",
    "    cost = train_df.loc[i, \"cost\"]\n",
    "    probability = train_df.loc[i, \"probability\"]\n",
    "    feature1 = train_df.loc[i, \"feature1\"]\n",
    "    feature2 = train_df.loc[i, \"feature2\"]\n",
    "    feature3 = train_df.loc[i, \"feature3\"]\n",
    "    ## do the actual learning\n",
    "    vw.learn(\n",
    "        str(action)\n",
    "        + \":\"\n",
    "        + str(cost)\n",
    "        + \":\"\n",
    "        + str(probability)\n",
    "        + \" | \"\n",
    "        + str(feature1)\n",
    "        + \" \"\n",
    "        + str(feature2)\n",
    "        + \" \"\n",
    "        + str(feature3)\n",
    "    )\n",
    "\n",
    "# use the same model object that was trained to perform predictions\n",
    "\n",
    "# predict row by row and output results\n",
    "for j in test_df.index:\n",
    "    feature1 = test_df.loc[j, \"feature1\"]\n",
    "    feature2 = test_df.loc[j, \"feature2\"]\n",
    "    feature3 = test_df.loc[j, \"feature3\"]\n",
    "    choice = vw.predict(\n",
    "        \"| \" + str(feature1) + \" \" + str(feature2) + \" \" + str(feature3)\n",
    "    )\n",
    "    print(j, choice)\n",
    "\n",
    "# the CB assigns every instance to action 3 as it should per the cost structure of the train data; you can play with the cost structure to see that the CB updates its predictions accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34,
     "output_extras": [
      {
       "item_id": 1
      }
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 605,
     "status": "ok",
     "timestamp": 1520198381640,
     "user": {
      "displayName": "Julian Runge",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
      "userId": "115854653587627279362"
     },
     "user_tz": 480
    },
    "id": "60bK90HlThA2",
    "outputId": "845ef0c1-465c-4748-b5bb-6b9d332a6820"
   },
   "outputs": [],
   "source": [
    "# BONUS: save and load the CB model\n",
    "# save model\n",
    "vw.save(\"cb.model\")\n",
    "del vw\n",
    "# load from saved file\n",
    "vw = vowpalwabbit.Workspace(\"--cb 4 -i cb.model\", quiet=True)\n",
    "print(vw.predict(\"| a b\"))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "Implementing a Contextual Bandit Using VW's Python Wrapper.ipynb",
   "provenance": [
    {
     "file_id": "1Njy1txYPXqVwueHudbkF40zTbwkui1xA",
     "timestamp": 1519781379506
    },
    {
     "file_id": "11qz1CSi8-8yQACKJzp3G8VPUqzB3V1Hw",
     "timestamp": 1519780916480
    }
   ],
   "version": "0.3.2",
   "views": {}
  },
  "interpreter": {
   "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
  },
  "kernelspec": {
   "display_name": "Python 3.8.5 64-bit",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
