{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training L1 models and saving their predictions\n",
    "This notebook shows how to train the base models and save their predictions to `artifacts/`. These saved predictions can be then reused to train the L2 and L3 models without the need to retrain the base models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fev\n",
    "import pprint\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame\n",
    "from autogluon.common.savers import save_pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset_name='monash_nn5_weekly'\n",
      "{'dataset_path': 'autogluon/chronos_datasets',\n",
      " 'dataset_config': 'monash_nn5_weekly',\n",
      " 'horizon': 8,\n",
      " 'cutoff': -8,\n",
      " 'lead_time': 1,\n",
      " 'min_ts_length': 9,\n",
      " 'max_context_length': None,\n",
      " 'seasonality': 1,\n",
      " 'eval_metric': 'SQL',\n",
      " 'extra_metrics': ['MASE', 'WAPE', 'WQL', 'SQL'],\n",
      " 'quantile_levels': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],\n",
      " 'id_column': 'id',\n",
      " 'timestamp_column': 'timestamp',\n",
      " 'target_column': 'target',\n",
      " 'multiple_target_columns': None,\n",
      " 'past_dynamic_columns': [],\n",
      " 'excluded_columns': []}\n"
     ]
    }
   ],
   "source": [
    "# Load the task definition\n",
    "\n",
    "# LOTSA datasets (make sure to run `datasets/download_lotsa.py` to download the data first)\n",
    "# benchmark = fev.Benchmark.from_yaml(\"../configs/tasks_lotsa.yaml\")\n",
    "\n",
    "# Chronos datasets:\n",
    "benchmark = fev.Benchmark.from_yaml(\"../configs/tasks_chronos.yaml\")\n",
    "\n",
    "task = benchmark.tasks[13]\n",
    "dataset_name = task.dataset_config or task.dataset_name\n",
    "print(f\"{dataset_name=}\")\n",
    "pprint.pprint(task.to_dict(), sort_dicts=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2bd20b3f06804dc3a01cadac213553aa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating dataset format (num_proc=32):   0%|          | 0/111 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_data, _ = fev.convert_input_data(task, adapter=\"autogluon\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Used in the paper\n",
    "all_models = {\n",
    "    \"SeasonalNaive\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"AutoETS\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"DynamicOptimizedTheta\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"DeepAR\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"PatchTST\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"TemporalFusionTransformer\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"DirectTabular\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"RecursiveTabular\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"TiDEModel\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"DLinearModel\": {\"scaling\": \"std\", \"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"Chronos\": {\"model_path\": \"bolt_base\", \"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "}\n",
    "\n",
    "# Example for quick results\n",
    "dummy_models = {\n",
    "    \"SeasonalNaive\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"DynamicOptimizedTheta\": {\"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "    \"Chronos\": {\"model_path\": \"bolt_base\", \"ag_args_fit\": {\"max_time_limit\": 9000}},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<autogluon.timeseries.predictor.TimeSeriesPredictor at 0x7fe30b4c5b10>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor = TimeSeriesPredictor(\n",
    "    prediction_length=task.horizon,\n",
    "    eval_metric=task.eval_metric,\n",
    "    verbosity=0,\n",
    ")\n",
    "predictor.fit(\n",
    "    train_data,\n",
    "    hyperparameters=dummy_models,\n",
    "    refit_every_n_windows=1,\n",
    "    enable_ensemble=False,\n",
    "    num_val_windows=5,\n",
    "    time_limit=400000,  # just a large number for the per-model time limit to kick in\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load full_data to compute the base model predictions\n",
    "def load_full_data(task: fev.Task) -> TimeSeriesDataFrame:\n",
    "    train_df, future_df, static_df = fev.convert_input_data(task, adapter=\"pandas\", trust_remote_code=True)\n",
    "    test_data = task.get_test_data()\n",
    "    future_df[task.target_column] = np.concatenate(test_data[task.target_column])\n",
    "    return TimeSeriesDataFrame.from_data_frame(\n",
    "        pd.concat([train_df, future_df]),\n",
    "        id_column=task.id_column,\n",
    "        timestamp_column=task.timestamp_column,\n",
    "        static_features_df=static_df,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_data = load_full_data(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save OOF predictions for validation and test data\n",
    "artifact = predictor._simulation_artifact(full_data)\n",
    "save_pkl.save(f\"../artifacts/{dataset_name}.pkl\", artifact)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "atse",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
