{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "04f2b39f",
   "metadata": {},
   "source": [
    "# Real Epid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fc917be",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from main import set_pytorch_seed\n",
    "from post_processing import get_model, make_callable, plot_predictions\n",
    "import pandas as pd\n",
    "import sympy as sp\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, root_mean_squared_error\n",
    "from post_processing import build_model_from_file as build_kan\n",
    "from post_processing_mpnn import build_model_from_file_mpnn as build_mpnn\n",
    "from post_processing_mpnn import build_model_from_file_llc as build_llc\n",
    "\n",
    "set_pytorch_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eef0454",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fine_tuning_coefficients import get_scaler\n",
    "\n",
    "def eval_real_epid_int(data, countries_dict, build_symb_model, inferred_coeffs, scaler=None, use_euler=False, tr_perc = 0.8, \n",
    "                       mask = None, device='cuda:0', model_name='', results_dict = {}):\n",
    "    y_true = data[0].y.detach().cpu().numpy()\n",
    "    if mask is not None:\n",
    "        y_true = y_true[:, mask, :]\n",
    "    y_pred = np.zeros_like(y_true)\n",
    "    \n",
    "    for country_name, node_idx in countries_dict.items():\n",
    "        symb_model = build_symb_model(country_name, inferred_coeffs)\n",
    "        symb_model = symb_model.to(device)\n",
    "        # print(f\"{country_name}\")\n",
    "        data_0 = data[0].cpu()\n",
    "        if scaler is not None:\n",
    "            tmp = scaler.transform(data[0].x)\n",
    "            data_0 = data[0]\n",
    "            data_0.x = tmp\n",
    "        \n",
    "        data_0 = data_0.to(device)\n",
    "        if use_euler:\n",
    "            symb_model.integration_method = \"euler\"\n",
    "            data_0.t_span = torch.arange(y_true.shape[0] + 1, device=data_0.x.device, dtype=data_0.t_span.dtype)\n",
    "        \n",
    "        try:\n",
    "            pred = symb_model(data_0).detach().cpu().numpy()\n",
    "            if mask is not None:\n",
    "                pred = pred[:, mask, :]\n",
    "        except AssertionError:\n",
    "            print(\"Failed\")\n",
    "            continue\n",
    "        \n",
    "        if scaler is not None:\n",
    "            pred = scaler.inverse_transform(pred)\n",
    "        \n",
    "        y_pred[:, node_idx, :] = pred[:, node_idx, :]\n",
    "    \n",
    "        \n",
    "    tr_len = y_true.shape[0]\n",
    "    tr_end = int(tr_perc * tr_len)\n",
    "    y_true_val = y_true[tr_end:, :, :]\n",
    "    y_pred_val = y_pred[tr_end:, :, :]\n",
    "    \n",
    "    mae_test = mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())\n",
    "    mse_test = mean_squared_error(y_true_val.flatten(), y_pred_val.flatten())\n",
    "    rmse_test = root_mean_squared_error(y_true_val.flatten(), y_pred_val.flatten())\n",
    "    \n",
    "    mae_all = mean_absolute_error(y_true.flatten(), y_pred.flatten())\n",
    "    mse_all = mean_squared_error(y_true.flatten(), y_pred.flatten())\n",
    "    rmse_all = root_mean_squared_error(y_true.flatten(), y_pred.flatten())\n",
    "    \n",
    "    results_dict[model_name][\"Test MAE_Traj\"].append(mae_test)\n",
    "    results_dict[model_name][\"Test MSE_Traj\"].append(mse_test)\n",
    "    results_dict[model_name][\"Test RMSE_Traj\"].append(rmse_test)\n",
    "    \n",
    "    results_dict[model_name][\"All MAE_Traj\"].append(mae_all)\n",
    "    results_dict[model_name][\"All MSE_Traj\"].append(mse_all)\n",
    "    results_dict[model_name][\"All RMSE_Traj\"].append(rmse_all)\n",
    "    \n",
    "    \n",
    "    print(f\"Test MAE: {mae_test}\")\n",
    "    print(f\"Overall MAE: {mae_all}\")\n",
    "    \n",
    "    print(f\"\\nTest MSE: {mse_test}\")\n",
    "    print(f\"Overall MSE: {mse_all}\")\n",
    "    \n",
    "    print(f\"\\nTest RMSE: {rmse_test}\")\n",
    "    print(f\"Overall RMSE: {rmse_all}\")\n",
    "    \n",
    "    return y_true, y_pred, y_true_val, y_pred_val \n",
    "\n",
    "\n",
    "def eval_real_epid_journal(data, countries_dict, build_symb_model, inferred_coeffs, tr_perc = 0.8, step_size=1.0, scaler = None,\n",
    "                           device='cpu', mask=None, model_name = '', results_dict = {}):\n",
    "    def get_dxdt_pred(data, symb_model):\n",
    "        dxdt_pred = []\n",
    "        for snapshot in data:\n",
    "            if scaler is not None:\n",
    "                snapshot.x = scaler.transform(snapshot.x)\n",
    "            snapshot = snapshot.to(device)\n",
    "            dxdt_pred.append(symb_model(snapshot))\n",
    "        \n",
    "        return torch.stack(dxdt_pred, dim=0)\n",
    "    \n",
    "    def sum_over_dxdt(dxdt_pred):\n",
    "        out = []\n",
    "        for i in range(dxdt_pred.shape[0]):\n",
    "            out.append(torch.sum(dxdt_pred[0:i+1, :, :], dim=0)) \n",
    "        \n",
    "        return torch.stack(out, dim=0)\n",
    "        \n",
    "    def integrate(out, x0):\n",
    "        pred = [x0]\n",
    "        for i in range(out.shape[0] - 1):\n",
    "            pred.append(x0 + step_size*out[i, :, :])\n",
    "        return torch.stack(pred, dim=0)\n",
    "      \n",
    "    x0 = data[0].x\n",
    "    if scaler is not None:\n",
    "        x0 = scaler.transform(x0)\n",
    "    x0 = x0.to(device)\n",
    "    y_true = torch.stack([d.x for d in data], dim=0).detach().cpu().numpy()\n",
    "    if mask is not None:\n",
    "        y_true = y_true[:, mask, :]\n",
    "    y_pred = np.zeros_like(y_true)\n",
    "    \n",
    "    for country_name, node_idx in countries_dict.items():\n",
    "        symb_model = build_symb_model(country_name, inferred_coeffs)\n",
    "        symb_model = symb_model.to(device)\n",
    "        symb_model.predict_deriv = True\n",
    "        dxdt_pred = get_dxdt_pred(data, symb_model)\n",
    "        out = sum_over_dxdt(dxdt_pred)\n",
    "        pred = integrate(out, x0).detach().cpu().numpy()\n",
    "        if mask is not None:\n",
    "            pred = pred[:, mask, :]\n",
    "        y_pred[:, node_idx, :] = pred[:, node_idx, :]\n",
    "    \n",
    "    if scaler is not None:\n",
    "        y_pred = scaler.inverse_transform(y_pred)    \n",
    "    \n",
    "    tr_len = y_true.shape[0]\n",
    "    tr_end = int(tr_perc * tr_len)\n",
    "    y_true_val = y_true[tr_end:, :, :]\n",
    "    y_pred_val = y_pred[tr_end:, :, :] \n",
    "    \n",
    "    mae_test = mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())\n",
    "    mse_test = mean_squared_error(y_true_val.flatten(), y_pred_val.flatten())\n",
    "    rmse_test = root_mean_squared_error(y_true_val.flatten(), y_pred_val.flatten())\n",
    "    \n",
    "    mae_all = mean_absolute_error(y_true.flatten(), y_pred.flatten())\n",
    "    mse_all = mean_squared_error(y_true.flatten(), y_pred.flatten())\n",
    "    rmse_all = root_mean_squared_error(y_true.flatten(), y_pred.flatten())\n",
    "    \n",
    "    results_dict[model_name][\"Test MAE_Eul\"].append(mae_test)\n",
    "    results_dict[model_name][\"Test MSE_Eul\"].append(mse_test)\n",
    "    results_dict[model_name][\"Test RMSE_Eul\"].append(rmse_test)\n",
    "    \n",
    "    results_dict[model_name][\"All MAE_Eul\"].append(mae_all)\n",
    "    results_dict[model_name][\"All MSE_Eul\"].append(mse_all)\n",
    "    results_dict[model_name][\"All RMSE_Eul\"].append(rmse_all)\n",
    "    \n",
    "    \n",
    "    print(f\"Test MAE: {mae_test}\")\n",
    "    print(f\"Overall MAE: {mae_all}\")\n",
    "    \n",
    "    print(f\"\\nTest MSE: {mse_test}\")\n",
    "    print(f\"Overall MSE: {mse_all}\")\n",
    "    \n",
    "    print(f\"\\nTest RMSE: {rmse_test}\")\n",
    "    print(f\"Overall RMSE: {rmse_all}\")\n",
    "    \n",
    "    \n",
    "    return y_true, y_pred, y_true_val, y_pred_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fe263fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def save_country_predictions(y_true, preds_dict, countries_dict, save_dir=\"./outputs/covid\"):\n",
    "\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "    for country_name, node_idx in countries_dict.items():\n",
    "        true_vals = y_true[:, node_idx, 0]\n",
    "\n",
    "        plt.figure(figsize=(10, 6))\n",
    "        # Ground truth\n",
    "        plt.plot(true_vals, label=\"True\", linewidth=2, color=\"black\")\n",
    "\n",
    "        # Predictions for each model\n",
    "        for model_name, (y_pred, color) in preds_dict.items():\n",
    "            pred_vals = y_pred[:, node_idx, 0]\n",
    "            plt.plot(pred_vals, linestyle=\"--\", label=model_name, color=color)\n",
    "\n",
    "        plt.title(f\"{country_name}\")\n",
    "        plt.xlabel(\"Days\")\n",
    "        plt.ylabel(\"Infected Count\")\n",
    "        plt.legend()\n",
    "        plt.tight_layout()\n",
    "\n",
    "        filename = os.path.join(save_dir, f\"{country_name}_comparison.png\")\n",
    "        plt.savefig(filename, dpi=150)\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26a6a087",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45ac0814",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"font.size\": 18,         # base font size\n",
    "    \"axes.titlesize\": 22,    # title\n",
    "    \"axes.labelsize\": 16,    # x/y labels\n",
    "    \"xtick.labelsize\": 15,   # x-tick labels\n",
    "    \"ytick.labelsize\": 15,   # y-tick labels\n",
    "    \"legend.fontsize\": 13,   # legend\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "634f126a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.RealEpidemics import RealEpidemics\n",
    "\n",
    "real_epid_data = RealEpidemics(\n",
    "    root = './data_real_epid_covid_int',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=False,\n",
    "    history=1,\n",
    "    horizon=44,\n",
    "    scale=False\n",
    ")\n",
    "\n",
    "data_real_epid_orig = RealEpidemics(\n",
    "    root = './data_real_epid_covid_orig',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=True,\n",
    "    scale=False,\n",
    ")\n",
    "\n",
    "with open('./data_real_epid_covid_int/RealEpid/countries_dict.json', 'r') as f:\n",
    "    countries_dict = json.load(f)\n",
    "    \n",
    "all_res_covid_traj = {}\n",
    "all_res_covid_eul = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c33e780",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "res_dict_tss = defaultdict(lambda: defaultdict(list))\n",
    "model_name_tss = \"TPSINDy\"\n",
    "\n",
    "res_dict_gkan_bb = defaultdict(lambda: defaultdict(list))\n",
    "model_name_gkan_bb = \"GKAN-ODE+GP\"\n",
    "\n",
    "res_dict_gkan_sw = defaultdict(lambda: defaultdict(list))\n",
    "model_name_gkan_sw = \"GKAN-ODE+SW\"\n",
    "\n",
    "res_dict_gmlp = defaultdict(lambda: defaultdict(list))\n",
    "model_name_gmlp = \"GMLP-ODE+GP\"\n",
    "\n",
    "res_dict_llc = defaultdict(lambda: defaultdict(list))\n",
    "model_name_llc = \"LLC+GP\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c128f6d",
   "metadata": {},
   "source": [
    "### TSS 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf52e680",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_symb_model_tss(country, inf_coeff):\n",
    "    x_i, x_j = sp.symbols('x_i x_j')    \n",
    "    country_idx = countries_dict[country]\n",
    "\n",
    "    g_symb = inf_coeff[1, country_idx] * (1 / (1 + sp.exp(- (x_j - x_i))))\n",
    "    h_symb = inf_coeff[0, country_idx] * x_i\n",
    "    \n",
    "    g_symb = make_callable(g_symb)\n",
    "    h_symb = make_callable(h_symb)\n",
    "\n",
    "    symb_model = get_model(\n",
    "        g = g_symb,\n",
    "        h = h_symb,\n",
    "        message_passing=False,\n",
    "        include_time=False,\n",
    "        integration_method='rk4'\n",
    "    )\n",
    "    \n",
    "    return symb_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e6aa6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "inf_coeff_covid = pd.read_csv(\"./inferred_coeffs/tpsindy/inf_coeffs_all_covid.csv\").values\n",
    "\n",
    "\n",
    "y_true_tss, y_pred_tss, y_true_val_tss, y_pred_val_tss = eval_real_epid_int(\n",
    "    data = real_epid_data,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    build_symb_model=build_symb_model_tss,\n",
    "    use_euler=True,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_tss,\n",
    "    results_dict=res_dict_tss\n",
    ")\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "inf_coeff_covid = pd.read_csv(\"./inferred_coeffs/tpsindy/inf_coeffs_test_covid.csv\").values\n",
    "\n",
    "y_true_tss_jrn, y_pred_tss_jrn, y_true_val_tss_jrn, y_pred_val_tss_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_tss,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    tr_perc=0.9,\n",
    "    step_size=1.0,\n",
    "    device='cpu',\n",
    "    model_name=model_name_tss,\n",
    "    results_dict=res_dict_tss\n",
    ")\n",
    "\n",
    "all_res_covid_traj[\"TPSINDy\"] = (y_pred_tss.copy(), \"red\")\n",
    "all_res_covid_eul[\"TPSINDy\"] = (y_pred_tss_jrn.copy(), \"red\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc8ab708",
   "metadata": {},
   "source": [
    "### GKAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9f0ca26",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_symb_model_gkan(country, inf_coeff):\n",
    "    x_i, x_j = sp.symbols('x_i x_j')    \n",
    "\n",
    "    coeffs = inf_coeff[country]\n",
    "    b, a, c = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2] \n",
    "\n",
    "    g_symb = c * sp.exp(x_j)\n",
    "    h_symb = a * x_i + b\n",
    "    \n",
    "    g_symb = make_callable(g_symb)\n",
    "    h_symb = make_callable(h_symb)\n",
    "    \n",
    "    symb_model = get_model(\n",
    "        g = g_symb,\n",
    "        h = h_symb,\n",
    "        message_passing=False,\n",
    "        include_time=False,\n",
    "        integration_method='dopri5'\n",
    "    )\n",
    "    \n",
    "    return symb_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ad8700",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = get_scaler(data = real_epid_data, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_covid = pd.read_csv(\"./inferred_coeffs/gkan/inferred_coeffs_covid_ts.csv\")\n",
    "\n",
    "y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(\n",
    "    data = real_epid_data,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_gkan,\n",
    "    scaler=scaler,\n",
    "    use_euler=False,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_bb,\n",
    "    results_dict=res_dict_gkan_bb\n",
    ")\n",
    "\n",
    "all_res_covid_traj[\"GKAN-ODE+GP\"] = (y_pred_gkan.copy(), \"#5fa2d1\")\n",
    "\n",
    "t = real_epid_data.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "scaler = get_scaler(data = real_epid_data, tr_perc=0.8)\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_gkan,\n",
    "    tr_perc=0.9,\n",
    "    step_size=epsilon.item(),\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    scaler=scaler,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_bb,\n",
    "    results_dict=res_dict_gkan_bb\n",
    ")\n",
    "\n",
    "all_res_covid_eul[\"GKAN-ODE+GP\"] = (y_pred_gkan_jrn.copy(), \"#5fa2d1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eb0457a",
   "metadata": {},
   "source": [
    "### GKAN SW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30fbb9ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = get_scaler(data = real_epid_data, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_covid = pd.read_csv(\"./inferred_coeffs/gkan_sw/inferred_coeffs_covid_sw.csv\")\n",
    "\n",
    "def build_model_sw(country, inf_coeff):\n",
    "\n",
    "    \n",
    "    coeffs = inf_coeff[country]\n",
    "    r, h, l, i, k, j, m, q, n, o, p, g, a, c, b, d, f, e = coeffs.iloc[0:]\n",
    "    \n",
    "    expr1 = a*sp.tanh(b*x_i + c) + d*sp.tanh(e*x_j + f) + g        \n",
    "    expr2 = h*sp.tanh(i*sp.tanh(j*x_i + k) + l) + m*sp.tanh(n*x_i**3 + o*x_i**2 + p*x_i + q) + r\n",
    "    \n",
    "    g_symb = make_callable(expr1)\n",
    "    h_symb = make_callable(expr2)\n",
    "    \n",
    "    symb_model = get_model(\n",
    "        g = g_symb,\n",
    "        h = h_symb,\n",
    "        message_passing=False,\n",
    "        include_time=False,\n",
    "        integration_method='dopri5'\n",
    "    )\n",
    "    \n",
    "    return symb_model\n",
    "\n",
    "\n",
    "y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(\n",
    "    data = real_epid_data,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_model_sw,\n",
    "    scaler=scaler,\n",
    "    use_euler=False,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_sw,\n",
    "    results_dict=res_dict_gkan_sw\n",
    ")\n",
    "\n",
    "t = real_epid_data.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_model_sw,\n",
    "    tr_perc=0.9,\n",
    "    step_size=epsilon,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    scaler=scaler,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_sw,\n",
    "    results_dict=res_dict_gkan_sw\n",
    ")\n",
    "\n",
    "all_res_covid_traj[\"GKAN-ODE+SW\"] = (y_pred_gkan.copy(), \"#a2c8e3\")\n",
    "all_res_covid_eul[\"GKAN-ODE+SW\"] = (y_pred_gkan_jrn.copy(), \"#a2c8e3\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59f345fb",
   "metadata": {},
   "source": [
    "### MPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e18d7c4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_symb_model_mpnn(country, inf_coeff):\n",
    "    \n",
    "    coeffs = inf_coeff[country]\n",
    "    a, _, b, _, c, = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2], coeffs.iloc[3], coeffs.iloc[4] \n",
    "\n",
    "    \n",
    "    expr1 = sp.ln(sp.Max(sp.tan(x_i + c)**2 + 1, 1e-6))\n",
    "    expr2 = a * sp.ln(sp.Max(x_i + b, 1e-6))\n",
    "    \n",
    "    g_symb = make_callable(expr1)\n",
    "    h_symb = make_callable(expr2)\n",
    "    \n",
    "    symb_model = get_model(\n",
    "        g = g_symb,\n",
    "        h = h_symb,\n",
    "        message_passing=False,\n",
    "        include_time=False,\n",
    "        integration_method='dopri5'\n",
    "    )\n",
    "    \n",
    "    return symb_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fdb7422",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = get_scaler(data = real_epid_data, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_covid = pd.read_csv(\"./inferred_coeffs/mpnn/inferred_coeffs_covid_ts.csv\")\n",
    "\n",
    "y_true_mpnn, y_pred_mpnn, y_true_val_mpnn, y_pred_val_mpnn = eval_real_epid_int(\n",
    "    data = real_epid_data,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_mpnn,\n",
    "    scaler=scaler,\n",
    "    use_euler=False,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gmlp,\n",
    "    results_dict=res_dict_gmlp\n",
    ")\n",
    "\n",
    "all_res_covid_traj[\"GMLP-ODE+GP\"] = (y_pred_mpnn.copy(), \"#fcb97d\")\n",
    "\n",
    "t = real_epid_data.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "scaler = get_scaler(data = real_epid_data, tr_perc=0.8)\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "y_true_mpnn_jrn, y_pred_mpnn_jrn, y_true_val_mpnn_jrn, y_pred_val_mpnn_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_mpnn,\n",
    "    tr_perc=0.9,\n",
    "    step_size=epsilon,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    scaler=scaler,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gmlp,\n",
    "    results_dict=res_dict_gmlp\n",
    ")\n",
    "\n",
    "all_res_covid_eul[\"GMLP-ODE+GP\"] = (y_pred_mpnn_jrn.copy(), \"#fcb97d\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4471629",
   "metadata": {},
   "source": [
    "### LLC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b2a83f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_symb_model_llc(country, inf_coeff):\n",
    "    \n",
    "    coeffs = inf_coeff[country]\n",
    "    a, b, c = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2]\n",
    "\n",
    "    expr1 = c*((x_i - x_j) * sp.exp(- x_j))\n",
    "    expr2 = a * sp.tanh(x_i + b)\n",
    "    \n",
    "    g_symb = make_callable(expr1)\n",
    "    h_symb = make_callable(expr2)\n",
    "    \n",
    "    symb_model = get_model(\n",
    "        g = g_symb,\n",
    "        h = h_symb,\n",
    "        message_passing=False,\n",
    "        include_time=False,\n",
    "        integration_method='dopri5'\n",
    "    )\n",
    "    \n",
    "    return symb_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce3d5bc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_covid = get_scaler(data = real_epid_data, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_covid = pd.read_csv(\"./inferred_coeffs/llc/inferred_coeffs_covid_new.csv\")\n",
    "\n",
    "y_true_llc, y_pred_llc, y_true_val_llc, y_pred_val_llc = eval_real_epid_int(\n",
    "    data = real_epid_data,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    build_symb_model=build_symb_model_llc,\n",
    "    scaler=scaler_covid,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_llc,\n",
    "    results_dict=res_dict_llc\n",
    ")\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "t = real_epid_data.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "    \n",
    "y_true_llc_jrn, y_pred_llc_jrn, y_true_val_llc_jrn, y_pred_val_llc_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_llc,\n",
    "    inferred_coeffs=inf_coeff_covid,\n",
    "    scaler=scaler_covid,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_llc,\n",
    "    results_dict=res_dict_llc\n",
    ")\n",
    "\n",
    "all_res_covid_traj[\"LLC+GP\"] = (y_pred_llc.copy(), \"#34eb6e\")\n",
    "all_res_covid_eul[\"LLC+GP\"] = (y_pred_llc_jrn.copy(), \"#34eb6e\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b70913db",
   "metadata": {},
   "source": [
    "### Plot Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1476f99b",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_country_predictions(\n",
    "    y_true_llc,\n",
    "    preds_dict=all_res_covid_traj,\n",
    "    countries_dict=countries_dict,\n",
    "    save_dir = \"./outputs/covid_traj\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "595066a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_country_predictions(\n",
    "    y_true_llc_jrn,\n",
    "    preds_dict=all_res_covid_eul,\n",
    "    countries_dict=countries_dict,\n",
    "    save_dir=\"./outputs/covid_eul\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "638fa278",
   "metadata": {},
   "source": [
    "## Generalization on H1N1 data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db7136b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "real_epid_h1n1 = RealEpidemics(\n",
    "    root = './data_real_epid_h1n1_int',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=False,\n",
    "    history=1,\n",
    "    horizon=44,\n",
    "    scale=False,\n",
    "    infection_data=\"./data/RealEpidemics/infected_numbers_H1N1.csv\",\n",
    "    inf_threshold=100\n",
    ")\n",
    "\n",
    "data_real_epid_orig_h1n1 = RealEpidemics(\n",
    "    root = './data_real_epid_h1n1_orig',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=True,\n",
    "    scale=False,\n",
    "    infection_data=\"./data/RealEpidemics/infected_numbers_H1N1.csv\",\n",
    "    inf_threshold=100\n",
    ")\n",
    "\n",
    "with open('./data_real_epid_h1n1_int/RealEpid/countries_dict.json', 'r') as f:\n",
    "    countries_dict = json.load(f)\n",
    "    \n",
    "all_res_h1n1_traj = {}\n",
    "all_res_h1n1_eul = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38287347",
   "metadata": {},
   "source": [
    "### TSS 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdc67ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "inf_coeff_h1n1 = pd.read_csv(\"./inferred_coeffs/tpsindy/inf_coeffs_all_h1n1.csv\").values\n",
    "\n",
    "y_true_tss, y_pred_tss, y_true_val_tss, y_pred_val_tss = eval_real_epid_int(\n",
    "    data = real_epid_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    build_symb_model=build_symb_model_tss,\n",
    "    use_euler=True,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_tss,\n",
    "    results_dict=res_dict_tss\n",
    ")\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "inf_coeff_h1n1 = pd.read_csv(\"./inferred_coeffs/tpsindy/inf_coeffs_test_h1n1.csv\").values\n",
    "\n",
    "y_true_tss_jrn, y_pred_tss_jrn, y_true_val_tss_jrn, y_pred_val_tss_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_tss,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    tr_perc=0.9,\n",
    "    step_size=1.0,\n",
    "    device='cpu',\n",
    "    model_name=model_name_tss,\n",
    "    results_dict=res_dict_tss\n",
    ")\n",
    "\n",
    "\n",
    "all_res_h1n1_traj[\"TPSINDy\"] = (y_pred_tss.copy(), \"red\")\n",
    "all_res_h1n1_eul[\"TPSINDy\"] = (y_pred_tss_jrn.copy(), \"red\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4732d559",
   "metadata": {},
   "source": [
    "### GKAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7b6e9db",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_h1n1 = pd.read_csv(\"./inferred_coeffs/gkan/inferred_coeffs_h1n1_ts.csv\")\n",
    "\n",
    "y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(\n",
    "    data = real_epid_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    build_symb_model=build_symb_model_gkan,\n",
    "    scaler=scaler_h1n1,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_bb,\n",
    "    results_dict=res_dict_gkan_bb\n",
    ")\n",
    "\n",
    "all_res_h1n1_traj[\"GKAN-ODE+GP\"] = (y_pred_gkan.copy(), \"#5fa2d1\")\n",
    "\n",
    "print(\"\\nMae Eul\\n\")\n",
    "\n",
    "t = real_epid_h1n1.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)\n",
    "print(epsilon)\n",
    "\n",
    "y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_gkan,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    scaler=scaler_h1n1,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_bb,\n",
    "    results_dict=res_dict_gkan_bb\n",
    ")\n",
    "\n",
    "all_res_h1n1_eul[\"GKAN-ODE+GP\"] = (y_pred_gkan_jrn.copy(), \"#5fa2d1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35dd356e",
   "metadata": {},
   "source": [
    "### GKAN SW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "484e60a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_h1n1 = pd.read_csv(\"./inferred_coeffs/gkan_sw/inferred_coeffs_h1n1_sw.csv\")\n",
    "\n",
    "y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(\n",
    "    data = real_epid_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    build_symb_model=build_model_sw,\n",
    "    scaler=scaler_h1n1,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_sw,\n",
    "    results_dict=res_dict_gkan_sw\n",
    ")\n",
    "\n",
    "print(\"\\nMae Eul\\n\")\n",
    "t = real_epid_h1n1.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "\n",
    "y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_model_sw,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    scaler=scaler_h1n1,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_sw,\n",
    "    results_dict=res_dict_gkan_sw\n",
    ")\n",
    "\n",
    "all_res_h1n1_traj[\"GKAN-ODE+SW\"] = (y_pred_gkan.copy(), \"#a2c8e3\")\n",
    "all_res_h1n1_eul[\"GKAN-ODE+SW\"] = (y_pred_gkan_jrn.copy(), \"#a2c8e3\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7a61df5",
   "metadata": {},
   "source": [
    "### MPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b1a0cb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_h1n1 = pd.read_csv(\"./inferred_coeffs/mpnn/inferred_coeffs_h1n1_ts.csv\")\n",
    "\n",
    "y_true_mpnn, y_pred_mpnn, y_true_val_mpnn, y_pred_val_mpnn = eval_real_epid_int(\n",
    "    data = real_epid_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    build_symb_model=build_symb_model_mpnn,\n",
    "    scaler=scaler_h1n1,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gmlp,\n",
    "    results_dict=res_dict_gmlp\n",
    ")\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "t = real_epid_h1n1.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "    \n",
    "y_true_mpnn_jrn, y_pred_mpnn_jrn, y_true_val_mpnn_jrn, y_pred_val_mpnn_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_mpnn,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    scaler=scaler_h1n1,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gmlp,\n",
    "    results_dict=res_dict_gmlp\n",
    ")\n",
    "\n",
    "all_res_h1n1_traj[\"GMLP-ODE+GP\"] = (y_pred_mpnn.copy(), \"#fcb97d\")\n",
    "all_res_h1n1_eul[\"GMLP-ODE+GP\"] = (y_pred_mpnn_jrn.copy(), \"#fcb97d\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10976e38",
   "metadata": {},
   "source": [
    "### LLC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d72934d",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_h1n1 = pd.read_csv(\"./inferred_coeffs/llc/inferred_coeffs_h1n1_new.csv\")\n",
    "\n",
    "y_true_llc, y_pred_llc, y_true_val_llc, y_pred_val_llc = eval_real_epid_int(\n",
    "    data = real_epid_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    build_symb_model=build_symb_model_llc,\n",
    "    scaler=scaler_h1n1,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_llc,\n",
    "    results_dict=res_dict_llc\n",
    ")\n",
    "\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "t = real_epid_h1n1.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "    \n",
    "y_true_llc_jrn, y_pred_llc_jrn, y_true_val_llc_jrn, y_pred_val_llc_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_h1n1,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_llc,\n",
    "    inferred_coeffs=inf_coeff_h1n1,\n",
    "    scaler=scaler_h1n1,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_llc,\n",
    "    results_dict=res_dict_llc\n",
    ")\n",
    "\n",
    "all_res_h1n1_traj[\"LLC+GP\"] = (y_pred_llc.copy(), \"#34eb6e\")\n",
    "all_res_h1n1_eul[\"LLC+GP\"] = (y_pred_llc_jrn.copy(), \"#34eb6e\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a96c5617",
   "metadata": {},
   "source": [
    "### Plot Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b698445c",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_country_predictions(\n",
    "    y_true_llc,\n",
    "    preds_dict=all_res_h1n1_traj,\n",
    "    countries_dict=countries_dict,\n",
    "    save_dir=\"./outputs/h1n1_traj\"\n",
    ")\n",
    "\n",
    "save_country_predictions(\n",
    "    y_true_llc_jrn,\n",
    "    preds_dict=all_res_h1n1_eul,\n",
    "    countries_dict=countries_dict,\n",
    "    save_dir=\"./outputs/h1n1_eul\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b28afe8f",
   "metadata": {},
   "source": [
    "## Generalization SARS Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61dae203",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.RealEpidemics import RealEpidemics\n",
    "\n",
    "real_epid_sars = RealEpidemics(\n",
    "    root = './data_real_epid_sars_int',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=False,\n",
    "    history=1,\n",
    "    horizon=44,\n",
    "    scale=False,\n",
    "    infection_data=\"./data/RealEpidemics/infected_numbers_sars.csv\",\n",
    "    inf_threshold=100\n",
    ")\n",
    "\n",
    "data_real_epid_orig_sars = RealEpidemics(\n",
    "    root = './data_real_epid_sars_orig',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=True,\n",
    "    scale=False,\n",
    "    infection_data=\"./data/RealEpidemics/infected_numbers_sars.csv\",\n",
    "    inf_threshold=100\n",
    ")\n",
    "\n",
    "with open('./data_real_epid_sars_int/RealEpid/countries_dict.json', 'r') as f:\n",
    "    countries_dict = json.load(f)\n",
    "    \n",
    "all_res_sars_traj = {}\n",
    "all_res_sars_eul = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afee6503",
   "metadata": {},
   "source": [
    "### TSS2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e06f95f",
   "metadata": {},
   "outputs": [],
   "source": [
    "inf_coeff_sars = pd.read_csv(\"./inferred_coeffs/tpsindy/inf_coeffs_all_sars.csv\").values\n",
    "\n",
    "y_true_tss, y_pred_tss, y_true_val_tss, y_pred_val_tss = eval_real_epid_int(\n",
    "    data = real_epid_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    build_symb_model=build_symb_model_tss,\n",
    "    use_euler=True,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_tss,\n",
    "    results_dict=res_dict_tss\n",
    ")\n",
    "\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "inf_coeff_sars = pd.read_csv(\"./inferred_coeffs/tpsindy/inf_coeffs_test_sars.csv\").values\n",
    "\n",
    "y_true_tss_jrn, y_pred_tss_jrn, y_true_val_tss_jrn, y_pred_val_tss_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_tss,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    tr_perc=0.9,\n",
    "    step_size=1.0,\n",
    "    device='cpu',\n",
    "    model_name=model_name_tss,\n",
    "    results_dict=res_dict_tss\n",
    ")\n",
    "\n",
    "all_res_sars_traj[\"TPSINDy\"] = (y_pred_tss.copy(), \"red\")\n",
    "all_res_sars_eul[\"TPSINDy\"] = (y_pred_tss_jrn.copy(), \"red\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3760587",
   "metadata": {},
   "source": [
    "### GKAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12311805",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_sars = get_scaler(data = real_epid_sars, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_sars = pd.read_csv(\"./inferred_coeffs/gkan/inferred_coeffs_sars_ts.csv\")\n",
    "\n",
    "y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(\n",
    "    data = real_epid_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    build_symb_model=build_symb_model_gkan,\n",
    "    scaler=scaler_sars,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_bb,\n",
    "    results_dict=res_dict_gkan_bb\n",
    ")\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "t = real_epid_sars.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "\n",
    "y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_gkan,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    scaler=scaler_sars,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_bb,\n",
    "    results_dict=res_dict_gkan_bb\n",
    ")\n",
    "\n",
    "all_res_sars_traj[\"GKAN-ODE+GP\"] = (y_pred_gkan.copy(), \"#5fa2d1\")\n",
    "all_res_sars_eul[\"GKAN-ODE+GP\"] = (y_pred_gkan_jrn.copy(), \"#5fa2d1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "027638e5",
   "metadata": {},
   "source": [
    "### GKAN SW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8906400c",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_sars = get_scaler(data = real_epid_sars, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_sars = pd.read_csv(\"./inferred_coeffs/gkan_sw/inferred_coeffs_sars_sw.csv\")\n",
    "\n",
    "y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(\n",
    "    data = real_epid_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    build_symb_model=build_model_sw,\n",
    "    scaler=scaler_sars,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_sw,\n",
    "    results_dict=res_dict_gkan_sw\n",
    ")\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "t = real_epid_sars.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "\n",
    "y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_model_sw,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    scaler=scaler_sars,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gkan_sw,\n",
    "    results_dict=res_dict_gkan_sw\n",
    ")\n",
    "\n",
    "all_res_sars_traj[\"GKAN-ODE+SW\"] = (y_pred_gkan.copy(), \"#a2c8e3\")\n",
    "all_res_sars_eul[\"GKAN-ODE+SW\"] = (y_pred_gkan_jrn.copy(), \"#a2c8e3\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5926ba1",
   "metadata": {},
   "source": [
    "### MPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae6e00ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_sars = get_scaler(data = real_epid_sars, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_sars = pd.read_csv(\"./inferred_coeffs/mpnn/inferred_coeffs_sars_ts.csv\")\n",
    "\n",
    "y_true_mpnn, y_pred_mpnn, y_true_val_mpnn, y_pred_val_mpnn = eval_real_epid_int(\n",
    "    data = real_epid_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    build_symb_model=build_symb_model_mpnn,\n",
    "    scaler=scaler_sars,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gmlp,\n",
    "    results_dict=res_dict_gmlp\n",
    ")\n",
    "\n",
    "print(\"Mae Eul\\n\")\n",
    "\n",
    "t = real_epid_sars.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "    \n",
    "y_true_mpnn_jrn, y_pred_mpnn_jrn, y_true_val_mpnn_jrn, y_pred_val_mpnn_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_mpnn,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    scaler=scaler_sars,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_gmlp,\n",
    "    results_dict=res_dict_gmlp\n",
    ")\n",
    "\n",
    "    \n",
    "all_res_sars_traj[\"GMLP-ODE+GP\"] = (y_pred_mpnn.copy(), \"#fcb97d\")\n",
    "all_res_sars_eul[\"GMLP-ODE+GP\"] = (y_pred_mpnn_jrn.copy(), \"#fcb97d\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81dad070",
   "metadata": {},
   "source": [
    "### LLC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbbf81fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler_sars = get_scaler(data = real_epid_sars, tr_perc=0.8)\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "inf_coeff_sars = pd.read_csv(\"./inferred_coeffs/llc/inferred_coeffs_sars_new.csv\")\n",
    "\n",
    "y_true_llc, y_pred_llc, y_true_val_llc, y_pred_val_llc = eval_real_epid_int(\n",
    "    data = real_epid_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    build_symb_model=build_symb_model_llc,\n",
    "    scaler=scaler_sars,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_llc,\n",
    "    results_dict=res_dict_llc\n",
    ")\n",
    "\n",
    "\n",
    "print(\"\\nMae Eul\\n\")\n",
    "\n",
    "t = real_epid_sars.t_sampled\n",
    "epsilon = t[0][1] - t[0][0]\n",
    "    \n",
    "y_true_llc_jrn, y_pred_llc_jrn, y_true_val_llc_jrn, y_pred_val_llc_jrn = eval_real_epid_journal(\n",
    "    data = data_real_epid_orig_sars,\n",
    "    countries_dict=countries_dict,\n",
    "    build_symb_model=build_symb_model_llc,\n",
    "    inferred_coeffs=inf_coeff_sars,\n",
    "    scaler=scaler_sars,\n",
    "    step_size=epsilon,\n",
    "    tr_perc=0.9,\n",
    "    device='cpu',\n",
    "    model_name=model_name_llc,\n",
    "    results_dict=res_dict_llc\n",
    ")\n",
    "\n",
    "all_res_sars_traj[\"LLC+GP\"] = (y_pred_llc.copy(), \"#34eb6e\")\n",
    "all_res_sars_eul[\"LLC+GP\"] = (y_pred_llc_jrn.copy(), \"#34eb6e\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d016d76",
   "metadata": {},
   "source": [
    "### Plot comparsion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0e2b7e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_country_predictions(\n",
    "    y_true_llc,\n",
    "    preds_dict=all_res_sars_traj,\n",
    "    countries_dict=countries_dict,\n",
    "    save_dir=\"./outputs/sars_traj\"\n",
    ")\n",
    "\n",
    "save_country_predictions(\n",
    "    y_true_llc_jrn,\n",
    "    preds_dict=all_res_sars_eul,\n",
    "    countries_dict=countries_dict,\n",
    "    save_dir=\"./outputs/sars_eul\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "498af80e",
   "metadata": {},
   "source": [
    "## Save results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5994e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dd_to_dict(d):\n",
    "    if isinstance(d, defaultdict):\n",
    "        d = {k: dd_to_dict(v) for k, v in d.items()}\n",
    "    elif isinstance(d, dict):  # handle normal dicts inside\n",
    "        d = {k: dd_to_dict(v) for k, v in d.items()}\n",
    "    return d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78ff7ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/tss/real_epid_covid/post_process_res.json\", \"w\") as f:\n",
    "    json.dump(dd_to_dict(res_dict_tss), f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b923e5c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/post_process_res_bb.json\", \"w\") as f:\n",
    "    json.dump(dd_to_dict(res_dict_gkan_bb), f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "265d5852",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/post_process_res_sw.json\", \"w\") as f:\n",
    "    json.dump(dd_to_dict(res_dict_gkan_sw), f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4198706c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/model-real-epid-mpnn/real_epid_mpnn_7/0/post_process_res.json\", \"w\") as f:\n",
    "    json.dump(dd_to_dict(res_dict_gmlp), f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3912faf",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/model-real-epid-llc/real_epid_llc_3/0/post_process_res.json\", \"w\") as f:\n",
    "    json.dump(dd_to_dict(res_dict_llc), f, indent=2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
