{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pandas DataFrame to VowpalWabbit Format Conversion\n",
    "\n",
    "\n",
    "```{admonition} Credit\n",
    "Thank you [@etiennekintzler](https://github.com/etiennekintzler) for contributing this tutorial.\n",
    "```\n",
    "\n",
    "The purpose of this tutorial is to show how to use the {py:class}`vowpalwabbit.DFtoVW.DFtoVW` class to convert a pandas's DataFrame into a list of Vowpal Wabbit examples and to explore the outputs (model weights, VW output log) of the trained model. The VW output log is parsed using the class `VWLogParser` defined in this notebook.\n",
    "\n",
    "The task is to predict the concentration of [particulate matter](https://en.wikipedia.org/wiki/Particulates) (more specifically PM 2.5) in the atmosphere of 5 chinese cities. The original dataset contains 19 columns (targets, datetime and atmospheric features) and 167 358 observations.\n",
    "- For more details on the data, see the following UCI repository: https://archive.ics.uci.edu/ml/datasets/PM2.5+Data+of+Five+Chinese+Cities\n",
    "- For the associated academic papers, see [Liang, Xuan, et al. \"PM2. 5 data reliability, consistency, and air quality assessment in five Chinese cities.\"](https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2016JD024877) and [Liang, Xuan, et al. \"Assessing Beijing's PM2. 5 pollution: severity, weather impact, APEC and winter heating.\"](https://royalsocietypublishing.org/doi/10.1098/rspa.2015.0257)\n",
    "\n",
    "The data can be downloaded from the following URL: https://archive.ics.uci.edu/ml/machine-learning-databases/00394/FiveCitiePMData.rar. A function `download_data` is available in this notebook to download and extract the data (but this step can also be done manually). The folder containing the data is defined using the constant variable `DATA_FOLDER` (it is set to`'PM_DATA'` by default).\n",
    "\n",
    "## Tutorial outline\n",
    "\n",
    "**1. Data**\n",
    "\n",
    "**2. Train a first model**\n",
    "\n",
    "**3. Visualizing model's outputs**\n",
    "\n",
    "**4. Train a more complex model**\n",
    "\n",
    "\n",
    "## Requirements\n",
    "\n",
    "The notebook was developed for VW 8.11.0.\n",
    "\n",
    "It should work with older versions (>= 8.10) except for one cell in section 4.3 where the attribute `name` of `Feature` is accessed.\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "## Importing packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join\n",
    "import re\n",
    "\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "from vowpalwabbit.dftovw import DFtoVW\n",
    "from vowpalwabbit import Workspace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Function and class definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# VW output parsing function/class\n",
    "class VWLogParser:\n",
    "    \"\"\"Parser for Vowpal Wabbit output log\"\"\"\n",
    "\n",
    "    def __init__(self, file_path_or_list):\n",
    "        \"\"\"The file name or list of lines to parse\"\"\"\n",
    "        if isinstance(file_path_or_list, (list, str)):\n",
    "            self.file_path_or_list = file_path_or_list\n",
    "        else:\n",
    "            raise TypeError(\n",
    "                \"Argument `fname` should be a str (for file path) or a list of log lines\"\n",
    "            )\n",
    "\n",
    "    def parse(self):\n",
    "        \"\"\"Parse the output from `vw` command, return dataframe/dictionnaries with the associated data.\"\"\"\n",
    "        # Init containers\n",
    "        self.table_lst = []\n",
    "        self.params = {}\n",
    "        self.metrics = {}\n",
    "\n",
    "        self.inside_table = False\n",
    "        self.after_table = False\n",
    "\n",
    "        if isinstance(self.file_path_or_list, list):\n",
    "            for row in self.file_path_or_list:\n",
    "                self._parse_vw_row(row)\n",
    "        else:\n",
    "            with open(self.file_path_or_list, \"r\") as f:\n",
    "                for row in f:\n",
    "                    self._parse_vw_row(row)\n",
    "\n",
    "        self.df = self._make_output_df(self.table_lst)\n",
    "\n",
    "        return self.params, self.df, self.metrics\n",
    "\n",
    "    def _cast_string(self, s):\n",
    "        \"\"\"Cast to float or int if possible\"\"\"\n",
    "        try:\n",
    "            out = float(s)\n",
    "        except ValueError:\n",
    "            out = s\n",
    "        else:\n",
    "            if out.is_integer():\n",
    "                out = int(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "    def _make_output_df(self, lst):\n",
    "        \"\"\"Make dataframe from the list\"\"\"\n",
    "        # Make columns from first and second elements of the list\n",
    "        columns = [\n",
    "            f\"{first_row}_{second_row}\" for (first_row, second_row) in zip(*lst[:2])\n",
    "        ]\n",
    "\n",
    "        df = pd.DataFrame(data=lst[2:], columns=columns)\n",
    "\n",
    "        # Cast cols to appropriate types\n",
    "        int_cols = [\"example_counter\", \"current_features\"]\n",
    "        for col in int_cols:\n",
    "            df[col] = df[col].astype(int)\n",
    "\n",
    "        float_cols = df.columns.drop(int_cols)\n",
    "        for col in float_cols:\n",
    "            df[col] = df[col].astype(float)\n",
    "\n",
    "        return df\n",
    "\n",
    "    def _parse_vw_row(self, row):\n",
    "        \"\"\"Parse row and add parsed elements to instance attributes params, metrics and table_lst\"\"\"\n",
    "        if \"=\" in row:\n",
    "            param_name, value = [\n",
    "                element.strip() for element in row.split(\"=\", maxsplit=1)\n",
    "            ]\n",
    "            if self.after_table:\n",
    "                self.metrics[param_name] = self._cast_string(value)\n",
    "            else:\n",
    "                self.params[param_name] = self._cast_string(value)\n",
    "        elif \":\" in row:\n",
    "            param_name, value = [\n",
    "                element.strip() for element in row.split(\":\", maxsplit=1)\n",
    "            ]\n",
    "            self.params[param_name] = self._cast_string(value)\n",
    "\n",
    "        elif not self.after_table:\n",
    "            if re.match(\"average\\s+since\", row):\n",
    "                self.inside_table = True\n",
    "            if row == \"\\n\":\n",
    "                self.inside_table = False\n",
    "                self.after_table = True\n",
    "            if self.inside_table:\n",
    "                self.table_lst += [row.split()]\n",
    "\n",
    "\n",
    "# Data import/download functions\n",
    "def download_data(dest_dir=\"PM_DATA\"):\n",
    "    import requests\n",
    "    from io import BytesIO\n",
    "    from rarfile import RarFile\n",
    "\n",
    "    URL_PM_CITIES = \"https://archive.ics.uci.edu/ml/machine-learning-databases/00394/FiveCitiePMData.rar\"\n",
    "\n",
    "    print(f\"Downloading data at {URL_PM_CITIES}\")\n",
    "    r = requests.get(URL_PM_CITIES)\n",
    "    bcontent = BytesIO(r.content)\n",
    "    rf = RarFile(bcontent)\n",
    "\n",
    "    print(f\"Extracting content in folder {repr(dest_dir)}\")\n",
    "    rf.extractall(dest_dir)\n",
    "\n",
    "\n",
    "def import_data(folder_path, verbose=True):\n",
    "    df_lst = []\n",
    "    for fname in os.listdir(folder_path):\n",
    "        fpath = join(folder_path, fname)\n",
    "        if verbose:\n",
    "            print(f\"Importing file: {fpath}\")\n",
    "        city_name = re.sub(\n",
    "            \"pm$\", repl=\"\", string=re.search(\"^[a-z]+\", string=fname.lower()).group()\n",
    "        )\n",
    "\n",
    "        df_city = pd.read_csv(fpath)\n",
    "        df_city_clean = (\n",
    "            df_city.assign(city=city_name)\n",
    "            .drop(\n",
    "                columns=[\"No\"]\n",
    "                + [\n",
    "                    col\n",
    "                    for col in df_city.columns\n",
    "                    if (\"PM\" in col) and (col != \"PM_US Post\")\n",
    "                ]\n",
    "            )\n",
    "            .rename(columns={\"PM_US Post\": \"PM\"})\n",
    "            .dropna(subset=[\"PM\"])\n",
    "        )\n",
    "        df_lst += [df_city_clean]\n",
    "\n",
    "    df_city = (\n",
    "        pd.concat(df_lst)  # append dataframes\n",
    "        .sample(frac=1, random_state=123)  # shuffle\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "\n",
    "    return df_city\n",
    "\n",
    "\n",
    "# Model weight inspection functions\n",
    "def get_feature_names(df):\n",
    "    cat_names = get_cat_feature_names(df)\n",
    "    num_names = df.select_dtypes(np.number).columns.tolist()\n",
    "\n",
    "    return cat_names + num_names\n",
    "\n",
    "\n",
    "def get_cat_feature_names(df):\n",
    "    unique_values_cat = df.select_dtypes(object).apply(lambda s: s.dropna().unique())\n",
    "    cat_names = [\n",
    "        f\"{key}={value}\"\n",
    "        for (key, unique_values) in unique_values_cat.items()\n",
    "        for value in unique_values\n",
    "    ]\n",
    "\n",
    "    return cat_names\n",
    "\n",
    "\n",
    "def get_weight_from_name(model, feature_name, namespace_name=\" \"):\n",
    "    space_hash = model.hash_space(namespace_name)\n",
    "    feat_hash = model.hash_feature(feature_name, space_hash)\n",
    "    return model.get_weight(feat_hash)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Data\n",
    "DATA_FOLDER = \"PM_DATA\"\n",
    "\n",
    "# Graphical\n",
    "SUPTITLE_FONTSIZE = 20\n",
    "SUPTITLE_FONTWEIGHT = \"bold\"\n",
    "TITLE_FONTSIZE = 15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data\n",
    "\n",
    "### 1.1. Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not (os.path.isdir(DATA_FOLDER) and len(os.listdir(DATA_FOLDER)) == 5):\n",
    "    download_data(DATA_FOLDER)\n",
    "\n",
    "df = import_data(DATA_FOLDER)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "The full label of some features (the non obvious ones):\n",
    "\n",
    "- PM: PM2.5 concentration (ug/m^3)\n",
    "- DEWP: Dew Point (Celsius Degree)\n",
    "- TEMP: Temperature (Celsius Degree)\n",
    "- HUMI: Humidity (%)\n",
    "- PRES: Pressure (hPa)\n",
    "- cbwd: Combined wind direction\n",
    "- Iws: Cumulated wind speed (m/s)\n",
    "- precipitation: hourly precipitation (mm)\n",
    "- Iprec: Cumulated precipitation (mm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2. Types"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The types of the columns:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.dtypes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some columns (`year`, `month`, `day`, `hour`, `season`) have been imported as integer/float but **should be treated as categorical** by the model. Hence, we convert these columns to categorical (`str`) type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_cat_cols = [\"year\", \"month\", \"day\", \"hour\", \"season\"]\n",
    "\n",
    "for col in to_cat_cols:\n",
    "    df[col] = df[col].astype(str)\n",
    "\n",
    "# Pandas converts np.nan to \"nan\" when casting to float to object/str :(\n",
    "df[df == \"nan\"] = np.nan"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also, we standardize the numerical variables so we can compare their relative importance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for col in df.select_dtypes(np.number).columns.difference([\"PM\", \"log_PM\"]):\n",
    "    df[col] = (df[col] - df[col].mean()) / df[col].std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3. Highly correlated features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The most correlated (> 0.95) variable(s):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    df.corr()\n",
    "    .reset_index()\n",
    "    .melt(id_vars=\"index\", value_name=\"corr\")\n",
    "    .loc[\n",
    "        lambda df: df[\"index\"] < df[\"variable\"]\n",
    "    ]  # to get lower triangular part of the matrix\n",
    "    .loc[lambda df: df[\"corr\"] > 0.95]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We drop the `Iprec` variable since it is almost perfectly correlated with the `precipitation` variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.drop(columns=[\"Iprec\"], inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 2. Train a first model\n",
    "\n",
    "### 2.1. Converting DataFrame to Vowpal Wabbit input format.\n",
    "\n",
    "We now use `DFtoVW` class to convert this DataFrame to VW input format. \n",
    "\n",
    "There are 2 ways to use the class `DFtoVW`:\n",
    "- **Basic usage**, using the `DFtoVW.from_column_names` class method.\n",
    "- **Advanced usage**, that relies on the VW input format specification (see [Input format section of wiki](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format)). It is build upon classes such as `Feature`, `Namespace` or `SimpleLabel`, `MulticlassLabel` etc.\n",
    "\n",
    "The current section illustrates the basic usage. Section 4 will present the advanced usage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = \"PM\"\n",
    "x = [\n",
    "    \"year\",\n",
    "    \"month\",\n",
    "    \"day\",\n",
    "    \"hour\",\n",
    "    \"season\",\n",
    "    \"DEWP\",\n",
    "    \"HUMI\",\n",
    "    \"PRES\",\n",
    "    \"TEMP\",\n",
    "    \"cbwd\",\n",
    "    \"Iws\",\n",
    "    \"precipitation\",\n",
    "    \"city\",\n",
    "]\n",
    "\n",
    "print(\"label:\", y)\n",
    "print(\"features:\", x)\n",
    "\n",
    "converter = DFtoVW.from_column_names(df=df, y=y, x=x)\n",
    "examples = converter.convert_df()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can inspect the first few examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "examples[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "For categorical features, that VW format is `feature_name=feature_value` whereas for numerical features the format is `feature_name:feature_value`. One nice property of the class is that it will build the appropriate VW type (numerical or categorical) based on the types of the dataframe's columns.\n",
    "\n",
    "Also note that:\n",
    "- for categorical variables, VW adds `:1` behind the scene. For instance `day=14` is equivalent to `day=14:1`\n",
    "- The `=` doesn't have any special meaning and another symbol could have been used. However it's quite standard to use `=`\n",
    "\n",
    "Finally, if the feature name provided by the user is not found of the dataframe, the class will raise a `ValueError`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    DFtoVW.from_column_names(df=df, y=y, x=[\"TEMP\", \"COLUMN_NOT_IN_DF\"])\n",
    "except Exception as e:\n",
    "    print(type(e))\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2. Define and train model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now define the VW model. Note that we enable logging and also set the progress parameter (`P`) to 1 to log the information _for each example_."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Workspace(P=1, enable_logging=True)\n",
    "\n",
    "for ex in examples:\n",
    "    model.learn(ex)\n",
    "\n",
    "model.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 3. Visualizing model's outputs\n",
    "\n",
    "### 3.1. Retrieving model's parameters, losses/predictions and summary metrics from the log\n",
    "\n",
    "Since we enable logging in the model definition (subsection 2.2), we can get the model's log. The log is returned as a list of strings by the `vw.get_log` method. Below are the first 20 lines:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.get_log()[:20]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "And the last 10 lines:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.get_log()[-10:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "The class `VWLogParser` can be used to parse this log. It will return the following objects:\n",
    "- the initial parameters (beginning of the log)\n",
    "- the information available for each example/iteration (middle of the log)\n",
    "- the summary metrics (end of the log)\n",
    "\n",
    "The parsed information is available as `dict` or `DataFrame` objects that can be easily manipulated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_parser = VWLogParser(model.get_log())\n",
    "params, df_iter, summary_metrics = log_parser.parse()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Model's parameters\")\n",
    "display(params)\n",
    "print(\"\\n\")\n",
    "\n",
    "print(\"Information at each iteration\")\n",
    "display(df_iter)\n",
    "print(\"\\n\")\n",
    "\n",
    "print(\"Summary metrics\")\n",
    "display(summary_metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2. Visualizing the average loss and the distribution of selected metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following plots represent the average loss through time and the instantaneous loss:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_var = [\"average_loss\", \"since_last\"]\n",
    "g = sns.relplot(\n",
    "    data=df_iter[scatter_var + [\"example_counter\"]].melt(id_vars=\"example_counter\"),\n",
    "    x=\"example_counter\",\n",
    "    y=\"value\",\n",
    "    col=\"variable\",\n",
    "    col_wrap=3,\n",
    "    facet_kws={\"sharey\": False, \"sharex\": True},\n",
    "    kind=\"scatter\",\n",
    "    s=4,\n",
    "    height=6,\n",
    "    aspect=1.5,\n",
    "    alpha=0.5,\n",
    ")\n",
    "g.fig.suptitle(\n",
    "    \"Scatter plot of losses\", fontsize=SUPTITLE_FONTSIZE, fontweight=SUPTITLE_FONTWEIGHT\n",
    ")\n",
    "g.set_titles(\"{col_name}\", size=TITLE_FONTSIZE)\n",
    "g.fig.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In what follows, we consider the metrics recorded after the  50 000th iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_idx = 50_000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distr_vars = [\"current_label\", \"current_predict\", \"current_features\", \"example_weight\"]\n",
    "\n",
    "g = sns.displot(\n",
    "    data=df_iter.loc[start_idx:, distr_vars].melt(),\n",
    "    x=\"value\",\n",
    "    col=\"variable\",\n",
    "    multiple=\"dodge\",\n",
    "    hue=\"variable\",\n",
    "    bins=60,\n",
    "    common_bins=False,\n",
    "    facet_kws=dict(sharex=False, sharey=False),\n",
    "    col_wrap=4,\n",
    "    height=5,\n",
    ")\n",
    "g.fig.suptitle(\n",
    "    \"Distribution of selected metrics\",\n",
    "    fontsize=SUPTITLE_FONTSIZE,\n",
    "    fontweight=SUPTITLE_FONTWEIGHT,\n",
    "    y=1.05,\n",
    ")\n",
    "g.set_titles(\"{col_name}\", size=TITLE_FONTSIZE)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We notice that the distribution of the predictions differs substantially from the one of the labels. \n",
    "\n",
    "### 3.3. Visualizing the predictions of the model\n",
    "\n",
    "This section offers a visualization of the model's predictions and compares them with the labels (the truth). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error = df_iter.current_label - df_iter.current_predict\n",
    "\n",
    "f, (ax1, ax2) = plt.subplots(figsize=(20, 7), ncols=2)\n",
    "f.suptitle(\n",
    "    \"Predictions and errors\", fontsize=SUPTITLE_FONTSIZE, fontweight=SUPTITLE_FONTWEIGHT\n",
    ")\n",
    "\n",
    "# Scatterplot pred vs truth\n",
    "sns.scatterplot(\n",
    "    data=df_iter.loc[start_idx:],\n",
    "    x=\"current_predict\",\n",
    "    y=\"current_label\",\n",
    "    ax=ax1,\n",
    "    s=2,\n",
    "    alpha=0.15,\n",
    ")\n",
    "ax1.set_title(\"Prediction vs truth\", fontsize=TITLE_FONTSIZE)\n",
    "max_range = int(df_iter[[\"current_label\", \"current_predict\"]].quantile(0.99).max())\n",
    "ax1.set_xlim([0, max_range])\n",
    "ax1.set_ylim([0, max_range])\n",
    "\n",
    "# Adding x=y line\n",
    "# range_x = range(0, int(df_iter[[\"current_predict\", \"current_label\"]].max().min()))\n",
    "range_x = range(0, max_range)\n",
    "ax1.plot(range_x, range_x, linestyle=\":\", color=\"red\", linewidth=2.5)\n",
    "\n",
    "# Histogram of errors\n",
    "sns.histplot(error[start_idx:], ax=ax2)\n",
    "ax2.set_title(\"Distribution of errors\", fontsize=TITLE_FONTSIZE)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "The model tends to undervalue the concentration of PM. Another way to see it is that the distribution of errors (difference between the label and the prediction) has a right-skewed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.4. Visualizing learnt weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We build a dataframe with the model's weights:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get VW feature names\n",
    "feature_names = get_feature_names(df)\n",
    "\n",
    "# Get weights from feature names\n",
    "weights_df = pd.DataFrame(\n",
    "    [(name, get_weight_from_name(model, name), \"=\" in name) for name in feature_names],\n",
    "    columns=[\"vw_feature_name\", \"weight\", \"is_cat\"],\n",
    ")\n",
    "\n",
    "# Adding columns for easier visualization\n",
    "weights_df[\"feature_name\"] = weights_df.apply(\n",
    "    lambda row: (\n",
    "        row.vw_feature_name.split(\"=\")[0] if row.is_cat else row.vw_feature_name\n",
    "    ),\n",
    "    axis=1,\n",
    ")\n",
    "weights_df[\"feature_value\"] = weights_df.apply(\n",
    "    lambda row: (\n",
    "        row.vw_feature_name.split(\"=\")[1].zfill(2)\n",
    "        if row.is_cat\n",
    "        else row.vw_feature_name\n",
    "    ),\n",
    "    axis=1,\n",
    ")\n",
    "weights_df.sort_values([\"feature_name\", \"feature_value\"], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "g = sns.catplot(\n",
    "    data=weights_df[lambda df: df.is_cat],\n",
    "    kind=\"bar\",\n",
    "    x=\"feature_value\",\n",
    "    y=\"weight\",\n",
    "    col=\"feature_name\",\n",
    "    hue=\"feature_name\",\n",
    "    col_wrap=3,\n",
    "    sharex=False,\n",
    "    sharey=False,\n",
    "    aspect=1.5,\n",
    "    dodge=False,\n",
    ")\n",
    "g.fig.suptitle(\n",
    "    \"Feature weights (categorical features)\",\n",
    "    fontsize=SUPTITLE_FONTSIZE,\n",
    "    fontweight=SUPTITLE_FONTWEIGHT,\n",
    ")\n",
    "g.set_titles(\"{col_name}\", size=TITLE_FONTSIZE)\n",
    "\n",
    "# Add horizontal bar at y=0\n",
    "for ax in g.axes.flat:\n",
    "    ax.axhline(0, color=\"gray\", linestyle=\":\")\n",
    "g.fig.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Based on the weights learnt by the model, the predicted value for PM will be higher:\n",
    "\n",
    "- for wind orientation \"calm and variable\" (CV)\n",
    "- for cities such as Beijing or Chengdu\n",
    "- for winter season/months\n",
    "- for evening hours\n",
    "- for certain day of month such as 21-23 (oddly)\n",
    "\n",
    "The predicted value will be lower: \n",
    "\n",
    "- for winds from the north\n",
    "- for year 2015\n",
    "- for hours around noon\n",
    "- for certain cities such as Guangzhou and Shanghai."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(figsize=(12, 5))\n",
    "ax = sns.barplot(\n",
    "    data=weights_df[lambda df: df.is_cat == False], x=\"feature_name\", y=\"weight\"\n",
    ")\n",
    "ax.set_title(\"Feature weights (numerical features)\", fontsize=TITLE_FONTSIZE)\n",
    "\n",
    "# Set xlabels in bold, remove x-axis title\n",
    "ax.set_xticklabels(ax.get_xticklabels(), fontweight=SUPTITLE_FONTWEIGHT)\n",
    "ax.set(xlabel=None)\n",
    "\n",
    "# Add horizontal line\n",
    "ax.axhline(0, linestyle=\":\", color=\"gray\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Higher cumulated wind speed (Iws) and higher temperature (TEMP) are associated with lower predicted values. \n",
    "\n",
    "Higher air pressure (PRES) is associated with higher predicted values."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 4. Train a more complex model: using log transformed target and namespace interactions\n",
    "\n",
    "This section illustrates the **advanced usage** of `DFtoVW` class. To do so, we will need to import some specific classes from the `DFtoVW` module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from vowpalwabbit.DFtoVW import SimpleLabel, Namespace, Feature"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following drawing explains **how to use these classes**:\n",
    "\n",
    "![DFtoVW_usage.png](./DFtoVW_usage.png)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1. Applying logarithm transformation to the target\n",
    "\n",
    "The distribution of the target is close to a log-normal distribution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, (ax1, ax2) = plt.subplots(figsize=(12, 5), ncols=2)\n",
    "f.suptitle(\n",
    "    \"Distribution of the target (PM)\",\n",
    "    fontsize=SUPTITLE_FONTSIZE,\n",
    "    fontweight=SUPTITLE_FONTWEIGHT,\n",
    ")\n",
    "\n",
    "df[\"PM\"].hist(bins=40, ax=ax1)\n",
    "ax1.set_title(\"No transformation\", fontsize=TITLE_FONTSIZE)\n",
    "\n",
    "np.log(df[\"PM\"]).hist(bins=40, ax=ax2)\n",
    "ax2.set_title(\"Log-transformed\", fontsize=TITLE_FONTSIZE)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We decide to train the model on the log transformed version of the target (called `log_PM`). \n",
    "\n",
    "For a regression task, we use the `SimpleLabel` class (more details on this type in the [Input Format section of the wiki](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format#simple)) to represent the target."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"log_PM\"] = df[\"PM\"].apply(np.log)\n",
    "label = SimpleLabel(\"log_PM\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2. Defining namespaces\n",
    "\n",
    "Namespaces are defined using the `Namespace` class. They are formed of a (list of) `Feature` and can have a name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Datetime namespace\n",
    "datetime_features = [\"year\", \"month\", \"day\", \"hour\", \"season\"]\n",
    "ns_datetime = Namespace(\n",
    "    features=[Feature(col) for col in datetime_features], name=\"datetime_ns\"\n",
    ")\n",
    "\n",
    "# City namespace\n",
    "ns_city = Namespace(features=Feature(\"city\"), name=\"city_ns\")\n",
    "\n",
    "# Weather namespace\n",
    "weather_features = [\"DEWP\", \"HUMI\", \"PRES\", \"TEMP\", \"cbwd\", \"Iws\", \"precipitation\"]\n",
    "ns_weather = Namespace(\n",
    "    features=[Feature(col) for col in weather_features], name=\"weather_ns\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3. Converting to VW format and training model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "converter_advanced = DFtoVW(\n",
    "    df=df, namespaces=[ns_datetime, ns_weather, ns_city], label=label\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Namespace` and `Feature` objects can be accessed using `DFtoVW`'s instance attributes `namespaces` and `features`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for namespace in converter_advanced.namespaces:\n",
    "    print(\"namespace:\", namespace.name)\n",
    "    for feature in namespace.features:\n",
    "        print(\"\\tfeature:\", feature.name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can inspect the first few examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "examples_advanced = converter_advanced.convert_df()\n",
    "examples_advanced[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this new model we will add interactions between the \"weather\" namespace and all the namespaces."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_advanced = Workspace(\n",
    "    arg_str=\"--redefine W:=weather_ns --interactions W:\", P=1, enable_logging=True\n",
    ")\n",
    "\n",
    "for ex in examples_advanced:\n",
    "    model_advanced.learn(ex)\n",
    "\n",
    "model_advanced.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.4. Visualizing model's outputs\n",
    "\n",
    "We transform the labels and predictions using exponential function (since the target is log-transformed) so they can be compared with the ones of section 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params_advanced, df_iter_advanced, metrics_advanced = VWLogParser(\n",
    "    model_advanced.get_log()\n",
    ").parse()\n",
    "\n",
    "df_iter_advanced.current_label = np.exp(df_iter_advanced.current_label)\n",
    "df_iter_advanced.current_predict = np.exp(df_iter_advanced.current_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_advanced = df_iter_advanced.current_label - df_iter_advanced.current_predict\n",
    "\n",
    "f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(figsize=(20, 12), ncols=2, nrows=2)\n",
    "f.suptitle(\n",
    "    \"Predictions, labels and errors\",\n",
    "    fontsize=SUPTITLE_FONTSIZE,\n",
    "    fontweight=SUPTITLE_FONTWEIGHT,\n",
    ")\n",
    "\n",
    "sns.histplot(df_iter_advanced.current_label.iloc[start_idx:], ax=ax1, color=\"blue\")\n",
    "ax1.set_title(\"Distribution of current_label\", fontsize=TITLE_FONTSIZE)\n",
    "\n",
    "sns.histplot(df_iter_advanced.current_predict.iloc[start_idx:], ax=ax2, color=\"orange\")\n",
    "ax2.set_title(\"Distribution of current_predict\", fontsize=TITLE_FONTSIZE)\n",
    "\n",
    "sns.scatterplot(\n",
    "    data=df_iter_advanced.iloc[start_idx:],\n",
    "    x=\"current_predict\",\n",
    "    y=\"current_label\",\n",
    "    ax=ax3,\n",
    "    s=2,\n",
    "    alpha=0.15,\n",
    ")\n",
    "ax3.set_title(\"Prediction vs truth\", fontsize=TITLE_FONTSIZE)\n",
    "\n",
    "# Add x=y line\n",
    "max_range = int(df_iter[[\"current_label\", \"current_predict\"]].quantile(0.99).max())\n",
    "ax3.set_xlim(0, max_range)\n",
    "ax3.set_ylim(0, max_range)\n",
    "range_x = range(0, max_range)\n",
    "ax3.plot(range_x, range_x, linestyle=\":\", color=\"red\", linewidth=2.5)\n",
    "\n",
    "sns.histplot(error_advanced, ax=ax4)\n",
    "ax4.set_title(\"Distribution of errors\", fontsize=TITLE_FONTSIZE)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "In this new model, the distribution of the predictions is more in line with the distribution of the labels.\n",
    "\n",
    "The errors of this model are also closer to a normal distribution, even though the model still undervalues some observations.\n",
    "\n",
    "### 4.5. Comparing models' performance\n",
    "\n",
    "The model in section 2 is called \"simple\" and the current model (with interactions and log-transformed target) is called \"advanced\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_type = 1  # 1 for L1-loss, 2 for L2-loss\n",
    "n_iter = 10_000  # Number of iterations on which to compute the loss (for moving average and final value)\n",
    "\n",
    "df_loss = pd.concat([error, error_advanced], axis=1, keys=[\"simple\", \"advanced\"]).apply(\n",
    "    lambda x: np.power(np.abs(x), loss_type)\n",
    ")\n",
    "\n",
    "f, (ax1, ax2) = plt.subplots(ncols=2, figsize=(15, 5.5))\n",
    "f.suptitle(\n",
    "    \"Comparison of models' performance\",\n",
    "    fontweight=SUPTITLE_FONTWEIGHT,\n",
    "    fontsize=SUPTITLE_FONTSIZE,\n",
    ")\n",
    "\n",
    "df_loss.rolling(n_iter).mean().plot(ax=ax1)\n",
    "ax1.set_title(f\"Moving average of loss (over {n_iter:,} iterations)\", fontsize=15)\n",
    "\n",
    "loss_last_it = df_loss.tail(n_iter).mean()\n",
    "pct_diff_loss = 100 * (loss_last_it.advanced / loss_last_it.simple - 1)\n",
    "\n",
    "loss_last_it.plot.bar(ax=ax2, color=\"gray\")\n",
    "ax2.set_title(f\"Loss computed on the last {n_iter:,} iterations\", fontsize=15)\n",
    "ax2.text(\n",
    "    0.75,\n",
    "    loss_last_it.advanced * 1.025,\n",
    "    f\"Δ% loss= {pct_diff_loss:.2f}\",\n",
    "    fontsize=12,\n",
    "    fontweight=SUPTITLE_FONTWEIGHT,\n",
    ")\n",
    "ax2.set_xticklabels(ax2.get_xticklabels(), fontsize=13, rotation=0)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model with interactions and log-transformed target offers substantial improvements over the simple model: the **loss** (computed on the last 10 000 observations) **decreases by about 23%**."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Vowpal wabbit kernel",
   "language": "python",
   "name": ".vw"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
