{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7c8155",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b00a2ebb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.utils import load_config\n",
    "from datasets.SyntheticData import SyntheticData\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch_geometric.utils import to_dense_adj\n",
    "import os\n",
    "from datasets.RealEpidemics import RealEpidemics\n",
    "from main import set_pytorch_seed\n",
    "\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "set_pytorch_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d4398ae",
   "metadata": {},
   "source": [
    "## two-step-SINDy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7551adc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.baseline.TSS.NumericalDerivatives import NumericalDeriv\n",
    "from models.baseline.TSS.ElementaryFunctions_Matrix import ElementaryFunctions_Matrix\n",
    "from models.baseline.TSS.TwoPhaseInference import TwoPhaseInference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58b79170",
   "metadata": {},
   "source": [
    "## Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35d4be08",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data_tss(config, snr_db = -1, real_epid=False, denoise=False):\n",
    "    if not real_epid:\n",
    "        dataset = SyntheticData(\n",
    "            root=config['data_folder'],\n",
    "            dynamics=config['name'],\n",
    "            t_span=config['t_span'],\n",
    "            t_max=config['t_eval_steps'],\n",
    "            num_samples=config['num_samples'],\n",
    "            seed=config['seed'],\n",
    "            n_ics=config['n_iter'],\n",
    "            input_range=config['input_range'],\n",
    "            # device=config['device'],\n",
    "            device='cuda',\n",
    "            horizon = config['horizon'],\n",
    "            history = config['history'],\n",
    "            stride=config.get('stride', 5),\n",
    "            predict_deriv=config.get(\"predict_deriv\", False),\n",
    "            snr_db=snr_db,\n",
    "            denoise=denoise,\n",
    "            **config['integration_kwargs']\n",
    "        )\n",
    "    else:\n",
    "        dataset = RealEpidemics(\n",
    "            root = './data_real_epid_covid_orig',\n",
    "            name = 'RealEpid',\n",
    "            predict_deriv=True,\n",
    "            scale=False,\n",
    "        )\n",
    "    \n",
    "    raw_data = dataset.raw_data_sampled.cpu().detach().numpy() # shape: (ics, time_steps, n_nodes, 1)\n",
    "    time = dataset.t_sampled\n",
    "    \n",
    "    edge_index = dataset[0].edge_index\n",
    "    A = to_dense_adj(edge_index)[0].cpu().detach().numpy()\n",
    "    \n",
    "    return raw_data, A, time\n",
    "\n",
    "\n",
    "def get_matrix_tss(raw_data, time, A, Dim=1, selfPolyOrder = 3, act_index=False, method=\"five_point\"):\n",
    "    dt = time[0, 1] - time[0, 0]\n",
    "    dt = dt.item()\n",
    "    Nnodes = A.shape[0]\n",
    "    \n",
    "    data = []\n",
    "    num_deriv = []\n",
    "    Matrix = []\n",
    "    \n",
    "    for ic in range(raw_data.shape[0]):\n",
    "        data_ic = raw_data[ic].squeeze(-1)  # shape: (time_steps, n_nodes)\n",
    "        num_deriv_ic = NumericalDeriv(\n",
    "            TimeSeries=data_ic,\n",
    "            dim=1,\n",
    "            Nnodes=data_ic.shape[1],\n",
    "            deltT=dt,\n",
    "            method=method\n",
    "        )   # pd DatafRame\n",
    "        \n",
    "        if method == \"five_point\":\n",
    "            data_ic = data_ic[2:-2,:]\n",
    "            \n",
    "        data.append(data_ic)\n",
    "        num_deriv.append(num_deriv_ic)\n",
    "        matrix_ic = ElementaryFunctions_Matrix(\n",
    "            data_ic, \n",
    "            Dim, \n",
    "            Nnodes, \n",
    "            A, \n",
    "            selfPolyOrder, \n",
    "            coupledPolyOrder = 1, \n",
    "            PolynomialIndex = True, \n",
    "            TrigonometricIndex = True, \n",
    "            ExponentialIndex = True, \n",
    "            FractionalIndex = True, \n",
    "            ActivationIndex = act_index, \n",
    "            RescalingIndex = False, \n",
    "            CoupledPolynomialIndex = True,\n",
    "            CoupledTrigonometricIndex = True, \n",
    "            CoupledExponentialIndex = True, \n",
    "            CoupledFractionalIndex = True,\n",
    "            CoupledActivationIndex = act_index, \n",
    "            CoupledRescalingIndex = False\n",
    "        )\n",
    "        \n",
    "        Matrix.append(matrix_ic)\n",
    "        \n",
    "\n",
    "    data = np.concatenate(data, axis=0)\n",
    "    num_deriv = pd.concat(num_deriv, ignore_index=True)\n",
    "    Matrix = pd.concat(Matrix, ignore_index=True)\n",
    "    Matrix = Matrix.replace([np.inf, -np.inf], np.nan).dropna(axis=1)\n",
    "    \n",
    "    return Matrix, num_deriv, data\n",
    "\n",
    "\n",
    "def two_step_sindy(Matrix, num_deriv, Nnodes, out_path, Dim = 1, plotstart = 0.5, plotend = 0.9, Keep = 10, SampleTimes = 20, Batchsize = 1,\n",
    "                   snr_db = -1, denoise=False):\n",
    "    Lambda = pd.DataFrame([[0.01, 0.5, 1]])\n",
    "    os.makedirs(out_path, exist_ok=True)\n",
    "    \n",
    "    for dim in range(Dim):\n",
    "        InferredResults, _, _, _ = TwoPhaseInference(\n",
    "            Matrix, \n",
    "            num_deriv, \n",
    "            Nnodes, \n",
    "            dim, \n",
    "            Dim, \n",
    "            Keep, \n",
    "            SampleTimes,\n",
    "            Batchsize, \n",
    "            Lambda, \n",
    "            plotstart, \n",
    "            plotend\n",
    "        )\n",
    "        \n",
    "        suffix = \"_denoise\" if denoise else \"\"\n",
    "        save_file = f\"{out_path}/results_dim={dim}.csv\" if snr_db < 0 else f\"{out_path}/results_dim={dim}_{snr_db}_db_{suffix}.csv\"\n",
    "        InferredResults.to_csv(save_file)\n",
    "    \n",
    "    \n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "939af7de",
   "metadata": {},
   "source": [
    "## Two Phase Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56665128",
   "metadata": {},
   "source": [
    "### Clean data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c738a649",
   "metadata": {},
   "outputs": [],
   "source": [
    "configs = [\n",
    "    'configs/config_pred_deriv/config_ic1/config_kuramoto.yml',\n",
    "    'configs/config_pred_deriv/config_ic1/config_biochemical.yml',\n",
    "    'configs/config_pred_deriv/config_ic1/config_epidemics.yml',\n",
    "    'configs/config_pred_deriv/config_ic1/config_population.yml'\n",
    "]\n",
    "\n",
    "for conf_path in configs:\n",
    "    conf= load_config(config_path=conf_path)\n",
    "    raw_data, A, time = load_data_tss(conf)\n",
    "    \n",
    "    Matrix, num_deriv, _ = get_matrix_tss(\n",
    "        raw_data=raw_data,\n",
    "        time = time,\n",
    "        A=A,\n",
    "        Dim=1,\n",
    "        selfPolyOrder=3\n",
    "    )\n",
    "    # Matrix.to_csv(f'./saved_models_optuna/tss/{conf['name']}-{conf['n_iter']}/Matrix.csv')\n",
    "    # print(f\"Matrix dim: {Matrix.values.shape[1]}\")\n",
    "    \n",
    "    \n",
    "    two_step_sindy(\n",
    "        Matrix=Matrix,\n",
    "        num_deriv=num_deriv,\n",
    "        Nnodes=A.shape[0],\n",
    "        out_path=f'./saved_models_optuna/tss/{conf['name']}-{conf['n_iter']}'\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5bb6789",
   "metadata": {},
   "outputs": [],
   "source": [
    "configs = [\n",
    "    'configs/config_pred_deriv/config_ic1/config_kuramoto.yml',\n",
    "    'configs/config_pred_deriv/config_ic1/config_biochemical.yml',\n",
    "    'configs/config_pred_deriv/config_ic1/config_epidemics.yml',\n",
    "    'configs/config_pred_deriv/config_ic1/config_population.yml'\n",
    "]\n",
    "\n",
    "for conf_path in configs:\n",
    "    conf= load_config(config_path=conf_path)\n",
    "    raw_data, A, time = load_data_tss(conf)\n",
    "    \n",
    "    Matrix, num_deriv, _ = get_matrix_tss(\n",
    "        raw_data=raw_data,\n",
    "        time = time,\n",
    "        A=A,\n",
    "        Dim=1,\n",
    "        selfPolyOrder=3,\n",
    "        method=\"finite_diff\"\n",
    "    )\n",
    "    \n",
    "    two_step_sindy(\n",
    "        Matrix=Matrix,\n",
    "        num_deriv=num_deriv,\n",
    "        Nnodes=A.shape[0],\n",
    "        out_path=f'./saved_models_optuna/tss/{conf['name']}-{conf['n_iter']}-no_fp'\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8d13efe",
   "metadata": {},
   "source": [
    "### Noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aa69f81",
   "metadata": {},
   "outputs": [],
   "source": [
    "configs = [\n",
    "    'configs/config_pred_deriv/config_ic1/config_kuramoto.yml',\n",
    "    # 'configs/config_pred_deriv/config_ic1/config_biochemical.yml',\n",
    "    'configs/config_pred_deriv/config_ic1/config_epidemics.yml',\n",
    "    'configs/config_pred_deriv/config_ic1/config_population.yml'\n",
    "]\n",
    "\n",
    "snr_db_levels = [70, 50, 20]\n",
    "\n",
    "for conf_path in configs:\n",
    "    for snr_db in snr_db_levels:\n",
    "        \n",
    "        conf = load_config(config_path=conf_path)\n",
    "        raw_data, A, time = load_data_tss(conf, snr_db=snr_db, denoise=True)\n",
    "        \n",
    "        Matrix, num_deriv, _ = get_matrix_tss(\n",
    "            raw_data=raw_data,\n",
    "            time = time,\n",
    "            A=A,\n",
    "            Dim=1,\n",
    "            selfPolyOrder=3\n",
    "        )\n",
    "         \n",
    "        two_step_sindy(\n",
    "            Matrix=Matrix,\n",
    "            num_deriv=num_deriv,\n",
    "            Nnodes=A.shape[0],\n",
    "            out_path=f'./saved_models_optuna/tss/{conf['name']}-{conf['n_iter']}_denoise_2',\n",
    "            snr_db=snr_db,\n",
    "            denoise=True\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f079631",
   "metadata": {},
   "source": [
    "## Post Processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30c1644a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "\n",
    "from post_processing import set_pytorch_seed, get_test_set, get_symb_test_error\n",
    "from utils.utils import load_config\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "set_pytorch_seed(0)\n",
    "\n",
    "kur_config = load_config(\"./configs/config_pred_deriv/config_ic1/config_kuramoto.yml\")\n",
    "\n",
    "KUR = get_test_set(\n",
    "    dynamics=kur_config['name'],\n",
    "    device='cuda',\n",
    "    input_range=kur_config['input_range'],\n",
    "    **kur_config['integration_kwargs']\n",
    ")\n",
    "\n",
    "g_symb = lambda x: torch.sin(x[:, 1] - x[:, 0]).unsqueeze(-1)\n",
    "h_symb = lambda x: 2.0 + 0.5 * x[:, 1].unsqueeze(-1)\n",
    "\n",
    "test_losses = get_symb_test_error(\n",
    "    g_symb=g_symb,\n",
    "    h_symb=h_symb,\n",
    "    test_set=KUR,\n",
    "    message_passing=True,\n",
    "    include_time=False,\n",
    "    is_symb=False\n",
    ")\n",
    "\n",
    "ts_mean = np.mean(test_losses)\n",
    "ts_var = np.var(test_losses)\n",
    "ts_std = np.std(test_losses)\n",
    "\n",
    "print(f\"Mean Test loss of symbolic formula: {ts_mean}\")\n",
    "print(f\"Var Test loss of symbolic formula: {ts_var}\")\n",
    "print(f\"Std Test loss of symbolic formula: {ts_std}\")\n",
    "\n",
    "\"\"\"### Epidemics\"\"\"\n",
    "\n",
    "epid_config = load_config(\"./configs/config_pred_deriv/config_ic1/config_epidemics.yml\")\n",
    "\n",
    "EPID = get_test_set(\n",
    "    dynamics=epid_config['name'],\n",
    "    device='cuda',\n",
    "    input_range=epid_config['input_range'],\n",
    "    **epid_config['integration_kwargs']\n",
    ")\n",
    "\n",
    "g_symb = lambda x: 0.5*x[:, 1].unsqueeze(-1) * (1 - x[:, 0].unsqueeze(-1))\n",
    "h_symb = lambda x: x[:, 1].unsqueeze(1) - 0.5 * x[:, 0].unsqueeze(-1)\n",
    "\n",
    "test_losses = get_symb_test_error(\n",
    "    g_symb=g_symb,\n",
    "    h_symb=h_symb,\n",
    "    test_set=EPID,\n",
    "    message_passing=True,\n",
    "    include_time=False,\n",
    "    is_symb=False\n",
    ")\n",
    "\n",
    "\n",
    "ts_mean = np.mean(test_losses)\n",
    "ts_var = np.var(test_losses)\n",
    "ts_std = np.std(test_losses)\n",
    "\n",
    "print(f\"Mean Test loss of symbolic formula: {ts_mean}\")\n",
    "print(f\"Var Test loss of symbolic formula: {ts_var}\")\n",
    "print(f\"Std Test loss of symbolic formula: {ts_std}\")\n",
    "\n",
    "\"\"\"### Population\"\"\"\n",
    "\n",
    "pop_config = load_config(\"./configs/config_pred_deriv/config_ic1/config_population.yml\")\n",
    "\n",
    "POP = get_test_set(\n",
    "    dynamics=pop_config['name'],\n",
    "    device='cuda',\n",
    "    input_range=pop_config['input_range'],\n",
    "    **pop_config['integration_kwargs']\n",
    ")\n",
    "\n",
    "g_symb = lambda x: 0.2*torch.pow(x[:, 1].unsqueeze(-1), 3)\n",
    "h_symb = lambda x: -0.5 * x[:, 0].unsqueeze(-1) + x[:, 1].unsqueeze(1)\n",
    "\n",
    "test_losses = get_symb_test_error(\n",
    "    g_symb=g_symb,\n",
    "    h_symb=h_symb,\n",
    "    test_set=POP,\n",
    "    message_passing=True,\n",
    "    include_time=False,\n",
    "    is_symb=False\n",
    ")\n",
    "\n",
    "ts_mean = np.mean(test_losses)\n",
    "ts_var = np.var(test_losses)\n",
    "ts_std = np.std(test_losses)\n",
    "\n",
    "print(f\"Mean Test loss of symbolic formula: {ts_mean}\")\n",
    "print(f\"Var Test loss of symbolic formula: {ts_var}\")\n",
    "print(f\"Std Test loss of symbolic formula: {ts_std}\")\n",
    "\n",
    "\"\"\"### Biochemical\"\"\"\n",
    "\n",
    "bio_config = load_config(\"./configs/config_pred_deriv/config_ic1/config_biochemical.yml\")\n",
    "\n",
    "BIO = get_test_set(\n",
    "    dynamics=bio_config['name'],\n",
    "    device='cuda',\n",
    "    input_range=bio_config['input_range'],\n",
    "    **bio_config['integration_kwargs']\n",
    ")\n",
    "\n",
    "g_symb = lambda x: (-0.5*x[:, 1] * x[:, 0]).unsqueeze(-1)\n",
    "h_symb = lambda x: (1.0 - 0.5 * x[:, 0]).unsqueeze(-1)  + x[:, 1].unsqueeze(-1)\n",
    "\n",
    "test_losses = get_symb_test_error(\n",
    "    g_symb=g_symb,\n",
    "    h_symb=h_symb,\n",
    "    test_set=BIO,\n",
    "    message_passing=True,\n",
    "    include_time=False,\n",
    "    is_symb=False\n",
    ")\n",
    "\n",
    "ts_mean = np.mean(test_losses)\n",
    "ts_var = np.var(test_losses)\n",
    "ts_std = np.std(test_losses)\n",
    "\n",
    "print(f\"Mean Test loss of symbolic formula: {ts_mean}\")\n",
    "print(f\"Var Test loss of symbolic formula: {ts_var}\")\n",
    "print(f\"Std Test loss of symbolic formula: {ts_std}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd3d9e36",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy as sp\n",
    "import json\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, root_mean_squared_error, r2_score, mean_absolute_percentage_error\n",
    "from post_processing import get_test_pred, get_list_test_errors, make_callable, get_model\n",
    "from torch_geometric.data import Data\n",
    "\n",
    "\n",
    "def compute_five_point_fd(raw_data, time):\n",
    "    delta_t = time[1] - time[0]\n",
    "    delta_t = delta_t.item()\n",
    "    \n",
    "    T, _, _ = raw_data.shape\n",
    "    derivative = torch.zeros_like(raw_data)\n",
    "\n",
    "    # Apply the five-point stencil to the interior points\n",
    "    for t in range(2, T - 2):\n",
    "        derivative[t] = (\n",
    "            -raw_data[t + 2] + 8 * raw_data[t + 1] - 8 * raw_data[t - 1] + raw_data[t - 2]\n",
    "        ) / (12 * delta_t)\n",
    "\n",
    "    # Handle boundary values with lower-order differences (e.g., forward/backward)\n",
    "    derivative[0] = (raw_data[1] - raw_data[0]) / delta_t\n",
    "    derivative[1] = (raw_data[2] - raw_data[0]) / (2 * delta_t)\n",
    "    derivative[-2] = (raw_data[-1] - raw_data[-3]) / (2 * delta_t)\n",
    "    derivative[-1] = (raw_data[-1] - raw_data[-2]) / delta_t\n",
    "\n",
    "    return derivative\n",
    "\n",
    "\n",
    "def get_error_through_time(y_true, y_pred):\n",
    "    test_loss_tt = []\n",
    "    for yp, yt in zip(y_pred, y_true):\n",
    "        # shape yp = yt = (T, N, 1)\n",
    "        yp = yp.detach().cpu().numpy().astype(np.float32)\n",
    "        yt = yt.detach().cpu().numpy().astype(np.float32)\n",
    "        err = np.mean(np.abs(yt - yp), axis=1) # Shape (T, 1)\n",
    "        test_loss_tt.append(err) \n",
    "    return np.stack(test_loss_tt, axis=0) # shape (3, T, 1)\n",
    "\n",
    "\n",
    "def save_pred_deriv(g_symb, h_symb, test_set, suffix, message_passing = False, include_time = False,\n",
    "                    is_symb=True, device='cuda', save_dir = None):\n",
    "    if is_symb:\n",
    "        if isinstance(g_symb, int):\n",
    "            g_symb = sp.sympify(g_symb)\n",
    "\n",
    "        if isinstance(h_symb, int):\n",
    "            h_symb = sp.sympify(h_symb)\n",
    "\n",
    "        g_symb = make_callable(g_symb)\n",
    "        h_symb = make_callable(h_symb)\n",
    "\n",
    "    symb = get_model(\n",
    "        g=g_symb,\n",
    "        h=h_symb,\n",
    "        message_passing=message_passing,\n",
    "        include_time=include_time,\n",
    "        pred_deriv=True,\n",
    "    )\n",
    "    symb = symb.to(torch.device(device))\n",
    "    symb = symb.eval()\n",
    "    pred_deriv = []\n",
    "    true_deriv = []\n",
    "    states = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for ts in test_set:\n",
    "            data = ts[0].raw_data\n",
    "            time = ts[0].t_span\n",
    "            derivative = compute_five_point_fd(data, time) # shape (T, N, 1)\n",
    "            true_deriv.append(derivative)\n",
    "            edge_index = ts[0].edge_index.to(device)\n",
    "            edge_attr = None\n",
    "            states.append(data)\n",
    "            \n",
    "            x_dot_pred = []\n",
    "            for t in range(data.shape[0]):\n",
    "                snapshot = data[t]\n",
    "                snapshot = snapshot.to(device)\n",
    "                \n",
    "                fake_snap = Data(\n",
    "                    edge_index=edge_index,\n",
    "                    edge_attr=edge_attr,\n",
    "                    x=snapshot\n",
    "                )\n",
    "                fake_snap = fake_snap.to(device)\n",
    "                x_dot_pred.append(symb(fake_snap))\n",
    "            \n",
    "            x_dot_pred = torch.stack(x_dot_pred, dim=0) # shape (T, N, 1)\n",
    "            pred_deriv.append(x_dot_pred)\n",
    "            \n",
    "        if save_dir is not None:\n",
    "            preds_dir = os.path.join(save_dir, \"pred_deriv\")\n",
    "            os.makedirs(preds_dir, exist_ok=True)\n",
    "            save_path = os.path.join(preds_dir, f\"pred_deriv_{suffix}.json\")\n",
    "            preds_dict = {\n",
    "                \"pred_deriv\": [pd.detach().cpu().numpy().tolist() for pd in pred_deriv],\n",
    "                \"true_deriv\": [d.detach().cpu().numpy().tolist() for d in true_deriv],\n",
    "                \"states\": [s.detach().cpu().numpy().tolist() for s in states]\n",
    "            }\n",
    "            with open(save_path, \"w\") as f:\n",
    "                json.dump(preds_dict, f)\n",
    "            print(f\"Saved predicted and true derivatives to {save_path}\")\n",
    "\n",
    "\n",
    "def get_tss_test_error(\n",
    "    text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h,\n",
    "    row_means,\n",
    "    test_set,\n",
    "    result_dict,\n",
    "    suffix = '',\n",
    "    method = \"dopri5\",\n",
    "    device = 'cuda',\n",
    "    save_dir = None\n",
    "):\n",
    "    g_symb = sp.S(0)\n",
    "    h_symb = sp.S(0)\n",
    "\n",
    "    \n",
    "    for symb_g in text_sympy_mapping_g.keys():\n",
    "        g_symb += row_means[symb_g] * text_sympy_mapping_g[symb_g]\n",
    "    for symb_h in text_sympy_mapping_h.keys():\n",
    "        h_symb += row_means[symb_h] * text_sympy_mapping_h[symb_h]\n",
    "\n",
    "\n",
    "    try:\n",
    "        \n",
    "        y_pred_test, y_true_test = get_test_pred(\n",
    "            g_symb=g_symb,\n",
    "            h_symb=h_symb,\n",
    "            test_set=test_set,\n",
    "            message_passing=False,\n",
    "            include_time=False,\n",
    "            method=method,\n",
    "            atol=1e-5,\n",
    "            rtol=1e-5,\n",
    "            is_symb=True,\n",
    "            device=device\n",
    "        )\n",
    "        \n",
    "        if save_dir is not None:\n",
    "            preds_dir = os.path.join(save_dir, \"test_preds\")\n",
    "            os.makedirs(preds_dir, exist_ok=True)\n",
    "            save_path = os.path.join(preds_dir, \"preds_tss.json\")\n",
    "            preds_dict = {\n",
    "                \"y_pred\": [yp.detach().cpu().numpy().tolist() for yp in y_pred_test],\n",
    "                \"y_true\": [yt.detach().cpu().numpy().tolist() for yt in y_true_test]\n",
    "            }\n",
    "            with open(save_path, \"w\") as f:\n",
    "                json.dump(preds_dict, f)\n",
    "            print(f\"Saved predictions to {save_path}\")\n",
    "            \n",
    "        if save_dir is not None:\n",
    "            save_pred_deriv(\n",
    "                g_symb=g_symb,\n",
    "                h_symb=h_symb,\n",
    "                test_set=test_set,\n",
    "                suffix=suffix,\n",
    "                message_passing=False,\n",
    "                include_time=False,\n",
    "                is_symb=True,\n",
    "                device=device,\n",
    "                save_dir=save_dir\n",
    "            )\n",
    "\n",
    "        test_losses_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=mean_absolute_error)\n",
    "        test_mse_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=mean_squared_error)\n",
    "        test_rmse_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=root_mean_squared_error)\n",
    "        test_r2_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=r2_score)\n",
    "        test_mape_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=mean_absolute_percentage_error)\n",
    "        \n",
    "        mae_tt = get_error_through_time(y_true=y_true_test, y_pred=y_pred_test) # shape (3, T, 1)\n",
    "        mae_tt_avg = np.mean(mae_tt, axis=0) # shape (T, 1)\n",
    "        mae_tt_std = np.std(mae_tt, axis = 0) # shape (T, 1)\n",
    "        \n",
    "        result_dict[f'tss_mae_tt_{suffix}'] = mae_tt_avg.tolist()\n",
    "        result_dict[f'tss_std_tt_{suffix}'] = mae_tt_std.tolist()\n",
    "        \n",
    "        result_dict[f'tss_test_mae_{suffix}'] = np.mean(test_losses_symb)\n",
    "        result_dict[f'tss_test_var_{suffix}'] = np.var(test_losses_symb)\n",
    "        result_dict[f'tss_test_std_{suffix}'] = np.std(test_losses_symb)\n",
    "        \n",
    "        result_dict[f'tss_test_mse_{suffix}'] = np.mean(test_mse_symb)\n",
    "        result_dict[f'tss_test_mse_var_{suffix}'] = np.var(test_mse_symb)\n",
    "        result_dict[f'tss_test_mse_std_{suffix}'] = np.std(test_mse_symb)\n",
    "        \n",
    "        result_dict[f'tss_test_rmse_{suffix}'] = np.mean(test_rmse_symb)\n",
    "        result_dict[f'tss_test_rmse_var_{suffix}'] = np.var(test_rmse_symb)\n",
    "        result_dict[f'tss_test_rmse_std_{suffix}'] = np.std(test_rmse_symb)\n",
    "        \n",
    "        result_dict[f'tss_test_r2_{suffix}'] = np.mean(test_r2_symb)\n",
    "        result_dict[f'tss_test_r2_var_{suffix}'] = np.var(test_r2_symb)\n",
    "        result_dict[f'tss_test_r2_std_{suffix}'] = np.std(test_r2_symb)\n",
    "        \n",
    "        result_dict[f'tss_test_mape_{suffix}'] = np.mean(test_mape_symb)\n",
    "        result_dict[f'tss_test_mape_var_{suffix}'] = np.var(test_mape_symb)\n",
    "        result_dict[f'tss_test_mape_std_{suffix}'] = np.std(test_mape_symb)\n",
    "        \n",
    "        \n",
    "    except AssertionError:\n",
    "        print(\"Evaluation failed !\")\n",
    "        result_dict[f'error_{suffix}'] = 'Evaluation failed !'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c0ffd81",
   "metadata": {},
   "source": [
    "### KUR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd58544",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_kur = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd3ca037",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"./saved_models_optuna/tss/Kuramoto-1/results_dim=0.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc7efa76",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "text_sympy_mapping_g = {\n",
    "    \"sinx1jMinusx1i\": sp.sin(x_j - x_i)\n",
    "}\n",
    "text_sympy_mapping_h = {\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=KUR,\n",
    "    result_dict=results_kur,\n",
    "    save_dir=\"./saved_models_optuna/tss/Kuramoto-1\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f1aef6b",
   "metadata": {},
   "source": [
    "#### 70 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d292c27",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"./saved_models_optuna/tss/Kuramoto-1_denoise_2/results_dim=0_70_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7074e512",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "text_sympy_mapping_g = {\n",
    "    \"sinx1jMinusx1i\": sp.sin(x_j - x_i)\n",
    "}\n",
    "text_sympy_mapping_h = {\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=KUR,\n",
    "    result_dict=results_kur,\n",
    "    suffix='70db'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "724a6bde",
   "metadata": {},
   "source": [
    "#### 50 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2670a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"./saved_models_optuna/tss/Kuramoto-1_denoise_2/results_dim=0_50_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc88d1c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "text_sympy_mapping_g = {\n",
    "    \"sinx1jMinusx1i\": sp.sin(x_j - x_i)\n",
    "}\n",
    "text_sympy_mapping_h = {\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=KUR,\n",
    "    result_dict=results_kur,\n",
    "    suffix='50db'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00179cdd",
   "metadata": {},
   "source": [
    "#### 20 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "165edd00",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"./saved_models_optuna/tss/Kuramoto-1_denoise_2/results_dim=0_20_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84adce87",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "text_sympy_mapping_g = {\n",
    "    \"sinx1jMinusx1i\": sp.sin(x_j - x_i)\n",
    "}\n",
    "text_sympy_mapping_h = {\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=KUR,\n",
    "    result_dict=results_kur,\n",
    "    suffix='20db',\n",
    "    method=\"dopri5\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43222550",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/tss/Kuramoto-1/post_process_tmlr_revs.json\", 'w') as file:\n",
    "        json.dump(results_kur, file, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b07014c",
   "metadata": {},
   "source": [
    "### EPID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5454bd11",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_epid = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3b999cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Epidemics-1/results_dim=0.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5929d030",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {\n",
    "    \"expx1jMinusx1i\": sp.exp(x_j - x_i)\n",
    "}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"constant\": sp.S(1.0),\n",
    "    # \"x1x1x1\": x_i * x_i * x_i\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=EPID,\n",
    "    result_dict=results_epid,\n",
    "    save_dir=\"./saved_models_optuna/tss/Epidemics-1\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2eae4a4",
   "metadata": {},
   "source": [
    "#### 70 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5978c9f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Epidemics-1_denoise_2/results_dim=0_70_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "523be7b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"x1x1\": x_i * x_i,\n",
    "    \"sinx1\": sp.sin(x_i)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=EPID,\n",
    "    result_dict=results_epid,\n",
    "    suffix='70db'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "145a6c8f",
   "metadata": {},
   "source": [
    "#### 50 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef2d2d83",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Epidemics-1_denoise_2/results_dim=0_50_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "859291d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"x1x1\": x_i * x_i,\n",
    "    \"sinx1\": sp.sin(x_i)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=EPID,\n",
    "    result_dict=results_epid,\n",
    "    suffix='50db'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac14e251",
   "metadata": {},
   "source": [
    "#### 20 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc771a58",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Epidemics-1_denoise_2/results_dim=0_20_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89ffd112",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {\n",
    "    \"x1jMinusx1i\": x_j - x_i,\n",
    "}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"x1x1x1\": x_i * x_i * x_i,\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=EPID,\n",
    "    result_dict=results_epid,\n",
    "    suffix='20db',\n",
    "    method=\"dopri5\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "947b393b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/tss/Epidemics-1/post_process_tmlr_revs.json\", 'w') as file:\n",
    "    json.dump(results_epid, file, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6967546",
   "metadata": {},
   "source": [
    "### BIO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77f01615",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_bio = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bd02524",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Biochemical-1/results_dim=0.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83d394ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {\n",
    "    \"x1ix1j\": x_i * x_j\n",
    "}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=BIO,\n",
    "    result_dict=results_bio,\n",
    "    save_dir=\"./saved_models_optuna/tss/Biochemical-1\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b134b234",
   "metadata": {},
   "source": [
    "#### 70 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cee3b347",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Biochemical-1_denoise_2/results_dim=0_70_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d53094",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"sinx1\": sp.sin(x_i),\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=BIO,\n",
    "    result_dict=results_bio,\n",
    "    suffix='70db'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5051ca17",
   "metadata": {},
   "source": [
    "#### 50 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b22deb71",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Biochemical-1_denoise_2/results_dim=0_50_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f6f5e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"sinx1\": sp.sin(x_i),\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=BIO,\n",
    "    result_dict=results_bio,\n",
    "    suffix='50db',\n",
    "    method=\"dopri5\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7eaf2d8",
   "metadata": {},
   "source": [
    "#### 20 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6726b8d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Biochemical-1_denoise_2/results_dim=0_20_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49da7d7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"sinx1\": sp.sin(x_i),\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=BIO,\n",
    "    result_dict=results_bio,\n",
    "    suffix='20db'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e9b9568",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/tss/Biochemical-1/post_process_tmlr_revs.json\", 'w') as file:\n",
    "    json.dump(results_bio, file, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d5bb772",
   "metadata": {},
   "source": [
    "### POP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c101796",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_pop = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fb328f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Population-1/results_dim=0.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a532fe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {\n",
    "    \"sinx1j\": sp.sin(x_j),\n",
    "    \"x1j\": x_j\n",
    "}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"constant\": sp.S(1.0)\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=POP,\n",
    "    result_dict=results_pop,\n",
    "    save_dir=\"./saved_models_optuna/tss/Population-1\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "402522c8",
   "metadata": {},
   "source": [
    "#### 70 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77005796",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Population-1_denoise_2/results_dim=0_70_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bb51202",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"sinx1\": sp.sin(x_i),\n",
    "    \"x1x1x1\": x_i*x_i*x_i\n",
    "}\n",
    "\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=POP,\n",
    "    result_dict=results_pop,\n",
    "    suffix='70db'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d0eb289",
   "metadata": {},
   "source": [
    "#### 50 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6767fbdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Population-1_denoise_2/results_dim=0_50_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7abd2e66",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"sinx1\": sp.sin(x_i),\n",
    "    \"x1x1x1\": x_i*x_i*x_i\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=POP,\n",
    "    result_dict=results_pop,\n",
    "    suffix='50db'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "638fb582",
   "metadata": {},
   "source": [
    "#### 20 DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1453afa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_i, x_j = sp.symbols('x_i x_j')\n",
    "\n",
    "df = pd.read_csv(\"./saved_models_optuna/tss/Population-1_denoise_2/results_dim=0_20_db__denoise.csv\", header=None)\n",
    "df.set_index(0, inplace=True)\n",
    "row_means = df.mean(axis=1)\n",
    "\n",
    "row_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94b50095",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_sympy_mapping_g = {}\n",
    "\n",
    "text_sympy_mapping_h = {\n",
    "    \"sinx1\": sp.sin(x_i),\n",
    "    \"x1x1x1\": x_i*x_i*x_i\n",
    "}\n",
    "\n",
    "get_tss_test_error(\n",
    "    text_sympy_mapping_g=text_sympy_mapping_g,\n",
    "    text_sympy_mapping_h=text_sympy_mapping_h,\n",
    "    row_means=row_means,\n",
    "    test_set=POP,\n",
    "    result_dict=results_pop,\n",
    "    suffix='20db'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bc226a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./saved_models_optuna/tss/Population-1/post_process_tmlr_revs.json\", 'w') as file:\n",
    "        json.dump(results_pop, file, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c796e91",
   "metadata": {},
   "source": [
    "## Re-fitting coefficients\n",
    "\n",
    "We fine tune the coefficients of the TPSINDy formula using the method proposed by the authors. However, we perform the fitting only on the first 90% of observation, in order to have a fair comparison with the neural-based formulas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17022c43",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sympy as sp\n",
    "from post_processing import make_callable, get_model, set_pytorch_seed\n",
    "import numpy as np\n",
    "\n",
    "set_pytorch_seed(0)\n",
    "\n",
    "def build_symb_model_tss():\n",
    "    x_i, x_j = sp.symbols('x_i x_j')    \n",
    "\n",
    "    g_symb = (1 / (1 + sp.exp(- (x_j - x_i))))\n",
    "    h_symb = x_i\n",
    "    \n",
    "    g_symb = make_callable(g_symb)\n",
    "    h_symb = make_callable(h_symb)\n",
    "\n",
    "    symb_model = get_model(\n",
    "        g = g_symb,\n",
    "        h = h_symb,\n",
    "        message_passing=False,\n",
    "        include_time=False,\n",
    "        integration_method='rk4'\n",
    "    )\n",
    "    symb_model.predict_deriv = True\n",
    "    return symb_model\n",
    "\n",
    "\n",
    "def get_dxdt_pred(data, symb_model, device='cpu'):\n",
    "    self_int = []\n",
    "    pair_int = []\n",
    "    for snapshot in data:\n",
    "        snapshot = snapshot.to(device)\n",
    "        _ = symb_model(snapshot)\n",
    "        self_int.append(symb_model.conv.model.h_out)    # h_out shape = g_out shape = (N, 1)\n",
    "        pair_int.append(symb_model.conv.model.g_out)\n",
    "    \n",
    "    self_int = torch.stack(self_int, dim=1)\n",
    "    pair_int = torch.stack(pair_int, dim=1)\n",
    "    \n",
    "    return self_int.cpu().detach().numpy().flatten(), pair_int.cpu().detach().numpy().flatten()\n",
    "\n",
    "\n",
    "def sum_over_dxdt(self_int, pair_int, n_nodes, T):\n",
    "    lib_new = []\n",
    "\n",
    "    for i in range(n_nodes):\n",
    "        for t in range(T):\n",
    "            start = i * T\n",
    "            end = start + t + 1  # Python slice is exclusive at end\n",
    "            val1 = np.sum(self_int[start:end])\n",
    "            val2 = np.sum(pair_int[start:end])  # index 35 in MATLAB = 34 in Python\n",
    "            lib_new.append([val1, val2])\n",
    "\n",
    "    lib_new = np.array(lib_new)\n",
    "    return lib_new\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c618976f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "coeffs_path = \"./inferred_coeffs/tpsindy\"\n",
    "os.makedirs(coeffs_path, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cccd3027",
   "metadata": {},
   "source": [
    "### COVID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7309215c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.RealEpidemics import RealEpidemics\n",
    "\n",
    "\n",
    "data_real_epid_orig = RealEpidemics(\n",
    "    root = './data_real_epid_covid_orig',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=True,\n",
    "    scale=False,\n",
    ")\n",
    "\n",
    "symb_model_tss = build_symb_model_tss()\n",
    "\n",
    "self_int, pair_int = get_dxdt_pred(\n",
    "    data=data_real_epid_orig,\n",
    "    symb_model=symb_model_tss,\n",
    "    device=data_real_epid_orig.device\n",
    ")\n",
    "\n",
    "lib = sum_over_dxdt(\n",
    "    self_int=self_int,\n",
    "    pair_int=pair_int,\n",
    "    n_nodes=data_real_epid_orig[0].x.shape[0],\n",
    "    T=len(data_real_epid_orig)\n",
    ")\n",
    "\n",
    "import pandas as pd\n",
    "lib_df = pd.DataFrame(lib, columns=['x', 'sigxjminusxi'])\n",
    "lib_df.to_csv(f'{coeffs_path}/lib_new_covid.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a72f2118",
   "metadata": {},
   "source": [
    "### H1N1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd0240ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_real_epid_orig_h1n1 = RealEpidemics(\n",
    "    root = './data_real_epid_h1n1_orig',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=True,\n",
    "    scale=False,\n",
    "    infection_data=\"./data/RealEpidemics/infected_numbers_H1N1.csv\",\n",
    "    inf_threshold=100\n",
    ")\n",
    "\n",
    "symb_model_tss = build_symb_model_tss()\n",
    "\n",
    "self_int, pair_int = get_dxdt_pred(\n",
    "    data=data_real_epid_orig_h1n1,\n",
    "    symb_model=symb_model_tss,\n",
    "    device=data_real_epid_orig_h1n1.device\n",
    ")\n",
    "\n",
    "lib = sum_over_dxdt(\n",
    "    self_int=self_int,\n",
    "    pair_int=pair_int,\n",
    "    n_nodes=data_real_epid_orig_h1n1[0].x.shape[0],\n",
    "    T=len(data_real_epid_orig_h1n1)\n",
    ")\n",
    "\n",
    "lib_df = pd.DataFrame(lib, columns=['x', 'sigxjminusxi'])\n",
    "lib_df.to_csv(f'{coeffs_path}/lib_new_h1n1.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea03b6bf",
   "metadata": {},
   "source": [
    "### SARS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bda44ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.RealEpidemics import RealEpidemics\n",
    "import pandas as pd\n",
    "\n",
    "data_real_epid_orig_sars = RealEpidemics(\n",
    "    root = './data_real_epid_sars_orig',\n",
    "    name = 'RealEpid',\n",
    "    predict_deriv=True,\n",
    "    scale=False,\n",
    "    infection_data=\"./data/RealEpidemics/infected_numbers_sars.csv\",\n",
    "    inf_threshold=100\n",
    ")\n",
    "\n",
    "symb_model_tss = build_symb_model_tss()\n",
    "\n",
    "self_int, pair_int = get_dxdt_pred(\n",
    "    data=data_real_epid_orig_sars,\n",
    "    symb_model=symb_model_tss,\n",
    "    device=data_real_epid_orig_sars.device\n",
    ")\n",
    "\n",
    "lib = sum_over_dxdt(\n",
    "    self_int=self_int,\n",
    "    pair_int=pair_int,\n",
    "    n_nodes=data_real_epid_orig_sars[0].x.shape[0],\n",
    "    T=len(data_real_epid_orig_sars)\n",
    ")\n",
    "\n",
    "lib_df = pd.DataFrame(lib, columns=['x', 'sigxjminusxi'])\n",
    "lib_df.to_csv(f'{coeffs_path}/lib_new_sars.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d35feb37",
   "metadata": {},
   "source": [
    "### Fair fitting of coefficients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb3275e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression \n",
    "\n",
    "righ_sides = [\n",
    "    f\"{coeffs_path}/lib_new_covid.csv\",\n",
    "    f\"{coeffs_path}/lib_new_h1n1.csv\",\n",
    "    f\"{coeffs_path}/lib_new_sars.csv\"\n",
    "]\n",
    "\n",
    "lef_sides = [\n",
    "    \"./inferred_coeffs/tpsindy/left_side_components_covid.csv\",\n",
    "    \"./inferred_coeffs/tpsindy/left_side_components_H1N1.csv\",\n",
    "    \"./inferred_coeffs/tpsindy/left_side_components_Sars.csv\"\n",
    "]\n",
    "\n",
    "n_nodes = [82, 21, 4]\n",
    "names = ['covid', 'h1n1', 'sars']\n",
    "\n",
    "\n",
    "for j, (rs, ls) in enumerate(zip(righ_sides, lef_sides)):\n",
    "    X_all = pd.read_csv(rs)\n",
    "    y_all = pd.read_csv(ls)\n",
    "    N = n_nodes[j]\n",
    "    X_mat = X_all.values\n",
    "    y_mat = y_all.values\n",
    "    num = len(X_mat[0])\n",
    "    num2 = len(y_mat[0])\n",
    "    L = int(len(X_mat)/N)\n",
    "    times = N\n",
    "    Coef = np.zeros(shape=(2,times))\n",
    "    for i in range(0,times):\n",
    "        X = X_all.iloc[i*L:(i+1)*L,:]\n",
    "        y = y_all.iloc[i*L:(i+1)*L,:]\n",
    "        \n",
    "        cutoff = int(0.9 * len(X))\n",
    "        X = X.iloc[:cutoff, :]\n",
    "        y = y.iloc[:cutoff, :]\n",
    "        \n",
    "        v1 = X['x']\n",
    "        v2 = X['sigxjminusxi']\n",
    "        y1 = y['X']\n",
    "        Xin = pd.concat([v1,v2],axis=1)\n",
    "        model = LinearRegression(fit_intercept=False)\n",
    "        model.fit(Xin,y1)\n",
    "        a = model.coef_\n",
    "        a = (pd.DataFrame(a)).values\n",
    "        Coef[0,i] = a[0]\n",
    "        Coef[1,i] = a[1]\n",
    "        \n",
    "    Coef = pd.DataFrame(Coef)\n",
    "    # print(Coef)\n",
    "    Coef.to_csv(f\"{coeffs_path}/inf_coeffs_test_{names[j]}.csv\", index=0)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
