{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# COV-19 Case Prediction\n",
    "This notebook aims to create models to predict COV-19 cases in 313 different areas of the world, using GluonTS models.\n",
    "\n",
    "The data set is downloaded from Kaggle(https://www.kaggle.com/c/covid19-global-forecasting-week-5), you can download them and put all the csv files under a folder called \\\"covid19-global-forecasting-week-4\\\" in the same directory of this notebook.\n",
    "\n",
    "**NOTE: this notebook is for illustration purposes only, it has not been reviewed by epidemiological experts, and we do not claim that accurate epidemiological predictions can be made with the code that follows.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import mxnet as mx\n",
    "from mxnet import gluon\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import json\n",
    "import os\n",
    "from tqdm.autonotebook import tqdm\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_length = 20"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data and preprocessing\n",
    "We first load the data from files. Since the original data doesn't meet the requirements of GluonTS models, we need to do data preprocessing and generate new dataframe where each row represents a time series for a certain place."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total = pd.read_csv(\"./covid19-global-forecasting-week-5/train.csv\", index_col=False)\n",
    "test = pd.read_csv(\"./covid19-global-forecasting-week-5/test.csv\", index_col=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total = total[total[\"Target\"]==\"ConfirmedCases\"]\n",
    "test = test[test[\"Target\"]==\"ConfirmedCases\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total = total.fillna(\"\")\n",
    "total[\"name\"] = total[\"Country_Region\"] + \"_\" + total[\"Province_State\"] + \"_\" + total[\"County\"]\n",
    "total.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "country_list = sorted(list(set(total[\"name\"])))\n",
    "date_list = sorted(list(set(total[\"Date\"])))\n",
    "data_dic = {\"name\": country_list}\n",
    "\n",
    "for date in date_list:\n",
    "    tmp = total[total[\"Date\"]==date][[\"name\", \"Date\", \"TargetValue\"]]\n",
    "    tmp = tmp.pivot(index=\"name\", columns=\"Date\", values=\"TargetValue\")\n",
    "    tmp_values = tmp[date].values\n",
    "    data_dic[date] = tmp_values\n",
    "new_df = pd.DataFrame(data_dic)\n",
    "new_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "The original data is daily confirmed cases, we tansform it into a accumulative one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_values = new_df.drop(\"name\", axis=1).values\n",
    "row, col = total_values.shape\n",
    "for i in range(row):\n",
    "    tmp = total_values[i]\n",
    "    for j in range(col):\n",
    "        if j > 0:\n",
    "            tmp[j] = tmp[j] + tmp[j - 1]\n",
    "    total_values[i] = tmp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the features for populations and weight, and apply min-max scale to it, also divide all the countries into three different type according to the weight it has."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_dic_population = {}\n",
    "feature_dic_weight = {}\n",
    "for date in date_list:\n",
    "    tmp = total[total[\"Date\"]==date][[\"Date\", \"name\", \"Population\", \"Weight\"]]\n",
    "    population = tmp.pivot(index=\"name\", columns=\"Date\", values=\"Population\")\n",
    "    weight = tmp.pivot(index=\"name\", columns=\"Date\", values=\"Weight\")\n",
    "    feature_dic_population[date] = population[date].values\n",
    "    feature_dic_weight[date] = weight[date].values\n",
    "feature_df_population = pd.DataFrame(feature_dic_population)\n",
    "feature_df_weight = pd.DataFrame(feature_dic_weight)\n",
    "# feature_df_population.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "populations = []\n",
    "weights = []\n",
    "for i in range(feature_df_population.shape[0]):\n",
    "    populations.append(feature_df_population.values[i][0])\n",
    "    weights.append(feature_df_weight.values[i][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def min_max_scale(lst):\n",
    "    minimum = min(lst)\n",
    "    maximum = max(lst)\n",
    "    new = []\n",
    "    for i in range(len(lst)):\n",
    "        new.append((lst[i] - minimum) / (maximum - minimum))\n",
    "    return new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaled_populations = min_max_scale(populations)\n",
    "scaled_weights = min_max_scale(weights)\n",
    "stat_real_features = []\n",
    "stat_cat_features = []\n",
    "for i in range(len(scaled_weights)):\n",
    "    if 0 <= scaled_weights[i] <= 0.33:\n",
    "#         country with small number of people\n",
    "        stat_cat_features.append([1])\n",
    "    elif 0.33 < scaled_weights[i] <= 0.67:\n",
    "#         country with median number of people\n",
    "        stat_cat_features.append([2])\n",
    "    else:\n",
    "#         country with large number of people\n",
    "        stat_cat_features.append([3])\n",
    "    stat_real_features.append([scaled_weights[i]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Create training dataset and train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from gluonts.dataset.common import load_datasets, ListDataset\n",
    "from gluonts.dataset.field_names import FieldName\n",
    "from copy import copy\n",
    "\n",
    "train_df = new_df.drop([\"name\"], axis=1)\n",
    "test_target_values = total_values.copy()\n",
    "train_target_values = [ts[:-prediction_length] for ts in total_values]\n",
    "cat_cardinality = [3]\n",
    "\n",
    "start_date = [pd.Timestamp(\"2020-01-23\", freq='1D') for _ in range(len(new_df))]\n",
    "train_ds = ListDataset([\n",
    "    {\n",
    "        FieldName.TARGET: target,\n",
    "        FieldName.START: start,\n",
    "        FieldName.FEAT_STATIC_REAL: static_real,\n",
    "        FieldName.FEAT_STATIC_CAT: static_cat\n",
    "    }\n",
    "    for (target, start, static_real,  static_cat) in zip(train_target_values,\n",
    "                                         start_date,\n",
    "                                         stat_real_features,\n",
    "                                        stat_cat_features)\n",
    "], freq=\"D\")\n",
    "\n",
    "test_ds = ListDataset([\n",
    "    {\n",
    "        FieldName.TARGET: target,\n",
    "        FieldName.START: start,\n",
    "        FieldName.FEAT_STATIC_REAL: static_real,\n",
    "        FieldName.FEAT_STATIC_CAT: static_cat\n",
    "    }\n",
    "    for (target, start, static_real,  static_cat) in zip(test_target_values,\n",
    "                                         start_date,\n",
    "                                        stat_real_features, \n",
    "                                        stat_cat_features)\n",
    "], freq=\"D\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "next(iter(train_ds))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from gluonts.model.deepar import DeepAREstimator\n",
    "from gluonts.distribution.neg_binomial import NegativeBinomialOutput\n",
    "from gluonts.mx.trainer import Trainer\n",
    "\n",
    "n = 50\n",
    "estimator = DeepAREstimator(\n",
    "    prediction_length=prediction_length,\n",
    "    freq=\"D\",\n",
    "    distr_output = NegativeBinomialOutput(),\n",
    "    use_feat_static_real=True,\n",
    "#     use_feat_static_cat=True,\n",
    "#     cardinality=cat_cardinality,\n",
    "    trainer=Trainer(\n",
    "        learning_rate=1e-5,\n",
    "        epochs=n,\n",
    "        num_batches_per_epoch=50,\n",
    "        batch_size=32\n",
    "    )\n",
    ")\n",
    "\n",
    "predictor = estimator.train(train_ds)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Evaluate the model"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from gluonts.evaluation.backtest import make_evaluation_predictions\n",
    "\n",
    "forecast_it, ts_it = make_evaluation_predictions(\n",
    "    dataset=test_ds,\n",
    "    predictor=predictor,\n",
    "    num_samples=100\n",
    ")\n",
    "\n",
    "print(\"Obtaining time series conditioning values ...\")\n",
    "tss = list(tqdm(ts_it, total=len(test_ds)))\n",
    "print(\"Obtaining time series predictions ...\")\n",
    "forecasts = list(tqdm(forecast_it, total=len(test_ds)))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from gluonts.evaluation import Evaluator\n",
    "\n",
    "\n",
    "class CustomEvaluator(Evaluator):\n",
    "\n",
    "    def get_metrics_per_ts(self, time_series, forecast):\n",
    "        successive_diff = np.diff(time_series.values.reshape(len(time_series)))\n",
    "        successive_diff = successive_diff ** 2\n",
    "        successive_diff = successive_diff[:-prediction_length]\n",
    "        denom = np.mean(successive_diff)\n",
    "        pred_values = forecast.samples.mean(axis=0)\n",
    "        true_values = time_series.values.reshape(len(time_series))[-prediction_length:]\n",
    "        num = np.mean((pred_values - true_values) ** 2)\n",
    "        rmsse = num / denom\n",
    "        metrics = super().get_metrics_per_ts(time_series, forecast)\n",
    "        metrics[\"RMSSE\"] = rmsse\n",
    "        return metrics\n",
    "\n",
    "    def get_aggregate_metrics(self, metric_per_ts):\n",
    "        wrmsse = metric_per_ts[\"RMSSE\"].mean()\n",
    "        agg_metric, _ = super().get_aggregate_metrics(metric_per_ts)\n",
    "        agg_metric[\"MRMSSE\"] = wrmsse\n",
    "        return agg_metric, metric_per_ts\n",
    "\n",
    "\n",
    "evaluator = CustomEvaluator(quantiles=[0.5, 0.67, 0.95, 0.99])\n",
    "agg_metrics, item_metrics = evaluator(iter(tss), iter(forecasts), num_series=len(test_ds))\n",
    "print(json.dumps(agg_metrics, indent=4))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Plot graphs for the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "next(iter(train_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gluonts.model.deepar import DeepAREstimator\n",
    "from gluonts.distribution.neg_binomial import NegativeBinomialOutput\n",
    "from gluonts.mx.trainer import Trainer\n",
    "\n",
    "n = 50\n",
    "estimator = DeepAREstimator(\n",
    "    prediction_length=prediction_length,\n",
    "    freq=\"D\",\n",
    "    distr_output = NegativeBinomialOutput(),\n",
    "    use_feat_static_real=True,\n",
    "#     use_feat_static_cat=True,\n",
    "#     cardinality=cat_cardinality,\n",
    "    trainer=Trainer(\n",
    "        learning_rate=1e-5,\n",
    "        epochs=n,\n",
    "        num_batches_per_epoch=50,\n",
    "        batch_size=32\n",
    "    )\n",
    ")\n",
    "\n",
    "predictor = estimator.train(train_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gluonts.evaluation.backtest import make_evaluation_predictions\n",
    "\n",
    "forecast_it, ts_it = make_evaluation_predictions(\n",
    "    dataset=test_ds,\n",
    "    predictor=predictor,\n",
    "    num_samples=100\n",
    ")\n",
    "\n",
    "print(\"Obtaining time series conditioning values ...\")\n",
    "tss = list(tqdm(ts_it, total=len(test_ds)))\n",
    "print(\"Obtaining time series predictions ...\")\n",
    "forecasts = list(tqdm(forecast_it, total=len(test_ds)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gluonts.evaluation import Evaluator\n",
    "\n",
    "\n",
    "class CustomEvaluator(Evaluator):\n",
    "\n",
    "    def get_metrics_per_ts(self, time_series, forecast):\n",
    "        successive_diff = np.diff(time_series.values.reshape(len(time_series)))\n",
    "        successive_diff = successive_diff ** 2\n",
    "        successive_diff = successive_diff[:-prediction_length]\n",
    "        denom = np.mean(successive_diff)\n",
    "        pred_values = forecast.samples.mean(axis=0)\n",
    "        true_values = time_series.values.reshape(len(time_series))[-prediction_length:]\n",
    "        num = np.mean((pred_values - true_values) ** 2)\n",
    "        rmsse = num / denom\n",
    "        metrics = super().get_metrics_per_ts(time_series, forecast)\n",
    "        metrics[\"RMSSE\"] = rmsse\n",
    "        return metrics\n",
    "\n",
    "    def get_aggregate_metrics(self, metric_per_ts):\n",
    "        wrmsse = metric_per_ts[\"RMSSE\"].mean()\n",
    "        agg_metric, _ = super().get_aggregate_metrics(metric_per_ts)\n",
    "        agg_metric[\"MRMSSE\"] = wrmsse\n",
    "        return agg_metric, metric_per_ts\n",
    "\n",
    "\n",
    "evaluator = CustomEvaluator(quantiles=[0.5, 0.67, 0.95, 0.99])\n",
    "agg_metrics, item_metrics = evaluator(iter(tss), iter(forecasts), num_series=len(test_ds))\n",
    "print(json.dumps(agg_metrics, indent=4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot graphs for the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_log_path = \"./plots/\"\n",
    "directory = os.path.dirname(plot_log_path)\n",
    "if not os.path.exists(directory):\n",
    "    os.makedirs(directory)\n",
    "    \n",
    "\n",
    "def plot_prob_forecasts(ts_entry, forecast_entry, path, sample_id, inline=True):\n",
    "    plot_length = 150\n",
    "    prediction_intervals = (50, 67, 95, 99)\n",
    "    legend = [\"observations\", \"median prediction\"] + [f\"{k}% prediction interval\" for k in prediction_intervals][::-1]\n",
    "\n",
    "    _, ax = plt.subplots(1, 1, figsize=(10, 7))\n",
    "    ts_entry[-plot_length:].plot(ax=ax)\n",
    "    forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')\n",
    "    ax.axvline(ts_entry.index[-prediction_length], color='r')\n",
    "    plt.legend(legend, loc=\"upper left\")\n",
    "    if inline:\n",
    "        plt.show()\n",
    "        plt.clf()\n",
    "    else:\n",
    "        plt.savefig('{}forecast_{}.pdf'.format(path, sample_id))\n",
    "        plt.close()\n",
    "\n",
    "print(\"Plotting time series predictions ...\")\n",
    "for i in tqdm(range(5)):\n",
    "    ts_entry = tss[i]\n",
    "    forecast_entry = forecasts[i]\n",
    "    plot_prob_forecasts(ts_entry, forecast_entry, plot_log_path, i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comments\n",
    "The result is seemingly good but there is still much space for improvements. The main problem is that the data got from kaggle contain only a few features which limits us from creating more precise models. The current is very close to a baseline model because it contains only one extra feature. The next thing to do is to find additional data on kaggle or from the internet to improve the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
