{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "FFk0grV-cDLo"
   },
   "outputs": [],
   "source": [
    "# load required packages\n",
    "import pandas as pd\n",
    "import sklearn as sk\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 357,
     "output_extras": [
      {
       "item_id": 8
      }
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 7226,
     "status": "ok",
     "timestamp": 1520197158295,
     "user": {
      "displayName": "Julian Runge",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
      "userId": "115854653587627279362"
     },
     "user_tz": 480
    },
    "id": "OPlledP-8oo9",
    "outputId": "53fe3b47-619f-44c4-805f-88e1b26b35db"
   },
   "outputs": [],
   "source": [
    "# install vowpal wabbit\n",
    "!pip install boost\n",
    "!apt-get install libboost-program-options-dev zlib1g-dev libboost-python-dev -y\n",
    "!pip install vowpalwabbit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ySF3GkA8qe6z"
   },
   "outputs": [],
   "source": [
    "# generate sample data that could originate from previous random trial, e.g. AB test, for the CB to explore\n",
    "## data here are equivalent to example in https://github.com/JohnLangford/vowpal_wabbit/wiki/Contextual-Bandit-Example\n",
    "train_data = [{'action': 1, 'cost': 2, 'probability': 0.4, 'feature1': 'a', 'feature2': 'c', 'feature3': ''},\n",
    "              {'action': 3, 'cost': 0, 'probability': 0.2, 'feature1': 'b', 'feature2': 'd', 'feature3': ''},\n",
    "              {'action': 4, 'cost': 1, 'probability': 0.5, 'feature1': 'a', 'feature2': 'b', 'feature3': ''},\n",
    "              {'action': 2, 'cost': 1, 'probability': 0.3, 'feature1': 'a', 'feature2': 'b', 'feature3': 'c'},\n",
    "              {'action': 3, 'cost': 1, 'probability': 0.7, 'feature1': 'a', 'feature2': 'd', 'feature3': ''}]\n",
    "\n",
    "train_df = pd.DataFrame(train_data)\n",
    "\n",
    "## add index to df\n",
    "train_df['index'] = range(1, len(train_df) + 1)\n",
    "train_df = train_df.set_index(\"index\")\n",
    "\n",
    "# generate some test data that you want the CB to make decisions for, e.g. features describing new users, for the CB to exploit\n",
    "test_data = [{'feature1': 'b', 'feature2': 'c', 'feature3': ''},\n",
    "            {'feature1': 'a', 'feature2': '', 'feature3': 'b'},\n",
    "            {'feature1': 'b', 'feature2': 'b', 'feature3': ''},\n",
    "            {'feature1': 'a', 'feature2': '', 'feature3': 'b'}]\n",
    "\n",
    "test_df = pd.DataFrame(test_data)\n",
    "\n",
    "## add index to df\n",
    "test_df['index'] = range(1, len(test_df) + 1)\n",
    "test_df = test_df.set_index(\"index\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 238,
     "output_extras": [
      {
       "item_id": 1
      }
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 598,
     "status": "ok",
     "timestamp": 1520198312440,
     "user": {
      "displayName": "Julian Runge",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
      "userId": "115854653587627279362"
     },
     "user_tz": 480
    },
    "id": "IDyiQVTJ4EBs",
    "outputId": "a5f7d5e7-7478-4c63-db7d-ca9e14ca259b"
   },
   "outputs": [],
   "source": [
    "# take a look at dataframes\n",
    "print(train_df)\n",
    "print(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85,
     "output_extras": [
      {
       "item_id": 1
      }
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 669,
     "status": "ok",
     "timestamp": 1520198318632,
     "user": {
      "displayName": "Julian Runge",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
      "userId": "115854653587627279362"
     },
     "user_tz": 480
    },
    "id": "pXLrjx6PjKio",
    "outputId": "897a674c-f076-4ba2-8978-f54a8eb9fbc7"
   },
   "outputs": [],
   "source": [
    "# import vowpal wabbit's python wrapper\n",
    "from vowpalwabbit import pyvw\n",
    "\n",
    "# create python model - this stores the model parameters in the python vw object; here a contextual bandit with four possible actions\n",
    "vw = pyvw.vw(\"--cb 4\")\n",
    "\n",
    "# use the learn method to train the vw model, train model row by row using a loop\n",
    "for i in train_df.index:\n",
    "  ## provide data to cb in requested format\n",
    "  action = train_df.loc[i, \"action\"]\n",
    "  cost = train_df.loc[i, \"cost\"]\n",
    "  probability = train_df.loc[i, \"probability\"]\n",
    "  feature1 = train_df.loc[i, \"feature1\"]\n",
    "  feature2 = train_df.loc[i, \"feature2\"]\n",
    "  feature3 = train_df.loc[i, \"feature3\"]\n",
    "  ## do the actual learning\n",
    "  vw.learn(str(action)+\":\"+str(cost)+\":\"+str(probability)+\" | \"+str(feature1)+\" \"+str(feature2)+\" \"+str(feature3))\n",
    "\n",
    "# use the same model object that was trained to perform predictions\n",
    "\n",
    "# predict row by row and output results\n",
    "for j in test_df.index:\n",
    "  feature1 = test_df.loc[j, \"feature1\"]\n",
    "  feature2 = test_df.loc[j, \"feature2\"]\n",
    "  feature3 = test_df.loc[j, \"feature3\"]\n",
    "  choice = vw.predict(\"| \"+str(feature1)+\" \"+str(feature2)+\" \"+str(feature3))\n",
    "  print(j, choice)\n",
    "\n",
    "# the CB assigns every instance to action 3 as it should per the cost structure of the train data; you can play with the cost structure to see that the CB updates its predictions accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34,
     "output_extras": [
      {
       "item_id": 1
      }
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 605,
     "status": "ok",
     "timestamp": 1520198381640,
     "user": {
      "displayName": "Julian Runge",
      "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
      "userId": "115854653587627279362"
     },
     "user_tz": 480
    },
    "id": "60bK90HlThA2",
    "outputId": "845ef0c1-465c-4748-b5bb-6b9d332a6820"
   },
   "outputs": [],
   "source": [
    "# BONUS: save and load the CB model\n",
    "# save model\n",
    "vw.save('cb.model')\n",
    "del vw\n",
    "# load from saved file\n",
    "vw = pyvw.vw(\"--cb 4 -i cb.model\")\n",
    "print(vw.predict('| a b'))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "Implementing a Contextual Bandit Using VW's Python Wrapper.ipynb",
   "provenance": [
    {
     "file_id": "1Njy1txYPXqVwueHudbkF40zTbwkui1xA",
     "timestamp": 1519781379506
    },
    {
     "file_id": "11qz1CSi8-8yQACKJzp3G8VPUqzB3V1Hw",
     "timestamp": 1519780916480
    }
   ],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
