{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load some libraries\n",
    "import sys\n",
    "import os\n",
    "import pickle\n",
    "import gzip\n",
    "sys.path.insert(1, '..')\n",
    "os.chdir('..')\n",
    "\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import darts\n",
    "from darts import metrics\n",
    "\n",
    "from lib.gluformer.model import *\n",
    "from lib.latent_ode.trainer_glunet import *\n",
    "from utils.darts_processing import *\n",
    "from utils.darts_dataset import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MODELS: TRANSFORMER, NHiTS, TFT, XGBOOST, LINEAR REGRESSION\n",
    "\n",
    "# model params\n",
    "model_params = {'transformer': {'darts': models.TransformerModel, 'darts_data': SamplingDatasetInferencePast, 'use_covs': False, 'use_static_covs': False, 'cov_type': 'past'},\n",
    "                'nhits': {'darts': models.NHiTSModel, 'darts_data': SamplingDatasetInferencePast, 'use_covs': False, 'use_static_covs': False, 'cov_type': 'past'},\n",
    "                'tft': {'darts': models.TFTModel, 'darts_data': SamplingDatasetInferenceMixed, 'use_covs': False, 'use_static_covs': True, 'cov_type': 'mixed'},\n",
    "                'xgboost': {'darts': models.XGBModel, 'use_covs': False, 'use_static_covs': False, 'cov_type': 'past'},\n",
    "                'linreg': {'darts': models.LinearRegressionModel, 'use_covs': False, 'use_static_covs': False, 'cov_type': 'past'}}\n",
    "# data sets\n",
    "datasets = ['weinstock', 'dubosson', 'colas', 'iglu', 'hall']\n",
    "save_trues = {}\n",
    "save_forecasts = {}\n",
    "save_inputs = {}\n",
    "# iterate through models and datasets\n",
    "for model_name in model_params.keys():\n",
    "    for dataset in datasets:\n",
    "        print(f'Testing {model_name} for {dataset}')\n",
    "        formatter, series, scalers = load_data(seed=0, study_file=None, dataset=dataset, \n",
    "                                               use_covs=model_params[model_name]['use_covs'], \n",
    "                                             use_static_covs=model_params[model_name]['use_static_covs'],\n",
    "                                             cov_type=model_params[model_name]['cov_type'])\n",
    "        # load model or refit model\n",
    "        if model_name in ['tft', 'transformer', 'nhits']:\n",
    "            # load model: transformer\n",
    "            model = model_params[model_name]['darts'](input_chunk_length=formatter.params[model_name]['in_len'],\n",
    "                                              output_chunk_length=formatter.params['length_pred'])\n",
    "            model = model.load_from_checkpoint(f'tensorboard_{model_name}_{dataset}', work_dir = './output', best=True)\n",
    "            # define dataset for inference\n",
    "            test_dataset = model_params[model_name]['darts_data'](target_series=series['test']['target'],\n",
    "                                                              n=formatter.params['length_pred'],\n",
    "                                                                input_chunk_length=formatter.params[model_name]['in_len'],\n",
    "                                                                  output_chunk_length=formatter.params['length_pred'],\n",
    "                                                                  use_static_covariates=model_params[model_name]['use_static_covs'],\n",
    "                                                                  max_samples_per_ts = None)\n",
    "            # get predictions\n",
    "            forecasts = model.predict_from_dataset(n=formatter.params['length_pred'], \n",
    "                                                   input_series_dataset=test_dataset,\n",
    "                                                   verbose=True,\n",
    "                                                   num_samples=20 if model_name == 'tft' else 1)\n",
    "            forecasts = scalers['target'].inverse_transform(forecasts)\n",
    "            save_forecasts[f'{model_name}_{dataset}'] = forecasts\n",
    "            # get true values\n",
    "            save_trues[f'{model_name}_{dataset}'] = [test_dataset.evalsample(i) for i in range(len(test_dataset))]\n",
    "            save_trues[f'{model_name}_{dataset}'] = scalers['target'].inverse_transform(save_trues[f'{model_name}_{dataset}'])\n",
    "            # get inputs\n",
    "            inputs = [test_dataset[i][0] for i in range(len(test_dataset))]\n",
    "            save_inputs[f'{model_name}_{dataset}'] = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_\n",
    "        elif model_name == 'xgboost':\n",
    "            # load model: xgboost\n",
    "            model = model_params[model_name]['darts'](lags=formatter.params[model_name]['in_len'], \n",
    "                                                      learning_rate=formatter.params[model_name]['lr'],\n",
    "                                                      subsample=formatter.params[model_name]['subsample'],\n",
    "                                                      min_child_weight=formatter.params[model_name]['min_child_weight'],\n",
    "                                                      colsample_bytree=formatter.params[model_name]['colsample_bytree'],\n",
    "                                                      max_depth=formatter.params[model_name]['max_depth'],\n",
    "                                                      gamma=formatter.params[model_name]['gamma'],\n",
    "                                                      reg_alpha=formatter.params[model_name]['alpha'],\n",
    "                                                      reg_lambda=formatter.params[model_name]['lambda_'],\n",
    "                                                      n_estimators=formatter.params[model_name]['n_estimators'],\n",
    "                                                      random_state=0)\n",
    "            # fit model\n",
    "            model.fit(series['train']['target'])\n",
    "            # get predictions\n",
    "            forecasts = model.historical_forecasts(series['test']['target'],\n",
    "                                                   forecast_horizon=formatter.params['length_pred'],\n",
    "                                                   stride=1,\n",
    "                                                   retrain=False,\n",
    "                                                   verbose=True,\n",
    "                                                   last_points_only=False)\n",
    "            forecasts = [scalers['target'].inverse_transform(forecast) for forecast in forecasts]\n",
    "            save_forecasts[f'{model_name}_{dataset}'] = forecasts\n",
    "            # get true values\n",
    "            save_trues[f'{model_name}_{dataset}'] = scalers['target'].inverse_transform(series['test']['target'])\n",
    "        elif model_name == 'linreg':\n",
    "            # load model: linear regression\n",
    "            model = models.LinearRegressionModel(lags = formatter.params[model_name]['in_len'],\n",
    "                                                 output_chunk_length = formatter.params['length_pred'])\n",
    "            model.fit(series['train']['target'])\n",
    "            # get predictions\n",
    "            forecasts = model.historical_forecasts(series['test']['target'],\n",
    "                                                forecast_horizon=formatter.params['length_pred'], \n",
    "                                                stride=1,\n",
    "                                                retrain=False,\n",
    "                                                verbose=False,\n",
    "                                                last_points_only=False)\n",
    "            forecasts = [scalers['target'].inverse_transform(forecast) for forecast in forecasts]\n",
    "            save_forecasts[f'{model_name}_{dataset}'] = forecasts\n",
    "            # get true values\n",
    "            save_trues[f'{model_name}_{dataset}'] = scalers['target'].inverse_transform(series['test']['target'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MODELS: LATENT ODE and GLUFROMER\n",
    "device = 'cuda'\n",
    "for dataset in datasets:\n",
    "    print(f'Testing {dataset}')\n",
    "    formatter, series, scalers = load_data(seed=0, study_file=None, dataset=dataset, use_covs=True, use_static_covs=True)\n",
    "    # define dataset for inference: gluformer\n",
    "    dataset_test_glufo = SamplingDatasetInferenceDual(target_series=series['test']['target'],\n",
    "                                                      covariates=series['test']['future'],\n",
    "                                                      input_chunk_length=formatter.params['gluformer']['in_len'],\n",
    "                                                      output_chunk_length=formatter.params['length_pred'],\n",
    "                                                      use_static_covariates=True,\n",
    "                                                      array_output_only=True)\n",
    "    # define dataset for inference: latent ode\n",
    "    dataset_test_latod = SamplingDatasetInferenceDual(target_series=series['test']['target'],\n",
    "                                                      covariates=series['test']['future'],\n",
    "                                                      input_chunk_length=formatter.params['latentode']['in_len'],\n",
    "                                                      output_chunk_length=formatter.params['length_pred'],\n",
    "                                                      use_static_covariates=True,\n",
    "                                                      array_output_only=True)\n",
    "    # load model: gluformer\n",
    "    num_dynamic_features = series['train']['future'][-1].n_components\n",
    "    num_static_features = series['train']['static'][-1].n_components\n",
    "    glufo = Gluformer(d_model = formatter.params['gluformer']['d_model'],\n",
    "                      n_heads = formatter.params['gluformer']['n_heads'],\n",
    "                      d_fcn = formatter.params['gluformer']['d_fcn'],\n",
    "                      r_drop = 0.2, \n",
    "                      activ = 'relu', \n",
    "                      num_enc_layers = formatter.params['gluformer']['num_enc_layers'], \n",
    "                      num_dec_layers = formatter.params['gluformer']['num_dec_layers'],\n",
    "                      distil = True, \n",
    "                      len_seq = formatter.params['gluformer']['in_len'],\n",
    "                      label_len = formatter.params['gluformer']['in_len'] // 3,\n",
    "                      len_pred = formatter.params['length_pred'],\n",
    "                      num_dynamic_features = num_dynamic_features,\n",
    "                      num_static_features = num_static_features,)\n",
    "    glufo.to(device)\n",
    "    glufo.load_state_dict(torch.load(f'./output/tensorboard_gluformer_{dataset}/model.pt', map_location=torch.device(device)))\n",
    "    # load model: latent ode\n",
    "    latod = LatentODEWrapper(device = device,\n",
    "                             latents = formatter.params['latentode']['latents'],\n",
    "                             rec_dims = formatter.params['latentode']['rec_dims'],\n",
    "                             rec_layers = formatter.params['latentode']['rec_layers'],\n",
    "                             gen_layers = formatter.params['latentode']['gen_layers'],\n",
    "                             units = formatter.params['latentode']['units'],\n",
    "                             gru_units = formatter.params['latentode']['gru_units'],)\n",
    "    latod.load(f'./output/tensorboard_latentode_{dataset}/model.ckpt', device)\n",
    "    # get predictions: gluformer\n",
    "    print('Gluformer')\n",
    "    forecasts, _ = glufo.predict(dataset_test_glufo,\n",
    "                                 batch_size=8,\n",
    "                                 num_samples=10,\n",
    "                                 device=device,\n",
    "                                 use_tqdm=True)\n",
    "    forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_\n",
    "    trues = [dataset_test_glufo.evalsample(i) for i in range(len(dataset_test_glufo))]\n",
    "    trues = scalers['target'].inverse_transform(trues)\n",
    "    inputs = [dataset_test_glufo[i][0] for i in range(len(dataset_test_glufo))]\n",
    "    inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_\n",
    "    save_forecasts[f'gluformer_{dataset}'] = forecasts\n",
    "    save_trues[f'gluformer_{dataset}'] = trues\n",
    "    save_inputs[f'gluformer_{dataset}'] = inputs\n",
    "    # get predictions: latent ode\n",
    "    print('Latent ODE')\n",
    "    forecasts = latod.predict(dataset_test_latod,\n",
    "                              batch_size=32,\n",
    "                              num_samples=20,\n",
    "                              device=device,\n",
    "                              use_tqdm=True,)\n",
    "    forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_\n",
    "    trues = [dataset_test_latod.evalsample(i) for i in range(len(dataset_test_latod))]\n",
    "    trues = scalers['target'].inverse_transform(trues)\n",
    "    inputs = [dataset_test_latod[i][0] for i in range(len(dataset_test_latod))]\n",
    "    inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_\n",
    "    save_forecasts[f'latentode_{dataset}'] = forecasts\n",
    "    save_trues[f'latentode_{dataset}'] = trues\n",
    "    save_inputs[f'latentode_{dataset}'] = inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save forecasts\n",
    "with gzip.open('./paper_results/data/compressed_forecasts.pkl', 'wb') as file:\n",
    "    pickle.dump(save_forecasts, file)\n",
    "# save true values\n",
    "with gzip.open('./paper_results/data/compressed_trues.pkl', 'wb') as file:\n",
    "    pickle.dump(save_trues, file)\n",
    "# save inputs\n",
    "with gzip.open('./paper_results/data/compressed_inputs.pkl', 'wb') as file:\n",
    "    pickle.dump(save_inputs, file)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "glunet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
