{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2b9a91a8",
   "metadata": {},
   "source": [
    "### Real-world data results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1720b56d",
   "metadata": {},
   "source": [
    "##### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fda168c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from endo_regime_pcmci.persistent_endo_cit import PersistentEndoCIT\n",
    "from endo_regime_pcmci.sparse_endo_cit import SparseEndoCIT\n",
    "from endo_regime_pcmci.mixed_test_pcmci import PersistentEndoPCMCI, SparseEndoPCMCI, MixedTestPCMCI\n",
    "\n",
    "from generate_applied_dataset import get_dataset_for_era5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "937f283b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns \n",
    "import pandas as pd\n",
    "from datetime import datetime, timedelta\n",
    "\n",
    "from tigramite.data_processing import DataFrame\n",
    "from matplotlib import pyplot as plt\n",
    "from tigramite import plotting as tp\n",
    "\n",
    "from tigramite.independence_tests.robust_parcorr import RobustParCorr\n",
    "from tigramite.independence_tests.cmiknn_mixed import CMIknnMixed\n",
    "from tigramite.independence_tests.cmiknn import CMIknn\n",
    "from tigramite.independence_tests.regressionCI import RegressionCI\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5070d696",
   "metadata": {},
   "source": [
    "##### Select the variables and years to keep as in the tutorial "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59653361",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals_to_drop = ['ef', 'swvl1', 'swvl2', 'swvl123']\n",
    "columns_name = {'slhf': 'LH',\n",
    "                'sshf': 'SH',\n",
    "                'swvl3': 'SM',\n",
    "                't2m': 'T2m',\n",
    "                'tp': 'Prec',\n",
    "                'z': 'Stream',\n",
    "                'ssrd': 'SW'}\n",
    "ranges = [0.25, 0.75]\n",
    "pd_df, tig_df, type_mask = get_dataset_for_era5('./../data/data_era5.csv', years_to_keep=[1993, 2022], \n",
    "                                     vals_to_drop=vals_to_drop, \n",
    "                                     ranges=ranges,\n",
    "                                     columns_name=columns_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6bd3f55",
   "metadata": {},
   "source": [
    "##### Get the unique context values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b27dc84e",
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_contexts = pd_df['context'].unique()\n",
    "unique_contexts = np.sort(unique_contexts[unique_contexts != 999.]) # remove missing values \n",
    "print(unique_contexts) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd27c258",
   "metadata": {},
   "source": [
    "##### Generate Tigramite dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d788bd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataframe = DataFrame(pd_df.to_numpy(),\n",
    "                         missing_flag=999.,\n",
    "                         data_type=type_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dbaf728",
   "metadata": {},
   "source": [
    "##### Run the PAC-PCMCI version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b652d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "results_pac_contexts = dict()\n",
    "for i in range(len(unique_contexts)):\n",
    "    context = unique_contexts[i]\n",
    "    print(f'Running PCMCI for context {context}')\n",
    "\n",
    "    cond_ind_test = PersistentEndoCIT(\n",
    "         mixed_cit=RegressionCI(),\n",
    "         cont_cit=RobustParCorr(),\n",
    "         context_vars=[6],\n",
    "         context_values=[context])\n",
    "\n",
    "\n",
    "    pac_pcmci = PersistentEndoPCMCI(dataframe=dataframe, \n",
    "                                        cond_ind_test=cond_ind_test,\n",
    "                                        verbosity=0)\n",
    "\n",
    "    results = pac_pcmci.run_pcmciplus_fullpcmci(tau_max=3)\n",
    "    \n",
    "    if results is not None:  \n",
    "        results_pac_contexts[context] = results  \n",
    "        fig, ax = tp.plot_graph(\n",
    "                val_matrix=results_pac_contexts[context]['val_matrix'],\n",
    "                graph=results_pac_contexts[context]['graph'],\n",
    "                show_autodependency_lags=False,\n",
    "                var_names=list(pd_df.columns))\n",
    "        plt.suptitle(f'PAC-PCMCI, Context {int(context)}', fontsize=15)\n",
    "        fig.savefig(f'./era5_persistent_pcmci_context_{context}.png', dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1bb70c5",
   "metadata": {},
   "source": [
    "##### Run SAC-PCMCI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b9e4821",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_sac_contexts = dict()\n",
    "for i in range(len(unique_contexts)):\n",
    "    context = unique_contexts[i]\n",
    "    print(f'Running PCMCI for context {context}')\n",
    "\n",
    "    cond_ind_test = SparseEndoCIT(\n",
    "         mixed_cit=RegressionCI(),\n",
    "         cont_cit=RobustParCorr(),\n",
    "         context_vars=[6],\n",
    "         context_values=[context])\n",
    "\n",
    "    sac_pcmci = SparseEndoPCMCI(dataframe=dataframe, \n",
    "                                        cond_ind_test=cond_ind_test,\n",
    "                                        verbosity=0)\n",
    "\n",
    "    results = sac_pcmci.run_pcmciplus(tau_max=3)\n",
    "    \n",
    "    if results is not None:  \n",
    "        results_sac_contexts[context] = results  \n",
    "        fig, ax = tp.plot_graph(\n",
    "                val_matrix=results_sac_contexts[context]['val_matrix'],\n",
    "                graph=results_sac_contexts[context]['graph'],\n",
    "                show_autodependency_lags=False,\n",
    "                var_names=list(pd_df.columns))\n",
    "        plt.suptitle(f'SAC-PCMCI, Context {int(context)}', fontsize=15)\n",
    "        fig.savefig(f'./era5_sparse_pcmci_context_{context}.png', dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tigramite",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
