{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Slates\n",
    "\n",
    "## Combinatorial bandits\n",
    "A common setting when using bandits is the desire to make multiple decisions jointly as each individual decision could impact the outcome of the others. For example, when optimizing the layout of a web page, you have a list of context boxes to decide upon, one usually only has whole-page metrics so the impact of all decisions are all connected and cant be associated with any single decision.\n",
    "\n",
    "One way to address this is, instead of making multiple decisions for each box, is make single decision with a combinatorial number of arms to cover every combination of individual decisions. This combinatorial explosion in the number of arms make traditional Contextual Bandits methods such as IPS (Inverse Propensity Score) untractable and requiring a large amount of data.\n",
    "\n",
    "## The Slates problem setting\n",
    "The [Slates](http://arxiv.org/abs/1605.04812) setting is an extension to the Contextual Bandits one. With Slates, there's a set of slots (a slot is akin to a decision) and each slot has a set of associated actions. When making a decision, an action is selected for each slot based on the policy and which slot they are associated with.\n",
    "\n",
    "A single, global, reward is produced that signals the outcome of an entire slate of decisions. There's a linearity assumption on the impact of each action on the observed reward.\n",
    "\n",
    "## Vowpal Wabbit Slates example format\n",
    "\n",
    "```\n",
    "slates shared | a\n",
    "slates action 0 | a\n",
    "slates action 0 | b\n",
    "slates action 0 | c\n",
    "slates action 1 | d\n",
    "slates action 1 | e\n",
    "slates slot | x\n",
    "slates slot | y\n",
    "```\n",
    "\n",
    "`slates shared` is the shared context for all of the decisions. Often these are user related features. Each `slates action` needs to tell which slot they are associated with, slots are 0-indexed. Finally, `slates slot` is used to list the features of each slot.\n",
    "\n",
    "## Scenario: Outfit optimization\n",
    "\n",
    "In this example our goal is to select the best outfit for the occasion, in this case it is a job interview. The shared context represents the kind of job being interviewed for, in our example it's either `corporate` or `trade`. \n",
    "\n",
    "- Each slot represent the clothing type, here it will be `torso`, `legs` and `footwear`\n",
    "- The clothing items are represented by each action\n",
    "- The reward is how suitable an outfit choice is for a given ocasion\n",
    "\n",
    "We will compare using Slates with CB to see its effectiveness.\n",
    "\n",
    "![slates_scenario.png](./slates_scenario.png)\n",
    "\n",
    "### Part One - setup\n",
    "\n",
    "Let's start by importing the python libraries we'll use in this example:\n",
    "\t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import vowpalwabbit\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import random\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, the simulator parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ITERATIONS = 2500\n",
    "shared_contexts = [\"corporate\", \"trade\"]\n",
    "torso_items = [\"tshirt\", \"buttonupshirt\", \"highvis\"]\n",
    "legs_items = [\"workpants\", \"formalpants\", \"shorts\"]\n",
    "feet_items = [\"formalshoes\", \"runners\", \"flipflops\", \"boots\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reward function\n",
    "Our reward function receives as arguments the shared context (the occasion) and the selected actions for each slot (the types of clothing). In order to reflect real world problems, and make it more dificult, we inject normal noise in the reward value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def noise(center, stddev=0.075):\n",
    "    return np.random.normal(loc=center, scale=0.075)\n",
    "\n",
    "\n",
    "def reward_function(shared_context, torso_index, legs_index, feet_index):\n",
    "    if shared_context == \"corporate\":\n",
    "        torso_values = [noise(0.2), noise(0.3), noise(0.1)]\n",
    "        legs_val = [noise(0.1), noise(0.3), noise(0.2)]\n",
    "        feet_values = [noise(0.4), noise(0.3), noise(0.05), noise(0.1)]\n",
    "    if shared_context == \"trade\":\n",
    "        torso_values = [noise(0.15), noise(0.2), noise(0.3)]\n",
    "        legs_val = [noise(0.4), noise(0.2), noise(0.35)]\n",
    "        feet_values = [noise(0.15), noise(0.2), noise(0.1), noise(0.3)]\n",
    "\n",
    "    return torso_values[torso_index] + legs_val[legs_index] + feet_values[feet_index]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Slates\n",
    "The following piece of code generate examples that will be used for inference and training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_slates_text_format(shared_context):\n",
    "    return [\n",
    "        f\"slates shared |User {shared_context}\",\n",
    "        \"slates action 0 |Action tshirt\",\n",
    "        \"slates action 0 |Action buttonupshirt\",\n",
    "        \"slates action 0 |Action highvis\",\n",
    "        \"slates action 1 |Action workpants\",\n",
    "        \"slates action 1 |Action formalpants\",\n",
    "        \"slates action 1 |Action shorts\",\n",
    "        \"slates action 2 |Action formalshoes\",\n",
    "        \"slates action 2 |Action runners\",\n",
    "        \"slates action 2 |Action flipflops\",\n",
    "        \"slates action 2 |Action boots\",\n",
    "        \"slates slot |Slot torso\",\n",
    "        \"slates slot |Slot legs\",\n",
    "        \"slates slot |Slot feet\",\n",
    "    ]\n",
    "\n",
    "\n",
    "def generate_slates_text_format_with_label(\n",
    "    shared_context,\n",
    "    reward,\n",
    "    chosen_torso_index,\n",
    "    chosen_torso_prob,\n",
    "    chosen_legs_index,\n",
    "    chosen_legs_prob,\n",
    "    chosen_feet_index,\n",
    "    chosen_feet_prob,\n",
    "):\n",
    "    return [\n",
    "        f\"slates shared {-1*reward} |User {shared_context}\",\n",
    "        \"slates action 0 |Action tshirt\",\n",
    "        \"slates action 0 |Action buttonupshirt\",\n",
    "        \"slates action 0 |Action highvis\",\n",
    "        \"slates action 1 |Action workpants\",\n",
    "        \"slates action 1 |Action formalpants\",\n",
    "        \"slates action 1 |Action shorts\",\n",
    "        \"slates action 2 |Action formalshoes\",\n",
    "        \"slates action 2 |Action runners\",\n",
    "        \"slates action 2 |Action flipflops\",\n",
    "        \"slates action 2 |Action boots\",\n",
    "        f\"slates slot {chosen_torso_index}:{chosen_torso_prob} |Slot torso\",\n",
    "        f\"slates slot {chosen_legs_index}:{chosen_legs_prob} |Slot legs\",\n",
    "        f\"slates slot {chosen_feet_index}:{chosen_feet_prob} |Slot feet\",\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Slates simulation \n",
    "Let's create the `Workspace` instance we'll use with this simulator "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "slates_vw = vowpalwabbit.Workspace(\n",
    "    \"--slates --epsilon 0.2 --interactions SA UAS US UA -l 0.05 --power_t 0\", quiet=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice we're passing `--slates` which enable usage of the Slates algorithm. For an explanation of the other arguments, go back to the [Command Line Tutorial](https://vowpalwabbit.org/docs/vowpal_wabbit/python/latest/tutorials/cmd_first_steps.html)]\n",
    "\n",
    "Now, the core simulation loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "slates_rewards = []\n",
    "for _ in range(NUM_ITERATIONS):\n",
    "    shared_context = random.choice(shared_contexts)\n",
    "    slates_prediction = slates_vw.predict(generate_slates_text_format(shared_context))\n",
    "    torso_index, torso_prob = slates_prediction[0][0]\n",
    "    legs_index, legs_prob = slates_prediction[1][0]\n",
    "    feet_index, feet_prob = slates_prediction[2][0]\n",
    "    reward = reward_function(shared_context, torso_index, legs_index, feet_index)\n",
    "    slates_rewards.append(reward)\n",
    "    slates_vw.learn(\n",
    "        generate_slates_text_format_with_label(\n",
    "            shared_context,\n",
    "            reward,\n",
    "            torso_index,\n",
    "            torso_prob,\n",
    "            legs_index,\n",
    "            legs_prob,\n",
    "            feet_index,\n",
    "            feet_prob,\n",
    "        )\n",
    "    )\n",
    "\n",
    "slates_vw.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Contextual Bandit\n",
    "\n",
    "Let's try to solve the same problem by using CB with combinatorial actions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_combinations(shared_context, torso_items, legs_items, feet_items):\n",
    "    examples = [f\"shared |User {shared_context}\"]\n",
    "    descriptions = []\n",
    "    for i, torso in enumerate(torso_items):\n",
    "        for j, legs in enumerate(legs_items):\n",
    "            for k, feet in enumerate(feet_items):\n",
    "                examples.append(f\"|Action torso={torso} legs={legs} feet={feet}\")\n",
    "                descriptions.append((i, j, k))\n",
    "\n",
    "    return examples, descriptions\n",
    "\n",
    "\n",
    "def sample_custom_pmf(pmf):\n",
    "    total = sum(pmf)\n",
    "    scale = 1 / total\n",
    "    pmf = [x * scale for x in pmf]\n",
    "    draw = random.random()\n",
    "    sum_prob = 0.0\n",
    "    for index, prob in enumerate(pmf):\n",
    "        sum_prob += prob\n",
    "        if sum_prob > draw:\n",
    "            return index, prob\n",
    "\n",
    "\n",
    "cb_vw = vowpalwabbit.Workspace(\n",
    "    \"--cb_explore_adf --epsilon 0.2 --interactions AA AU AAU -l 0.05 --power_t 0\",\n",
    "    quiet=True,\n",
    ")\n",
    "\n",
    "cb_rewards = []\n",
    "for _ in range(NUM_ITERATIONS):\n",
    "    shared_context = random.choice(shared_contexts)\n",
    "    examples, indices = generate_combinations(\n",
    "        shared_context, torso_items, legs_items, feet_items\n",
    "    )\n",
    "    cb_prediction = cb_vw.predict(examples)\n",
    "    chosen_index, prob = sample_custom_pmf(cb_prediction)\n",
    "    torso_index, legs_index, feet_index = indices[chosen_index]\n",
    "    reward = reward_function(shared_context, torso_index, legs_index, feet_index)\n",
    "    cb_rewards.append(reward)\n",
    "    examples[chosen_index + 1] = f\"0:{-1*reward}:{prob} {examples[chosen_index + 1]}\"\n",
    "    cb_vw.learn(examples)\n",
    "\n",
    "cb_vw.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparison\n",
    "\n",
    "Let's plot the two runs so we can compare the result:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(pd.Series(cb_rewards).expanding().mean())\n",
    "plt.plot(pd.Series(slates_rewards).expanding().mean())\n",
    "plt.xlabel(\"Iterations\")\n",
    "plt.ylabel(\"Average reward\")\n",
    "plt.legend([\"cb\", \"slates\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, Slates learns a better policy with lower regret.\n",
    "\n",
    "## Summary\n",
    "\n",
    "In this tutorial, we learned what Slates is, how to use it with VW and how it compares to solving the same problem with Contextual Bandits.\n",
    "\n",
    "For further information, visit the following links:\n",
    "\n",
    "- [Slate input format](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Slates)\n",
    "- [Examples from the original publication](https://github.com/adith387/slates_semisynth_expts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
