{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c06fbf85",
   "metadata": {},
   "source": [
    "# [Experimental] Testing Basic Models with Varying Epsilon Values and Model Counts for Non-Stationary Epsilon Decay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfefd86f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import vowpalwabbit\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import itertools\n",
    "import numpy as np\n",
    "from matplotlib.pyplot import figure\n",
    "\n",
    "users = [\"Tom\", \"Anna\"]\n",
    "times_of_day = [\"morning\", \"afternoon\"]\n",
    "actions = [\"politics\", \"sports\", \"music\", \"food\", \"finance\", \"health\", \"camping\"]\n",
    "\n",
    "USER_LIKED_ARTICLE = -1.0\n",
    "USER_DISLIKED_ARTICLE = 0.0\n",
    "\n",
    "\n",
    "def get_cost(context, action, switch):\n",
    "    if context[\"user\"] == \"Tom\":\n",
    "        if context[\"time_of_day\"] == \"morning\" and action == \"politics\" and not switch:\n",
    "            return USER_LIKED_ARTICLE\n",
    "        elif context[\"time_of_day\"] == \"afternoon\" and action == \"music\" and not switch:\n",
    "            return USER_LIKED_ARTICLE\n",
    "        if context[\"time_of_day\"] == \"morning\" and action == \"health\" and switch:\n",
    "            return USER_LIKED_ARTICLE\n",
    "        elif context[\"time_of_day\"] == \"afternoon\" and action == \"camping\" and switch:\n",
    "            return USER_LIKED_ARTICLE\n",
    "    elif context[\"user\"] == \"Anna\":\n",
    "        if context[\"time_of_day\"] == \"morning\" and action == \"sports\" and not switch:\n",
    "            return USER_LIKED_ARTICLE\n",
    "        elif (\n",
    "            context[\"time_of_day\"] == \"afternoon\"\n",
    "            and action == \"politics\"\n",
    "            and not switch\n",
    "        ):\n",
    "            return USER_LIKED_ARTICLE\n",
    "        if context[\"time_of_day\"] == \"morning\" and action == \"food\" and switch:\n",
    "            return USER_LIKED_ARTICLE\n",
    "        elif context[\"time_of_day\"] == \"afternoon\" and action == \"finance\" and switch:\n",
    "            return USER_LIKED_ARTICLE\n",
    "\n",
    "    return USER_DISLIKED_ARTICLE\n",
    "\n",
    "\n",
    "# This function modifies (context, action, cost, probability) to VW friendly format\n",
    "def to_vw_example_format(context, actions, cb_label=None):\n",
    "    if cb_label is not None:\n",
    "        chosen_action, cost, prob = cb_label\n",
    "    example_string = \"\"\n",
    "    example_string += \"shared |User user={} time_of_day={}\\n\".format(\n",
    "        context[\"user\"], context[\"time_of_day\"]\n",
    "    )\n",
    "    for action in actions:\n",
    "        if cb_label is not None and action == chosen_action:\n",
    "            example_string += \"0:{}:{} \".format(cost, prob)\n",
    "        example_string += \"|Action article={} \\n\".format(action)\n",
    "    # Strip the last newline\n",
    "    return example_string[:-1]\n",
    "\n",
    "\n",
    "def sample_custom_pmf(pmf):\n",
    "    total = sum(pmf)\n",
    "    scale = 1 / total\n",
    "    pmf = [x * scale for x in pmf]\n",
    "    draw = random.random()\n",
    "    sum_prob = 0.0\n",
    "    for index, prob in enumerate(pmf):\n",
    "        sum_prob += prob\n",
    "        if sum_prob > draw:\n",
    "            return index, prob\n",
    "\n",
    "\n",
    "def get_action(vw, context, actions):\n",
    "    vw_text_example = to_vw_example_format(context, actions)\n",
    "    pmf = vw.predict(vw_text_example)\n",
    "    chosen_action_index, prob = sample_custom_pmf(pmf)\n",
    "    return actions[chosen_action_index], prob\n",
    "\n",
    "\n",
    "def choose_user(users):\n",
    "    return random.choice(users)\n",
    "\n",
    "\n",
    "def choose_time_of_day(times_of_day):\n",
    "    return random.choice(times_of_day)\n",
    "\n",
    "\n",
    "# display preference matrix\n",
    "def get_preference_matrix(cost_fun):\n",
    "    def expand_grid(data_dict):\n",
    "        rows = itertools.product(*data_dict.values())\n",
    "        return pd.DataFrame.from_records(rows, columns=data_dict.keys())\n",
    "\n",
    "    df = expand_grid({\"users\": users, \"times_of_day\": times_of_day, \"actions\": actions})\n",
    "    df[\"cost\"] = df.apply(\n",
    "        lambda r: cost_fun({\"user\": r[0], \"time_of_day\": r[1]}, r[2]), axis=1\n",
    "    )\n",
    "\n",
    "    return df.pivot_table(\n",
    "        index=[\"users\", \"times_of_day\"], columns=\"actions\", values=\"cost\"\n",
    "    )\n",
    "\n",
    "\n",
    "def run_simulation(\n",
    "    vw,\n",
    "    num_iterations,\n",
    "    users,\n",
    "    times_of_day,\n",
    "    actions,\n",
    "    cost_function,\n",
    "    switch_rewards,\n",
    "    do_learn=True,\n",
    "):\n",
    "    cost_sum = 0.0\n",
    "    ctr = []\n",
    "    random.seed(5)\n",
    "\n",
    "    for i in range(1, num_iterations + 1):\n",
    "        # 1. In each simulation choose a user\n",
    "        user = choose_user(users)\n",
    "        # 2. Choose time of day for a given user\n",
    "        time_of_day = choose_time_of_day(times_of_day)\n",
    "\n",
    "        # 3. Pass context to vw to get an action\n",
    "        context = {\"user\": user, \"time_of_day\": time_of_day}\n",
    "        action, prob = get_action(vw, context, actions)\n",
    "\n",
    "        # 4. Get cost of the action we chose\n",
    "        cost = cost_function(context, action, i > switch_rewards)\n",
    "        cost_sum += cost\n",
    "\n",
    "        if do_learn:\n",
    "            # 5. Inform VW of what happened so we can learn from it\n",
    "            vw_format = vw.parse(\n",
    "                to_vw_example_format(context, actions, (action, cost, prob)),\n",
    "                vowpalwabbit.LabelType.CONTEXTUAL_BANDIT,\n",
    "            )\n",
    "            # 6. Learn\n",
    "            vw.learn(vw_format)\n",
    "\n",
    "        # We negate this so that on the plot instead of minimizing cost, we are maximizing reward\n",
    "        ctr.append(-1 * cost_sum / i)\n",
    "\n",
    "    return ctr\n",
    "\n",
    "\n",
    "def plot_ctr(num_iterations, ctr):\n",
    "    plt.plot(range(1, num_iterations + 1), ctr)\n",
    "    plt.xlabel(\"num_iterations\", fontsize=14)\n",
    "    plt.ylabel(\"ctr\", fontsize=14)\n",
    "    plt.ylim([0, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbea157d",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations = 100000\n",
    "switch_rewards = 25000\n",
    "\n",
    "leg = []\n",
    "figure(figsize=(15, 9), dpi=80)\n",
    "\n",
    "for ep in np.linspace(0, 0.1, 11):\n",
    "    vw = vowpalwabbit.Workspace(\"--cb_explore_adf -q UA --quiet --epsilon \" + str(ep))\n",
    "    ctr = run_simulation(\n",
    "        vw, num_iterations, users, times_of_day, actions, get_cost, switch_rewards\n",
    "    )\n",
    "    plt.plot(range(1, num_iterations + 1), ctr)\n",
    "    leg.append(str(ep))\n",
    "\n",
    "# Instantiate learner in VW non-stationary epsilon\n",
    "vw = vowpalwabbit.Workspace(\n",
    "    \"--cb_explore_adf -q UA --quiet --epsilon_decay --model_count 1\"\n",
    ")\n",
    "ctr = run_simulation(\n",
    "    vw, num_iterations, users, times_of_day, actions, get_cost, switch_rewards\n",
    ")\n",
    "plt.plot(range(1, num_iterations + 1), ctr, \"--\")\n",
    "leg.append(\"Non-Stationary 1 Model\")\n",
    "\n",
    "# Instantiate learner in VW non-stationary epsilon\n",
    "vw = vowpalwabbit.Workspace(\n",
    "    \"--cb_explore_adf -q UA --quiet --epsilon_decay --model_count 4\"\n",
    ")\n",
    "ctr = run_simulation(\n",
    "    vw, num_iterations, users, times_of_day, actions, get_cost, switch_rewards\n",
    ")\n",
    "plt.plot(range(1, num_iterations + 1), ctr, \"--\")\n",
    "leg.append(\"Non-Stationary 4 Models\")\n",
    "\n",
    "# Instantiate learner in VW non-stationary epsilon\n",
    "vw = vowpalwabbit.Workspace(\n",
    "    \"--cb_explore_adf -q UA --quiet --epsilon_decay --model_count 32\"\n",
    ")\n",
    "ctr = run_simulation(\n",
    "    vw, num_iterations, users, times_of_day, actions, get_cost, switch_rewards\n",
    ")\n",
    "plt.plot(range(1, num_iterations + 1), ctr, \"--\")\n",
    "leg.append(\"Non-Stationary 32 Models\")\n",
    "\n",
    "plt.title(\"Accuracy of Models with Varying Epsilon Values\", fontsize=17)\n",
    "plt.xlabel(\"Number of Iterations\", fontsize=14)\n",
    "plt.ylabel(\"Click Through Rate\", fontsize=14)\n",
    "plt.ylim([0.8, 1])\n",
    "plt.legend(leg)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ae05310",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations = 1000000\n",
    "switch_rewards = 250000\n",
    "\n",
    "leg = []\n",
    "figure(figsize=(15, 9), dpi=80)\n",
    "\n",
    "# Instantiate learner in VW non-stationary epsilon\n",
    "vw = vowpalwabbit.Workspace(\n",
    "    \"--cb_explore_adf -q UA --quiet --epsilon_decay --model_count 1\"\n",
    ")\n",
    "ctr = run_simulation(\n",
    "    vw, num_iterations, users, times_of_day, actions, get_cost, switch_rewards\n",
    ")\n",
    "plt.plot(range(1, num_iterations + 1), ctr, \"--\")\n",
    "leg.append(\"Non-Stationary 1 Model\")\n",
    "\n",
    "# Instantiate learner in VW non-stationary epsilon\n",
    "vw = vowpalwabbit.Workspace(\n",
    "    \"--cb_explore_adf -q UA --quiet --epsilon_decay --model_count 4\"\n",
    ")\n",
    "ctr = run_simulation(\n",
    "    vw, num_iterations, users, times_of_day, actions, get_cost, switch_rewards\n",
    ")\n",
    "plt.plot(range(1, num_iterations + 1), ctr, \"--\")\n",
    "leg.append(\"Non-Stationary 4 Models\")\n",
    "\n",
    "# Instantiate learner in VW non-stationary epsilon\n",
    "vw = vowpalwabbit.Workspace(\n",
    "    \"--cb_explore_adf -q UA --quiet --epsilon_decay --model_count 32\"\n",
    ")\n",
    "ctr = run_simulation(\n",
    "    vw, num_iterations, users, times_of_day, actions, get_cost, switch_rewards\n",
    ")\n",
    "plt.plot(range(1, num_iterations + 1), ctr, \"--\")\n",
    "leg.append(\"Non-Stationary 32 Models\")\n",
    "\n",
    "plt.title(\"Accuracy of Models with Varying Non-Stationary Model Counts\", fontsize=17)\n",
    "plt.xlabel(\"Number of Iterations\", fontsize=14)\n",
    "plt.ylabel(\"Click Through Rate\", fontsize=14)\n",
    "plt.ylim([0.8, 1])\n",
    "plt.legend(leg)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "mystnb": {
    "execution_mode": "off"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
