{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lightgbm\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "import pickle as pkl\n",
    "\n",
    "from testbed.models.ngboost import NGBoostGaussian, NGBoostMixtureGaussian\n",
    "from testbed.models.treeffuser import Treeffuser\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "from jaxtyping import Float, Array\n",
    "from typing import List, Callable\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from testbed.metrics.log_likelihood import LogLikelihoodFromSamplesMetric\n",
    "\n",
    "path = \"../src/testbed/data/m5\"\n",
    "\n",
    "# load autoreload extension\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# These are config variables\n",
    "\n",
    "PROCESS_FROM_SCRATCH = True\n",
    "USE_SUBSET = True\n",
    "CONTEXT_LENGTH = 20\n",
    "RUN_DEPRECATED = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# READ IN DATA\n",
    "\n",
    "sell_prices_df = pd.read_csv(Path(path) / \"sell_prices.csv\")\n",
    "sales_train_validation_df = pd.read_csv(Path(path) / \"sales_train_validation.csv\")\n",
    "calendar_df = pd.read_csv(Path(path) / \"calendar.csv\")\n",
    "\n",
    "print(\"\\ncolumns of sell_prices_df:\")\n",
    "[print(col) for col in sell_prices_df.columns]\n",
    "print(\"\\ncolumns of sales_train_validation_df:\")\n",
    "[print(col) for col in sales_train_validation_df.columns if not col.startswith(\"d_\")]\n",
    "print(\"\\ncolumns of calendar_df:\") # ommit d_1, d_2, ..., d_1913\n",
    "[print(col) for col in calendar_df.columns if not col.startswith(\"d_\")]\n",
    "\n",
    "\"\"\n",
    "\n",
    "# print number of zeros\n",
    "print(\"number of zeros in sales_train_validation_df: \", (sales_train_validation_df == 0).sum().sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "items_sold_cols = sales_train_validation_df.columns[sales_train_validation_df.columns.str.startswith(\"d_\")]\n",
    "num_zeros = (sales_train_validation_df[items_sold_cols] == 0).sum().sum()\n",
    "total_entries = sales_train_validation_df[items_sold_cols].shape[0] * sales_train_validation_df[items_sold_cols].shape[1]\n",
    "\n",
    "print(f\"number of zeros in sales_train_validation_df: {num_zeros} out of {total_entries} entries\")\n",
    "print(f\"percentage of zeros in sales_train_validation_df: {num_zeros / total_entries * 100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add explicit columns for the day, month, year for ease of processing\n",
    "calendar_df[\"date\"] = pd.to_datetime(calendar_df[\"date\"])\n",
    "calendar_df[\"day\"] = calendar_df[\"date\"].dt.day\n",
    "calendar_df[\"month\"] = calendar_df[\"date\"].dt.month\n",
    "calendar_df[\"year\"] = calendar_df[\"date\"].dt.year\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Brief snapshots of the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "calendar_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sales_train_validation_df.iloc[:10, :20]"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "sales_train_validation_df.shape"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sell_prices_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Process the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TOTAL_ITEMS = 1_000\n",
    "# select a random subset of items\n",
    "if USE_SUBSET:\n",
    "    np.random.seed(0)\n",
    "    unique_ids = sales_train_validation_df[\"id\"].unique()\n",
    "    ids = np.random.choice(sales_train_validation_df[\"id\"].unique(), TOTAL_ITEMS, replace=False)\n",
    "    sales_train_validation_df_sub = sales_train_validation_df[sales_train_validation_df[\"id\"].isin(ids)]\n",
    "    item_ids = sales_train_validation_df_sub[\"item_id\"].unique()\n",
    "    sell_prices_df_sub = sell_prices_df[sell_prices_df[\"item_id\"].isin(item_ids)]\n",
    "    calendar_df_sub = calendar_df\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "columns_sales_train_validation.head)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The strategy for processing the data is going to be the following. 1) We are going to have X and y where y is the next days sales for a given product. 3) X is made up of 10 previous prices, day of the week, + event types, cat_id, store_id, state_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def proc_train_test(sales_train_validation_df: pd.DataFrame, calendar_df: pd.DataFrame, sell_prices_df: pd.DataFrame, context_length: int, test_days, percentage_omittied: int = 0): #type annotation too long\n",
    "    \"\"\"\n",
    "    This function processes the data and returns the training and test data in two ways:\n",
    "    - undifferentiated: a list of all training and test data (X_train, y_train, X_test, y_test)\n",
    "    - differentiated: a list of training and test data for each product (X_train_prod, y_train_prod, X_test_prod, y_test_prod)\n",
    "        where X_train_prod[i] contains a list of all X_train values for the product i with similar grouping for y_train_prod and test\n",
    "\n",
    "    This assumes from the dataframes that\n",
    "    - sales_train_validation_df:\n",
    "        - has columns with the format d_1, d_2, ...\n",
    "        - has columns item_id and store_id\n",
    "    - calendar_df:\n",
    "        - wday, month, event_name_1, event_name_2\n",
    "    - sell_prices_df:\n",
    "        - item_id, store_id, sell_price\n",
    "\n",
    "    - percentage_omittied: percentage of the data to be omitted from the training data and the test data\n",
    "        (randomly selected)\n",
    "\n",
    "    Returns:\n",
    "    - undifferentiated: Tuple of X_train, y_train, X_test, y_test\n",
    "    - differentiated: Tuple of X_train_prod, y_train_prod, X_test_prod, y_test_prod\n",
    "    \"\"\"\n",
    "    np.random.seed(0)\n",
    "    # First we need to get the training data\n",
    "    # We will use the first 1913 days as training data and the next\n",
    "\n",
    "    X_train = []\n",
    "    y_train = []\n",
    "\n",
    "    X_test = []\n",
    "    y_test = []\n",
    "\n",
    "    # We will also return a second grouping of lists where X_train_prod[i] contains a\n",
    "    # a list of all X_train values for the product i with similar grouping for y_train_prod and test\n",
    "    X_train_prod = []\n",
    "    y_train_prod = []\n",
    "    X_test_prod = []\n",
    "    y_test_prod = []\n",
    "\n",
    "\n",
    "    # get all days that start with d_ and look for the maximum\n",
    "    total_days = max([int(x.split(\"_\")[1]) for x in sales_train_validation_df.columns if \"d_\" in x])\n",
    "    train_days = total_days - test_days\n",
    "    print(\"train days\", train_days)\n",
    "    print(\"test days\", total_days - train_days)\n",
    "    print(\"total days\", total_days)\n",
    "\n",
    "    # Precompute the required data\n",
    "    calendar_df_dict = calendar_df.set_index(\"d\").to_dict(orient=\"index\")\n",
    "    sell_prices_dict = sell_prices_df.groupby([\"item_id\", \"store_id\"])[\"sell_price\"].first().to_dict()\n",
    "\n",
    "    pbar = tqdm(total=len(sales_train_validation_df))\n",
    "    for _, row in sales_train_validation_df.iterrows():\n",
    "        item_id = row[\"item_id\"]\n",
    "        store_id = row[\"store_id\"]\n",
    "\n",
    "        X_train_prod.append([])\n",
    "        y_train_prod.append([])\n",
    "        X_test_prod.append([])\n",
    "        y_test_prod.append([])\n",
    "\n",
    "        pbar.update(1)\n",
    "\n",
    "        valid_size = int((train_days - context_length) * (1 - percentage_omittied))\n",
    "        valid_js = np.random.choice(range(1, train_days - context_length), valid_size, replace=False)\n",
    "\n",
    "        valid_js = list(valid_js) + list(range(train_days - context_length, total_days - context_length))\n",
    "\n",
    "        for j in valid_js:\n",
    "            x = []\n",
    "\n",
    "            # Add sales values for the previous context_length days\n",
    "            x.extend(row[f\"d_{j+k}\"] for k in range(context_length))\n",
    "\n",
    "            # Add additional features\n",
    "            current_day = f\"d_{j+context_length}\"\n",
    "            calendar_data = calendar_df_dict[current_day]\n",
    "            x.extend([\n",
    "                calendar_data[\"wday\"],\n",
    "                calendar_data[\"month\"],\n",
    "                store_id,\n",
    "                calendar_data[\"event_name_1\"],\n",
    "                calendar_data[\"event_name_2\"],\n",
    "                sell_prices_dict[(item_id, store_id)],\n",
    "                item_id,\n",
    "                j + context_length\n",
    "            ])\n",
    "\n",
    "            if j < train_days:\n",
    "                X_train.append(x)\n",
    "                y_train.append(row[current_day])\n",
    "                X_train_prod[-1].append(x)\n",
    "                y_train_prod[-1].append(row[current_day])\n",
    "\n",
    "            else:\n",
    "                X_test.append(x)\n",
    "                y_test.append(row[current_day])\n",
    "                X_train_prod[-1].append(x)\n",
    "                y_train_prod[-1].append(row[current_day])\n",
    "\n",
    "    undifferentiated = (X_train, y_train, X_test, y_test)\n",
    "    differentiated = (X_train_prod, y_train_prod, X_test_prod, y_test_prod)\n",
    "    return undifferentiated, differentiated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PROCESS_FROM_SCRATCH:\n",
    "    undifferentiated, differentiated = proc_train_test(sales_train_validation_df_sub, calendar_df, sell_prices_df_sub, CONTEXT_LENGTH, 50, 0.95)\n",
    "    X_train, y_train, X_test, y_test = undifferentiated\n",
    "    X_train_prod, y_train_prod, X_test_prod, y_test_prod = differentiated\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(X_train), len(y_train), len(X_test), len(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "COL_NAMES = [\n",
    "    f\"day_{i}\" for i in range(1, CONTEXT_LENGTH+1)\n",
    "] + [\"wday\", \"month\", \"store_id\", \"event_name_1\", \"event_name_2\", \"sell_price\", \"item_id\", \"day\"]\n",
    "\n",
    "CAT_COLS = [\"store_id\", \"event_name_1\", \"event_name_2\", \"item_id\", \"wday\", \"month\"]\n",
    "CAT_COLS_IDX = [COL_NAMES.index(col) for col in CAT_COLS]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_df = pd.DataFrame(X_train)\n",
    "X_test_df = pd.DataFrame(X_test)\n",
    "y_test_df = pd.DataFrame(y_test)\n",
    "y_train_df = pd.DataFrame(y_train)\n",
    "\n",
    "X_train_df.columns = COL_NAMES\n",
    "X_test_df.columns = COL_NAMES"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "pd.DataFrame(X_test)[26].value_counts().shape"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "X_train_df"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "X_train_df[\"item_id\"].value_counts()"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encode the categorical columns as numbers\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "# Get only label of item_id\n",
    "X_train_df[\"item_id\"] = X_train_df[\"item_id\"].apply(lambda x: x.split(\"_\")[1])\n",
    "X_test_df[\"item_id\"] = X_test_df[\"item_id\"].apply(lambda x: x.split(\"_\")[1])\n",
    "\n",
    "\n",
    "label_encoders = {}\n",
    "for col in CAT_COLS:\n",
    "    le = LabelEncoder()\n",
    "    X_train_df[col] = le.fit_transform(X_train_df[col])\n",
    "    X_test_df[col] = le.transform(X_test_df[col])\n",
    "    label_encoders[col] = le\n",
    "\n",
    "\n",
    "X_train_prod_processed = []\n",
    "X_test_prod_processed = []\n",
    "for i in range(len(X_train_prod)):\n",
    "    X_train_prod_processed.append(pd.DataFrame(X_train_prod[i], columns=COL_NAMES))\n",
    "    X_test_prod_processed.append(pd.DataFrame(X_test_prod[i], columns=COL_NAMES))\n",
    "    X_train_prod_processed[-1][\"item_id\"] = X_train_prod_processed[-1][\"item_id\"].apply(lambda x: x.split(\"_\")[1])\n",
    "    X_test_prod_processed[-1][\"item_id\"] = X_test_prod_processed[-1][\"item_id\"].apply(lambda x: x.split(\"_\")[1])\n",
    "    for col in CAT_COLS:\n",
    "        X_train_prod_processed[-1][col] = label_encoders[col].transform(X_train_prod_processed[-1][col])\n",
    "        X_test_prod_processed[-1][col] = label_encoders[col].transform(X_test_prod_processed[-1][col])\n",
    "\n",
    "X_train_df.head()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PPC"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### \"Standard PPCs\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def max_ppc(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], number=0, name=\"\") -> None:\n",
    "    # rpeat y_true to match the shape of y_samples\n",
    "    max_ppc = np.max(y_samples, axis=1)\n",
    "    true_max = np.max(y_true)\n",
    "\n",
    "    return max_ppc.flatten(), true_max.flatten(), \"max_ppc\"\n",
    "\n",
    "def quantile_ppc(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], quantile=0.5, number=0, name=\"\") -> None:\n",
    "    # rpeat y_true to match the shape of y_samples\n",
    "    q = np.quantile(y_samples, quantile, axis=1)\n",
    "    true_q = np.quantile(y_true, quantile)\n",
    "    return q.flatten(), true_q.flatten(), f\"quantile_ppc_{quantile}\"\n",
    "\n",
    "def zeros(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], number=0, name=\"\") -> None:\n",
    "    \"Count the number of zeros in the samples\"\n",
    "    zeros = np.sum(y_samples < 0.1, axis=1)\n",
    "    true_zeros = np.sum(y_true < 0.1)\n",
    "\n",
    "    return zeros.flatten(), true_zeros.flatten(), \"zeros\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ppcs(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], ppcs: List[Callable],\n",
    "              number=0, name=\"\") -> None:\n",
    "    # plot the distribution of\n",
    "\n",
    "    for ppc in ppcs:\n",
    "        ppc(y_true, y_samples, number=number, name=name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### \"Complex PPCs\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot_model_comparisons(data, y_true, figsize=(12, 8), model_names=None):\n",
    "    \"\"\"\n",
    "    Plots model predictions against true values for each day.\n",
    "\n",
    "    :param data: numpy array of shape [models, samples, days] containing model predictions\n",
    "    :param y_true: array of shape [days] containing the true values\n",
    "    :param figsize: tuple indicating the size of the figure\n",
    "    \"\"\"\n",
    "    sns.set(style=\"whitegrid\")\n",
    "    models, samples, days = data.shape\n",
    "\n",
    "    # Create a figure and axis object\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "    # We will transform the data to a format suitable for seaborn\n",
    "    # Create a DataFrame with model, day, and sample values\n",
    "    plot_data = []\n",
    "    if model_names is None:\n",
    "        model_names = [f\"Model {i}\" for i in range(models)]\n",
    "\n",
    "    for model_idx in range(models):\n",
    "        for day_idx in range(days):\n",
    "            for sample_idx in range(samples):\n",
    "                plot_data.append({\n",
    "                    \"Day\": day_idx,\n",
    "                    \"Value\": data[model_idx, sample_idx, day_idx],\n",
    "                    \"Model\": model_names[model_idx]\n",
    "                })\n",
    "\n",
    "    import pandas as pd\n",
    "    plot_data = pd.DataFrame(plot_data)\n",
    "\n",
    "    # Use seaborn to plot the boxplots\n",
    "    sns.boxplot(x=\"Day\", y=\"Value\", hue=\"Model\", data=plot_data, ax=ax, width=0.6)\n",
    "\n",
    "    # Plot true values\n",
    "    plt.plot(y_true, 'o', color='red', label='True Values')\n",
    "\n",
    "    # Setting labels and title\n",
    "    plt.xticks(ticks=np.arange(days), labels=[f\"Day {i+1}\" for i in range(days)])\n",
    "    plt.xlabel('Days')\n",
    "    plt.ylabel('Values')\n",
    "    plt.title('Model Predictions vs. True Values')\n",
    "    plt.legend()\n",
    "\n",
    "    # Show the plot\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_results_to_pkl(results: dict, dir_name, name):\n",
    "    if not Path(dir_name).exists():\n",
    "        Path(dir_name).mkdir(parents=True)\n",
    "\n",
    "    path = Path(dir_name) / f\"{name}.pkl\"\n",
    "    with open(path, \"wb\") as f:\n",
    "        pkl.dump(results, f)\n",
    "\n",
    "\n",
    "def load_results_from_pkl(dir_name, name):\n",
    "    path = Path(dir_name) / f\"{name}.pkl\"\n",
    "    with open(path, \"rb\") as f:\n",
    "        results = pkl.load(f)\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simple helper function to train a model and plot ppcs\n",
    "\n",
    "def get_ppcs(y_samples, X_test, y_test, ppcs, number=0, name=\"\") -> None:\n",
    "    \"\"\"\n",
    "    Returns a dictionary with the samples and the true values for each ppc\n",
    "    the dictionary a\n",
    "    \"\"\"\n",
    "    y_samples = np.array(y_samples)\n",
    "    y_samples = np.maximum(y_samples, 0)\n",
    "    # y_samples = np.round(y_samples, 0)\n",
    "\n",
    "    ppc_results = {}\n",
    "    for ppc in ppcs:\n",
    "        samples, true, name = ppc(y_test, y_samples, number=number, name=name)\n",
    "        ppc_results[name] = {\"samples\": samples, \"true\": true}\n",
    "\n",
    "    return ppc_results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_train_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "X_test_df[\"day\"].value_counts()"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EVAL_VALUES = 10000 # How many test points to use\n",
    "np.random.seed(0)\n",
    "\n",
    "eval_idx = np.random.choice(len(X_test_df), EVAL_VALUES, replace=False)\n",
    "\n",
    "X_train_np = X_train_df.values[:, :-1]\n",
    "X_days_np = X_train_df[\"day\"].values\n",
    "X_test_np = X_test_df.values[eval_idx][:, :-1]\n",
    "X_test_days_np = X_test_df[\"day\"].values[eval_idx]\n",
    "X_test_df_sub = X_test_df.iloc[eval_idx]\n",
    "\n",
    "y_train_np = y_train_df.values\n",
    "y_test_np = y_test_df.values[eval_idx]\n",
    "\n",
    "# change to float to prevent errors\n",
    "y_train_np = y_train_np.astype(np.float32)\n",
    "y_test_np = y_test_np.astype(np.float32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from testbed.models import BayesOptProbabilisticModel\n",
    "from testbed.models.ngboost import NGBoostPoisson\n",
    "from testbed.models.quantile_regression import QuantileRegressionTree\n",
    "from testbed.models.ibug_ import IBugXGBoost\n",
    "from testbed.models.treeffuser import Treeffuser\n",
    "# from testbed.models.lightning_uq_models import DeepEnsemble\n",
    "\n",
    "\n",
    "MODEL_CLASSES = [Treeffuser, NGBoostPoisson, QuantileRegressionTree]\n",
    "NAMES = [\"Treeffuser\", \"NGBoostPoisson\", \"QuantileRegressionTree\"]\n",
    "\n",
    "NUM_SAMPLES = 100\n",
    "\n",
    "\n",
    "# results = []\n",
    "# for i in range(3]):\n",
    "#     print(f\"Fitting model {NAMES[i]}\")\n",
    "#     model_cls = MODEL_CLASSES[i]\n",
    "#     # model = model_cls(**HYPERS[i])\n",
    "#     # model = model_cls()\n",
    "#     model = BayesOptProbabilisticModel(model_cls, n_iter_bayes_opt=20, frac_validation=0.1)\n",
    "#     model.fit(X_train_np, y_train_np)\n",
    "# \n",
    "#     results.append({\n",
    "#         \"model\": model,\n",
    "#         \"model_name\": NAMES[i]\n",
    "#     })\n",
    "\n",
    "# change to range(3) to fit all models\n",
    "for i in [0]:\n",
    "    print(f\"Fitting model {NAMES[i]}\")\n",
    "    model_cls = MODEL_CLASSES[i]\n",
    "    # model = model_cls(**HYPERS[i])\n",
    "    model = model_cls()\n",
    "    model.fit(X_train_np, y_train_np)\n",
    "\n",
    "    results.append({\n",
    "        \"model\": model,\n",
    "        \"model_name\": NAMES[i] + \"_default\"\n",
    "    })\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for i, result in enumerate(results):\n",
    "    model = result[\"model\"]\n",
    "    model_name = result[\"model_name\"]\n",
    "    if \"y_samples\" in result:\n",
    "        continue\n",
    "    y_samples = model.sample(X_test_np, NUM_SAMPLES)\n",
    "    results[i][\"y_samples\"] = y_samples"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# implement the newsvendor utility function\n",
    "\n",
    "def newsvendor_utility(\n",
    "    y_true: Float[Array, \"batch 1\"], quantity_ordered : Float[Array, \"batch 1\"], prices : Float[Array, \"batch 1\"],\n",
    "    stocking_cost: Float[Array, \"batch 1\"]\n",
    "    ) -> Float[Array, \"batch 1\"]:\n",
    "    \"\"\"\n",
    "    The newsvendor utility function with stock q, demand y, selling price p, stocking cost c is given by\n",
    "    U(y, q, p, c) = p * min(y, q) - c * q\n",
    "    \"\"\"\n",
    "    y_true = y_true.flatten()\n",
    "    utility = prices * np.minimum(y_true, quantity_ordered) - stocking_cost * quantity_ordered\n",
    "    return utility\n",
    "\n",
    "def newsvendor_optimal_quantity(y_samples: Float[Array, \"samples batch 1\"], \n",
    "prices: Float[Array, \"batch 1\"], stocking_cost: Float[Array, \"batch 1\"]) -> Float[Array, \"batch 1\"]:\n",
    "    \"\"\"\n",
    "    Returns the optimal quantity to order for the newsvendor problem.\n",
    "    It is given theoeretically by:\n",
    "    q* = argmax_{q} E[U(y, q, p, c)] which has a closed form solution\n",
    "    q* = F^{-1}( (p - c) / p)\n",
    "    where F is the CDF of the demand distribution\n",
    "    \"\"\"\n",
    "    # compute the target quantiles (p - c) / p\n",
    "    target_quantiles = (prices - stocking_cost) / prices\n",
    "    target_quantiles = np.maximum(target_quantiles, 0.0)\n",
    "    \n",
    "    # compute the empirical quantities corresponding to the target quantiles\n",
    "    \n",
    "    res = []\n",
    "    for i in range(y_samples.shape[1]):\n",
    "        optimal_quantities = np.quantile(y_samples[:,i,0], target_quantiles[i])\n",
    "        res.append(optimal_quantities)\n",
    "    optimal_quantities = np.array(res)\n",
    "    return optimal_quantities\n",
    "    \n",
    "    \n",
    "def evaluate_models(results, X_test_df, y_test, profit_margin=0.5):\n",
    "    \"\"\"\n",
    "    Evaluates the model on the test data\n",
    "    \"\"\"\n",
    "    prices = X_test_df[\"sell_price\"].values\n",
    "    stocking_cost =  prices / (1 + profit_margin)\n",
    "    \n",
    "    \n",
    "    # compute the maximum utility if we had perfect information\n",
    "    perfect_quantities = np.maximum(y_test.flatten(), 0)\n",
    "    perfect_utility = newsvendor_utility(y_test, perfect_quantities, prices, stocking_cost)\n",
    "    print(f\"Knowing the perfect demand would make : ${perfect_utility.sum():.2f}\")\n",
    "    \n",
    "    # print all shapes\n",
    "    res_df = pd.DataFrame({\n",
    "        \"day\": X_test_df[\"day\"].values,\n",
    "        \"y_true\": y_test.flatten(),\n",
    "        \"sell_price\": prices,\n",
    "        \"stocking_cost\": stocking_cost,\n",
    "        \"perfect_quantities\": perfect_quantities,\n",
    "    })\n",
    "    \n",
    "    for tmp in results:\n",
    "        model_name = tmp[\"model_name\"]\n",
    "        y_samples = tmp[\"y_samples\"]\n",
    "        optimal_quantities = newsvendor_optimal_quantity(y_samples, prices, stocking_cost)\n",
    "        utility = newsvendor_utility(y_test, optimal_quantities, prices, stocking_cost)\n",
    "        res_df[f\"{model_name}_quantities\"] = optimal_quantities\n",
    "        res_df[f\"{model_name}_utility\"] = utility\n",
    "        print(f\"Model {model_name} makes {utility.sum():.2f} big bucks! {'😊' if utility.sum() > 0 else '😢'}\")\n",
    "        \n",
    "    return res_df\n",
    "    \n",
    "    \n",
    "    "
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "\n",
    "sns.set_theme(style=\"white\", rc={\"axes.facecolor\": (0, 0, 0, 0)})\n",
    "mpl.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": False,\n",
    "        \"font.family\": \"serif\",\n",
    "        \"font.serif\": [\"Times New Roman\"],  # [\"Computer Modern\"],  #\n",
    "        \"font.size\": 12,\n",
    "    }\n",
    ")"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "X_train_df.shape"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "# pcikle the data\n",
    "with open(\"big-dump.pkl\", \"wb\") as f:\n",
    "    pkl.dump(\n",
    "        {\n",
    "            \"X_train_df\": X_train_df,\n",
    "            \"X_test_df\": X_test_df,\n",
    "            \"y_test_df\": y_test_df,\n",
    "            \"y_train_df\": y_train_df,\n",
    "            \"X_train_prod_processed\": X_train_prod_processed,\n",
    "            \"X_test_prod_processed\": X_test_prod_processed,\n",
    "            \"label_encoders\": label_encoders,\n",
    "            \"results\": results,\n",
    "        },\n",
    "        f,\n",
    "    )"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "res_df = evaluate_models(results, X_test_df_sub, y_test_np, profit_margin=0.5)\n",
    "res_df_pivot = res_df.groupby(\"day\").sum()[[c for c in res_df.columns if \"utility\" in c]]\n",
    "res_df_pivot = res_df_pivot.sort_index()\n",
    "res_df_pivot = res_df_pivot.cumsum()\n",
    "\n",
    "sns.set(style=\"white\")\n",
    "fig, ax = plt.subplots(figsize=(5, 3))\n",
    "name_mapping = {\n",
    "    \"Treeffuser_utility\": \"Treeffuser\",\n",
    "    \"NGBoostPoisson_utility\": \"NGBoost-Poisson\",\n",
    "    \"QuantileRegressionTree_default_utility\": \"QuantileRegression-Tree (def.)\",\n",
    "    \"QuantileRegressionTree_utility\": \"QuantileRegression-Tree\",\n",
    "    \"Treeffuser_default_utility\": \"Treeffuser (def.)\",\n",
    "    \"NGBoostPoisson_default_utility\": \"NGBoostPoisson (def.)\",\n",
    "}\n",
    "data_plot = res_df_pivot.copy().reset_index()\n",
    "# shift day to start from 1\n",
    "data_plot[\"day\"] = data_plot[\"day\"] - data_plot[\"day\"].min() + 1\n",
    "data_plot = data_plot.melt(id_vars=\"day\", var_name=\"model_name\", value_name=\"utility\")\n",
    "data_plot[\"model_name\"] = data_plot[\"model_name\"].apply(lambda x: name_mapping[x])\n",
    "palette = {\n",
    "    \"Treeffuser\": \"Treeffuser\",\n",
    "    \"Treeffuser (def.)\": \"Treeffuser\",\n",
    "    \"NGBoost-Poisson\": \"NGBoost-Poisson\",\n",
    "    \"NGBoostPoisson (def.)\": \"NGBoost-Poisson\",\n",
    "    \"QuantileRegression-Tree\": \"QuantileRegression\",\n",
    "    \"QuantileRegression-Tree (def.)\": \"QuantileRegression\",\n",
    "    \n",
    "}\n",
    "\n",
    "ls = {\n",
    "    \"Treeffuser\": \"Tuned\",\n",
    "    \"NGBoost-Poisson\": \"Tuned\",\n",
    "    \"QuantileRegression-Tree\": \"Tuned\",\n",
    "    \"Treeffuser (def.)\": \"Default\",\n",
    "    \"NGBoostPoisson (def.)\": \"Default\",\n",
    "    \"QuantileRegression-Tree (def.)\": \"Default\",\n",
    "}\n",
    "\n",
    "\n",
    "for key in palette:\n",
    "    data_plot.loc[data_plot[\"model_name\"] == key, \"Model\"] = palette[key]\n",
    "    data_plot.loc[data_plot[\"model_name\"] == key, \"Hyperparameters\"] = ls[key]\n",
    "    \n",
    "# add a 0 for all models at day 0\n",
    "# data_plot = pd.concat(\n",
    "#     [data_plot, pd.DataFrame({\"day\": [0]*len(data_plot), \"utility\": [0]*len(data_plot), \"Model\": data_plot[\"Model\"], \"Training type\": data_plot[\"Training type\"]})])\n",
    "    \n",
    "sns.lineplot(\n",
    "    data=data_plot,\n",
    "    ax=ax,\n",
    "    hue=\"Model\",\n",
    "    x=\"day\",\n",
    "    y=\"utility\",\n",
    "    # palette=palette,\n",
    "    style=\"Hyperparameters\",\n",
    "    lw=2,\n",
    ")\n",
    "\n",
    "plt.ylim(0)\n",
    "plt.xlim(0, 30)\n",
    "plt.xlabel(\"Day\")\n",
    "plt.ylabel(\"Cumulative profit\")\n",
    "\n",
    "# reduce space between legend columns\n",
    "legend = plt.legend(ncols=2, frameon=False, columnspacing=0.1, handletextpad=0.3, handlelength=1.5)\n",
    "\n",
    "for i in [0,4]:\n",
    "    text = legend.get_texts()[i]\n",
    "    text.set_weight('bold')\n",
    "    x, y = text.get_position()\n",
    "    text.set_position((x - 15, y))\n",
    "    \n",
    "# make the handle of each line thicker\n",
    "for line in legend.get_lines():\n",
    "    line.set_linewidth(2.0)\n",
    "\n",
    "plt.tight_layout()\n",
    "sns.despine()\n",
    "plt.savefig(\"profit_loss.pdf\", bbox_inches=\"tight\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can actually fit some of the models"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# max_ppc, \n",
    "ppcs = [zeros, max_ppc] + [partial(quantile_ppc, quantile=q) for q in [0.9, 0.99, 0.999, 0.9999]]\n",
    "ppcs = [ppcs[i] for i in [0, 2, 3, 1, 4, 5]]\n",
    "\n",
    "for i in range(len(results)):\n",
    "  \n",
    "    ppc_results = get_ppcs(\n",
    "        y_samples=results[i][\"y_samples\"],\n",
    "        X_test=X_test_np,\n",
    "        y_test=y_test_np,\n",
    "        ppcs=ppcs,\n",
    "        number=i,\n",
    "        name=results[i][\"model_name\"]\n",
    "    )\n",
    "    results[i][\"ppc_results\"] = ppc_results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot the PPCs"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "ppc_tmp[\"Model\"].value_counts()"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "h.__dict__"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def proc_title(title):\n",
    "    x = title.replace(\"_\", \" \").capitalize()\n",
    "    x = x.replace(\"ppc\", \"\")\n",
    "    if x == \"Zeros\":\n",
    "        x = \"Proportion of zeros\"\n",
    "    return x\n",
    "\n",
    "ppc_number = len(ppcs)\n",
    "ppc_names = results[0][\"ppc_results\"].keys()\n",
    "\n",
    "n_cols = ppc_number // 2\n",
    "ax, fig = plt.subplots(\n",
    "    nrows=2, ncols=n_cols, figsize=(8, 4)\n",
    ")\n",
    "\n",
    "for ppc_idx, ppc_name in enumerate(ppc_names):\n",
    "    ax = fig[ppc_idx // n_cols, ppc_idx % n_cols]\n",
    "    \n",
    "    ppc_tmp = dict()\n",
    "    ppc_true = results[0][\"ppc_results\"][ppc_name][\"true\"]\n",
    "    for i, res in enumerate(results):\n",
    "        model_name = res[\"model_name\"]\n",
    "        samples = res[\"ppc_results\"][ppc_name][\"samples\"]\n",
    "        # make int\n",
    "        # samples = np.maximum(samples, 0)\n",
    "        # samples =  np.round(samples)\n",
    "        \n",
    "        ppc_tmp[model_name] = samples\n",
    "        \n",
    "        \n",
    "    ppc_tmp = pd.DataFrame(ppc_tmp)\n",
    "    ppc_tmp = ppc_tmp.melt(var_name=\"Model\", value_name=\"Samples\")\n",
    "    ppc_tmp = ppc_tmp[ppc_tmp[\"Model\"].isin([\"Treeffuser\", \"NGBoostPoisson\", \"QuantileRegressionTree_default\"])]\n",
    "    ppc_tmp[\"Model\"] = ppc_tmp[\"Model\"].replace({\"QuantileRegressionTree_default\": \"Quantile\\nRegression\"})\n",
    "    ppc_tmp[\"Model\"] = ppc_tmp[\"Model\"].replace({\"NGBoostPoisson\": \"NGBoost\\nPoisson\"})\n",
    "    if ppc_name == \"zeros\":\n",
    "        ppc_tmp[\"Samples\"] = ppc_tmp[\"Samples\"] / 10_000\n",
    "        ppc_true = ppc_true / 10_000\n",
    "    \n",
    "    discrete = ppc_tmp[\"Samples\"].nunique() < 20\n",
    "    print(discrete)\n",
    "    if ppc_name == \"quantile_ppc_0.9\":\n",
    "        bins = np.linspace(1, 5, 10)\n",
    "    else:\n",
    "        bins = \"auto\"\n",
    "        \n",
    "    if ppc_idx == 3:\n",
    "        legend = True\n",
    "    else:\n",
    "        legend = False\n",
    "    sns.histplot(\n",
    "        ppc_tmp, \n",
    "        x=\"Samples\", \n",
    "        hue=\"Model\", \n",
    "        bins=bins,\n",
    "        # kde=True, \n",
    "        stat=\"density\",\n",
    "        # bins=10,\n",
    "        # discrete=discrete,\n",
    "        common_norm=True,\n",
    "        ax=ax, \n",
    "        palette=\"deep\",\n",
    "        # alpha=1,\n",
    "        multiple=\"layer\",\n",
    "        legend=legend,\n",
    "        alpha=0.6,\n",
    "        element=\"poly\",\n",
    "    )\n",
    "    if ppc_idx == 3:\n",
    "        legend = ax.get_legend()\n",
    "        legend.set_title(\"\")\n",
    "        legend.get_frame().set_linewidth(0.0)\n",
    "        legend.get_frame().set_facecolor(\"none\")\n",
    "        # reduce font size\n",
    "        for t in legend.texts:\n",
    "            t.set_fontsize(11)\n",
    "        # move the legend a tiny bit to the right\n",
    "        legend.set_bbox_to_anchor((0.44, 1))\n",
    "        # change length of handles\n",
    "        for h in legend.legendHandles:\n",
    "            h.set_width(15)\n",
    "            h.set_height(10)\n",
    "        # move text to the left\n",
    "        for t in legend.texts:\n",
    "            t.set_position((t.get_position()[0] - 10, t.get_position()[1]))\n",
    "\n",
    "    max_x = ppc_true * 3\n",
    "    max_samples = np.max(ppc_tmp[\"Samples\"])\n",
    "    if max_samples > max_x:\n",
    "        ax.set_xlim(0, max_x)\n",
    "    if ppc_name == \"zeros\":\n",
    "        ax.set_xlim(0.35, 0.6)\n",
    "        \n",
    "    if ppc_name == \"quantile_ppc_0.9\":\n",
    "        ax.set_xlim(1, 5)\n",
    "    \n",
    "    ax.axvline(ppc_true, color=\"red\", linestyle=\"--\", label=\"True Value\")\n",
    "    ax.set_title(proc_title(ppc_name))\n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.set_yticks([])\n",
    "    \n",
    "sns.despine()\n",
    "fig = plt.gcf()\n",
    "# reduce space between subplots\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0.15, hspace=0.5)\n",
    "\n",
    "plt.savefig(\"ppc.pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Speed of treeffuser"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "tf_speed = Treeffuser()\n",
    "import time\n",
    "\n",
    "# training time\n",
    "start = time.time()\n",
    "tf_speed.fit(X_train_np, y_train_np)\n",
    "end = time.time()\n",
    "print(f\"Treeffuser took {end - start} seconds to fit on data of shape {X_train_np.shape}\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# sampling time\n",
    "start = time.time()\n",
    "tf_speed.sample(X_test_np, 1)\n",
    "end = time.time()\n",
    "print(f\"Treeffuser took {end - start} seconds to sample on data of shape {X_test_np.shape}\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
