{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from vowpalwabbit import pyvw\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate some count data that has poisson distribution \n",
    "# z ~ poisson(x + y), x \\in [0,10), y \\in [0,10)\n",
    "x = np.random.choice(range(0,10), 100)\n",
    "y = np.random.choice(range(0,10), 100)\n",
    "z = np.random.poisson(x + y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will model this data in two ways\n",
    "* log transform the labels and use linear prediction (square loss)\n",
    "* model it directly using poisson loss\n",
    "\n",
    "The first model predicts mean(log(label)) the second predicts log(mean(label)). Due to [Jensen's inequality](https://en.wikipedia.org/wiki/Jensen's_inequality), the first approach produces systematic negative bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train log-transform model\n",
    "training_samples = []\n",
    "logz = np.log(0.001 + z)\n",
    "vw = pyvw.vw(\"-b 2 --loss_function squared -l 0.1 --holdout_off -f vw.log.model --readable_model vw.readable.log.model\")\n",
    "for i in range(len(logz)):\n",
    "    training_samples.append(\"{label} | x:{x} y:{y}\".format(label=logz[i], x=x[i], y=y[i]))\n",
    "# Do hundred passes over the data and store the model in vw.log.model\n",
    "for iteration in range(100):\n",
    "    for i in range(len(training_samples)):\n",
    "        vw.learn(training_samples[i])\n",
    "vw.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate predictions from the log-transform model\n",
    "vw = pyvw.vw(\"-i vw.log.model -t\")\n",
    "log_predictions = [vw.predict(sample) for sample in training_samples]\n",
    "# Measure bias in the log-domain\n",
    "log_bias = np.mean(log_predictions - logz)\n",
    "bias = np.mean(np.exp(log_predictions) - z)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Although the model is relatively unbiased in the log-domain where we trained our model, in the original domain there is underprediction as we expected from Jensenn's inequality "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train original domain model using poisson regression\n",
    "training_samples = []\n",
    "vw = pyvw.vw(\"-b 2 --loss_function poisson -l 0.1 --holdout_off -f vw.poisson.model --readable_model vw.readable.poisson.model\")\n",
    "for i in range(len(z)):\n",
    "    training_samples.append(\"{label} | x:{x} y:{y}\".format(label=z[i], x=x[i], y=y[i]))\n",
    "# Do hundred passes over the data and store the model in vw.log.model\n",
    "for iteration in range(100):\n",
    "    for i in range(len(training_samples)):\n",
    "        vw.learn(training_samples[i])\n",
    "vw.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate predictions from the poisson model\n",
    "vw = pyvw.vw(\"-i vw.poisson.model\")\n",
    "poisson_predictions = [np.exp(vw.predict(sample)) for sample in training_samples]\n",
    "poisson_bias = np.mean(poisson_predictions - z)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.figure(figsize=(18,6))\n",
    "# Measure bias in the log-domain\n",
    "plt.subplot(131)\n",
    "plt.plot(logz, log_predictions, '.')\n",
    "plt.plot(logz, logz, 'r')\n",
    "plt.title('Log-domain bias:%f'%(log_bias))\n",
    "plt.xlabel('label')\n",
    "plt.ylabel('prediction')\n",
    "\n",
    "plt.subplot(132)\n",
    "plt.plot(z, np.exp(log_predictions), '.')\n",
    "plt.plot(z, z, 'r')\n",
    "plt.title('Original-domain bias:%f'%(bias))\n",
    "plt.xlabel('label')\n",
    "plt.ylabel('prediction')\n",
    "\n",
    "plt.subplot(133)\n",
    "plt.plot(z, poisson_predictions, '.')\n",
    "plt.plot(z, z, 'r')\n",
    "plt.title('Poisson bias:%f'%(poisson_bias))\n",
    "plt.xlabel('label')\n",
    "plt.ylabel('prediction')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
