# -*- coding: utf-8 -*-
"""post_processing.ipynb

Automatically generated by Colab.

## Dynamics

Dynamics | $\partial_{\tau}x_i=$ |
| :--------: | :-------: |
Biochemical | $F -B x_i - R \sum_j A_{ij} x_i x_j$ |
Epidemics | $-B x_i + R \sum_j A_{ij} (1-x_i)x_j$ |
Population | $-B x_i^{b} + R \sum_j A_{ij} x_j^a$ |
Synchronization | $\omega_i + R \sum_j A_{ij} \sin(x_j-x_i)$ |
"""

# Commented out IPython magic to ensure Python compatibility.
# %load_ext autoreload
# %autoreload 2

"""## Importing"""

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

from utils.utils import *
import optuna
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend
from experiments.experiments_gkan import ExperimentsGKAN
from experiments.experiments_mpnn import ExperimentsMPNN
from train_and_eval import get_pred_batch
import sympytorch
import itertools
from sklearn.metrics import mean_absolute_error, mean_squared_error, root_mean_squared_error
import pickle

storage = JournalStorage(JournalFileBackend("optuna_journal_storage.log"))

import random

def set_pytorch_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    torch.use_deterministic_algorithms(True, warn_only=True)


"""## Utils"""

from models.utils.MPNN import MPNN
from models.baseline.MPNN_ODE import MPNN_ODE
from models.baseline.MLP_ODE import MLP_ODE
from datasets.SyntheticData import SyntheticData
from sympy import symbols, sin, summation, simplify
import networkx as nx
from torch_geometric.utils import from_networkx
from utils.utils import integrate
from torch_geometric.data import Data
from models.kan.KAN import KAN
from models.GKAN_ODE import GKAN_ODE
import torch
import numpy as np


import optuna

import warnings
warnings.filterwarnings("ignore")

from sympy import latex
from torch.utils.data import DataLoader

def get_model(g, h, message_passing=True, include_time=False, atol=1e-5, rtol=1e-5, integration_method = 'scipy_solver',
              eval=True, options = {}, all_t = False, exploit_graph=True, pred_deriv=False):
    conv = MPNN(
        g_net = g,
        h_net = h,
        message_passing=message_passing,
        include_time=include_time,
        exploit_graph_struct=exploit_graph
    )

    if exploit_graph:
        symb = MPNN_ODE(
            conv=conv,
            model_path="./saved_models_optuna/tmp_symb",
            adjoint=True,
            integration_method=integration_method,
            atol=atol,
            rtol=rtol,
            options = options,
            all_t=all_t,
            predict_deriv=pred_deriv
        )
    else:
        symb = MLP_ODE(
            conv=conv,
            model_path="./saved_models_optuna/tmp_symb",
            adjoint=True,
            integration_method=integration_method,
            atol=atol,
            rtol=rtol,
            options = options,
            all_t=all_t,
            predict_deriv=pred_deriv
        )

    if eval:
        symb = symb.eval()
    return symb


def make_callable(expr):
    free_syms = expr.free_symbols
    if not free_syms:
        # Expression is constant
        const_value = float(expr)
        return lambda x: torch.full((x.shape[0], 1), const_value, dtype=x.dtype, device=x.device)

    # expr = sympytorch.hide_floats(expr)
    sym_module = sympytorch.SymPyModule(expressions=[expr])
    syms = {str(s) for s in free_syms}
    if {'x_i', 'x_j'} <= syms:
        return lambda x: sym_module(x_i=x[:, 0], x_j=x[:, 1])
    elif 'x_i' in syms:
        return lambda x: sym_module(x_i=x[:, 0])
    elif 'x_j' in syms:
        return lambda x: sym_module(x_j=x[:, 1])
    else:
        raise ValueError(f"Unexpected symbols in expression: {free_syms}")


def get_test_pred(g_symb, h_symb, test_set, message_passing=False, include_time=False, atol=1e-5, rtol=1e-5, method='scipy_solver',
                        is_symb = True, device='cuda', exploit_graph = True):
    
    if is_symb:
        if isinstance(g_symb, int):
            g_symb = sp.sympify(g_symb)

        if isinstance(h_symb, int):
            h_symb = sp.sympify(h_symb)

        g_symb = make_callable(g_symb)
        h_symb = make_callable(h_symb)

    symb = get_model(
        g=g_symb,
        h=h_symb,
        message_passing=message_passing,
        include_time=include_time,
        atol=atol,
        rtol=rtol,
        integration_method=method,
        exploit_graph=exploit_graph
    )
    symb = symb.to(torch.device(device))
    
    symb = symb.eval()
    
    y_pred = []
    y_true = []
    
    with torch.no_grad():
        for ts in test_set:
            collate_fn = lambda samples_list: samples_list
            test_loader = DataLoader(ts, batch_size=len(ts), shuffle=True, collate_fn=collate_fn)
            
            y_pred_batch, y_true_batch = get_pred_batch(
                model=symb,
                loader=test_loader,
                pred_deriv=False,
                device=device,
                scaler=None
            )
            y_pred.append(y_pred_batch)
            y_true.append(y_true_batch)
    
    return y_pred, y_true


def get_list_test_errors(y_pred, y_true, criterion = mean_absolute_error):
    test_losses = []
    
    for yp, yt in zip(y_pred, y_true):
        test_loss = criterion(yt.detach().cpu().numpy().astype(np.float32).flatten(), yp.detach().cpu().numpy().astype(np.float32).flatten())
        test_losses.append(test_loss)
        
    return test_losses


def get_symb_test_error(g_symb, h_symb, test_set, message_passing=False, include_time=False, atol=1e-5, rtol=1e-5, method='scipy_solver',
                        is_symb = True, device = 'cuda', criterion = mean_absolute_error):
    
    y_pred_test, y_true_test = get_test_pred(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=test_set,
        message_passing=message_passing,
        include_time=include_time,
        atol=atol,
        rtol=rtol,
        method=method,
        is_symb=is_symb,        
        device=device
    )
    
    test_losses = get_list_test_errors(y_pred_test, y_true_test, criterion=criterion)
        
    return test_losses



def get_test_set(dynamics, device='cuda', input_range=(0, 1), t_span = (0, 1), **integration_kwargs):
    seeds = [12345, 67890, 111213]

    graphs = [
        nx.barabasi_albert_graph(70, 3, seed=seeds[0]),
        nx.watts_strogatz_graph(50, 6, 0.3, seed=seeds[1]),
        nx.erdos_renyi_graph(100, 0.05, seed=seeds[2])
    ]
    
    save_dir = "./data_test_set"
    os.makedirs(save_dir, exist_ok=True)
    fpath = os.path.join(save_dir, f"{dynamics}.pkl")

    if os.path.exists(fpath):
        print(f"Loading test set from {fpath}")
        with open(fpath, 'rb') as f:
            return pickle.load(f)

    test_set = []
    for i, graph in enumerate(graphs):
        snapshots = integrate_test_set(
            graph=graph,
            dynamics=dynamics,
            seed=seeds[i],
            device=device,
            input_range=input_range,
            t_span=t_span,
            **integration_kwargs
        )
        test_set.append(snapshots)

    with open(fpath, 'wb') as f:
        pickle.dump(test_set, f)
    
    return test_set



def integrate_test_set(graph, dynamics, seed=12345, device='cuda', input_range = (0, 1), t_span = (0, 1), **integration_kwargs):
    # graph = nx.barabasi_albert_graph(100, 3, seed=seed)
    edge_index = from_networkx(graph).edge_index
    edge_index = edge_index.to(torch.device(device))
    rng = np.random.default_rng(seed=seed)

    data, t = integrate(
        input_range=input_range,
        t_span = t_span,
        t_eval_steps=1000,
        dynamics=dynamics,
        device=device,
        graph=graph,
        rng = rng,
        **integration_kwargs
    )

    snapshot = Data(
        x = data[0].unsqueeze(0),
        y = data[1:],
        edge_index=edge_index,
        edge_attr=None,
        t_span = t,
        raw_data = data
    )

    return [snapshot]


def build_model_from_file(model_path, message_passing, include_time, method='dopri5', adjoint=False, atol=1e-5, rtol=1e-5,
                          compute_mult=True, device='cuda'):
    best_params_file = f"{model_path}/best_params.json"
    best_state_path = f"{model_path}/gkan/state_dict.pth"

    with open(best_params_file, 'r') as f:
        best_hyperparams = json.load(f)

    # g_net
    g_net = KAN(
        layers_hidden=[2, best_hyperparams['hidden_dim_g_net'], 1],
        grid_size=best_hyperparams['grid_size_g_net'],
        spline_order=best_hyperparams['spline_order_g_net'],
        grid_range=[-best_hyperparams['range_limit_g_net'], best_hyperparams['range_limit_g_net']],
        mu_1=best_hyperparams['mu_1_g_net'],
        mu_2=best_hyperparams['mu_2_g_net'],
        device=device,
        compute_mult=compute_mult,
        store_act=True
    )

    time_dim = 1 if include_time else 0
    in_dim_h = 2 if message_passing else 1
    in_dim_h += time_dim

    # h_net
    h_net = KAN(
        layers_hidden=[in_dim_h, best_hyperparams['hidden_dim_h_net'], 1],
        grid_size=best_hyperparams['grid_size_h_net'],
        spline_order=best_hyperparams['spline_order_h_net'],
        grid_range=[-best_hyperparams['range_limit_h_net'], best_hyperparams['range_limit_h_net']],
        mu_1=best_hyperparams['mu_1_h_net'],
        mu_2=best_hyperparams['mu_2_h_net'],
        device=device,
        compute_mult=compute_mult,
        store_act=True
    )

    gkan = MPNN(
        h_net=h_net,
        g_net=g_net,
        message_passing=message_passing,
        include_time=include_time
    )

    model = GKAN_ODE(
        conv=gkan,
        model_path='./saved_models_optuna/tmp',
        lmbd_g=best_hyperparams['lamb_g_net'],
        lmbd_h=best_hyperparams['lamb_h_net'],
        integration_method=method,
        adjoint=adjoint,
        atol=atol,
        rtol=rtol
    )

    model = model.to(torch.device(device))
    model.load_state_dict(torch.load(best_state_path, weights_only=False, map_location=torch.device(device)))

    return model


def valid_symb_model(
    config,
    model_path_gkan,
    device='cuda',
    atol=1e-5,
    rtol=1e-5,
    method='dopri5',
    black_box_fitting=True,
    depth_g=2,
    depth_h=2,
    sample_size=10000,
    grid_orig = None
):
    
    seed = 9999
    graph = nx.barabasi_albert_graph(100, 3, seed=seed)

    # Prepare validation/test set
    valid_set = integrate_test_set(
        graph=graph,
        dynamics=config['name'],
        seed=seed,
        device=device,
        input_range=config['input_range'],
        t_span=(0, 1),
        **config['integration_kwargs']
    )

    # Helper to compute validation loss
    def evaluate_model(g_symb, h_symb, is_symb=True):
        errs = get_symb_test_error(
            g_symb=g_symb,
            h_symb=h_symb,
            test_set=[valid_set],
            message_passing=False,
            include_time=False,
            method=method,
            atol=atol,
            rtol=rtol,
            is_symb=is_symb,
            device=device
        )
        return errs[0]

    # Helper to fit model for current config
    def fit_single_model(param1, param2, param3=None, is_orig=False):
        if black_box_fitting:
            print(f"Fitting black-box model with {param1} and {param2} iterations")
            pysr_model = lambda: get_pysr_model(
                model_selection=param1, 
                n_iterations=param2,
                # parallelism="serial",
                # random_state = seed,
                # deterministic = True
            )
            _, g_symb, h_symb, _ = fit_black_box_from_kan(
                depth_g=depth_g,
                depth_h=depth_h,
                device=device,
                model_path=model_path_gkan,
                pysr_model=pysr_model,
                sample_size=sample_size,
                theta=-np.inf,
                message_passing=False,
                verbose=False
            )
        else:
            
            if not is_orig:
                print(f"Fitting symbolic model with {param1}, theta {param2} and cutting threshold {param3}")
                _, g_symb, h_symb, _ = fit_model(
                    depth_g=depth_g,
                    depth_h=depth_h,
                    model_path=model_path_gkan,
                    theta=param2,
                    message_passing=False,
                    include_time=False,
                    sample_size=sample_size,
                    sort_by=param1,
                    verbose=False,
                    cut_threshold=param3
                )
            else:
                print(f"Fitting symbolic model with {param1}, theta {param2} and ws {param3}")
                _, g_symb, h_symb, _ = fit_model(
                    depth_g=depth_g,
                    depth_h=depth_h,
                    model_path=model_path_gkan,
                    theta=param2,
                    message_passing=False,
                    include_time=False,
                    sample_size=sample_size,
                    verbose=False,
                    fit_orig=True,
                    a_range = param1,
                    b_range = param1,
                    weight_simple = param3
                )
        return g_symb, h_symb

    if black_box_fitting:
        param_grid = (["score", "accuracy"], [50, 100, 200])
        search_space = [(mod, val) for mod in param_grid[0] for val in param_grid[1]]
    else:
        if grid_orig == None:
            param_grid = (
                ["score", "log_loss"],       
                [0.01, 0.05, 0.1],           
                [0.1, 0.01, 0.001]    
            )
        else:
            param_grid = grid_orig
            
            
        search_space = list(itertools.product(*param_grid))

    valid_losses = []

    for params in search_space:
        g_symb, h_symb = fit_single_model(*params, grid_orig != None)
        try:
            loss = evaluate_model(g_symb, h_symb)
        except AssertionError:
            loss = 1e8
        if black_box_fitting:
            valid_losses.append({'model_selection': params[0], 'param': params[1], 'valid_loss': loss})
        elif grid_orig == None:
            valid_losses.append({'sort_by': params[0], 'theta': params[1], 'cut_threshold': params[2], 'valid_loss': loss})
        else:
            valid_losses.append({'grid_range': params[0], 'theta': params[1], "ws":params[2], 'valid_loss': loss})

    # Select best performing configuration
    best = min(valid_losses, key=lambda x: x['valid_loss'])

    # Final refit with best config
    print(f"Refitting best model with {best}")
    if black_box_fitting:
        gkan_symb, symb_g, symb_h, exec_time = fit_black_box_from_kan(
            model_path=model_path_gkan,
            depth_g=depth_g,
            depth_h=depth_h,
            device=device,
            theta=-np.inf,
            pysr_model=lambda: get_pysr_model(
                model_selection=best['model_selection'],
                n_iterations=best['param'],
                # parallelism="serial",
                # random_state = seed,
                # deterministic = True
            ),
            sample_size=sample_size,
            message_passing=False,
            verbose=True,
            include_time=False
        )
    else:
        if grid_orig == None:
            gkan_symb, symb_g, symb_h, exec_time = fit_model(
                model_path=model_path_gkan,
                depth_g=depth_g,
                depth_h=depth_h,
                theta=best['theta'],
                message_passing=False,
                include_time=False,
                sample_size=sample_size,
                sort_by=best['sort_by'],
                verbose=True,
                cut_threshold=best["cut_threshold"]
            )
        else:
            gkan_symb, symb_g, symb_h, exec_time = fit_model(
                model_path=model_path_gkan,
                depth_g=depth_g,
                depth_h=depth_h,
                theta=best['theta'],
                message_passing=False,
                include_time=False,
                sample_size=sample_size,
                verbose=True,
                fit_orig=True,
                a_range = best['grid_range'],
                b_range = best['grid_range'],
                weight_simple = best['ws']
            )

    return gkan_symb, symb_g, symb_h, exec_time

def post_process_gkan(
    config,
    model_path,
    test_set,
    device='cuda',
    sample_size=10000,
    message_passing=False,
    include_time=False,
    atol=1e-5,
    rtol=1e-5,
    method='dopri5',
    adjoint=True,
    eval_model=True,
    res_file_name = 'post_process_res_more_metrics.json',
    compute_mult = True,
    grid_orig = None,
    skip_bb = False,
):

    results_dict = {}

    def get_avg_test_error(g_symb, h_symb, is_symb=True):
        try:
            res = {}
            
            y_pred_test, y_true_test = get_test_pred(
                g_symb=g_symb,
                h_symb=h_symb,
                test_set=test_set,
                message_passing=message_passing,
                include_time=include_time,
                atol=atol,
                rtol=rtol,
                method=method,
                is_symb=is_symb,
                device=device
            )
            
            test_losses_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=mean_absolute_error)
            test_mse_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=mean_squared_error)
            test_rmse_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=root_mean_squared_error)
            
            res["MAE"] = (np.mean(test_losses_symb), np.var(test_losses_symb), np.std(test_losses_symb))
            res["MSE"] = (np.mean(test_mse_symb), np.var(test_mse_symb), np.std(test_mse_symb))
            res["RMSE"] = (np.mean(test_rmse_symb), np.var(test_rmse_symb), np.std(test_rmse_symb))
            # print(f"Mean Test loss of {txt}: {ts_mean}")
            # print(f"Var Test loss of {txt}: {ts_var}")
            # print(f"Std Test loss of {txt}: {ts_std}")

            return res
        except AssertionError:
            print("Evaluation failed!")
            return {
                "MAE": (np.inf, np.inf, np.inf),
                "MSE": (np.inf, np.inf, np.inf),
                "RMSE": (np.inf, np.inf, np.inf)
            }

    best_params_file = f"{model_path}/best_params.json"
    with open(best_params_file, 'r') as f:
        best_hyperparams = json.load(f)
    
    depth_g = best_hyperparams.get('n_hidden_layers_g_net', 1) + 1
    depth_h = best_hyperparams.get('n_hidden_layers_h_net', 1) + 1

    if not skip_bb:
        print("Black-Box fitting \n")
        bb_symb, bb_g_symb, bb_h_symb, exec_time = valid_symb_model(
            config=config,
            model_path_gkan=f"{model_path}/gkan",
            device=device,
            atol=atol,
            rtol=rtol,
            method=method,
            black_box_fitting=True,
            depth_g=depth_g,
            depth_h=depth_h,
            sample_size = sample_size
        )

        print(latex(quantise(bb_symb)))
        res = get_avg_test_error(g_symb=bb_g_symb, h_symb=bb_h_symb)
        
        results_dict["black_box_symb_quant"] = str(quantise(bb_symb))
        results_dict["black_box_symb"] = str(bb_symb)
        results_dict["black_box_symb_test_MAE"] = res["MAE"][0]
        results_dict["black_box_symb_test_Var"] = res["MAE"][1]
        results_dict["black_box_symb_test_Std"] = res["MAE"][2]
        
        results_dict["black_box_symb_test_MSE"] = res["MSE"][0]
        results_dict["black_box_symb_test_MSE_Var"] = res["MSE"][1]
        results_dict["black_box_symb_test_MSE_Std"] = res["MSE"][2]
        
        results_dict["black_box_symb_test_RMSE"] = res["RMSE"][0]
        results_dict["black_box_symb_test_RMSE_Var"] = res["RMSE"][1]
        results_dict["black_box_symb_test_RMSE_Std"] = res["RMSE"][2]
        
        results_dict["black_box_exec_time"] = exec_time

    print("Spline-wise fitting\n")
    spline_symb, spl_g_symb, spl_h_symb, exec_time = valid_symb_model(
        config=config,
        model_path_gkan=f"{model_path}/gkan",
        device=device,
        atol=atol,
        rtol=rtol,
        method=method,
        black_box_fitting=False,
        depth_g=depth_g,
        depth_h=depth_h,
        sample_size = sample_size,
        grid_orig=grid_orig
    )
    print(latex(quantise(spline_symb)))
    res_sw = get_avg_test_error(g_symb=spl_g_symb, h_symb=spl_h_symb)

    results_dict["spline_wise_symb_quant"] = str(quantise(spline_symb))
    results_dict["spline_wise_symb"] = str(spline_symb)
    results_dict["spline_wise_exec_time"] = exec_time
    results_dict["spline_wise_g_symb"] = str(spl_g_symb)
    results_dict["spline_wise_h_symb"] = str(spl_h_symb)
    
    results_dict["spline_wise_symb_test_MAE"] = res_sw["MAE"][0]
    results_dict["spline_wise_symb_test_Var"] = res_sw["MAE"][1]
    results_dict["spline_wise_symb_test_Std"] = res_sw["MAE"][2]
    
    results_dict["spline_wise_symb_test_MSE"] = res_sw["MSE"][0]
    results_dict["spline_wise_symb_test_MSE_Var"] = res_sw["MSE"][1]
    results_dict["spline_wise_symb_test_MSE_Std"] = res_sw["MSE"][2]
    
    results_dict["spline_wise_symb_test_RMSE"] = res_sw["RMSE"][0]
    results_dict["spline_wise_symb_test_RMSE_Var"] = res_sw["RMSE"][1]
    results_dict["spline_wise_symb_test_RMSE_Std"] = res_sw["RMSE"][2]


    if eval_model:
        print("Evaluate raw model\n")
        # Loading best model
        best_model = build_model_from_file(
            model_path=model_path,
            message_passing=message_passing,
            include_time=include_time,
            method=method,
            adjoint=adjoint,
            atol=atol,
            rtol=rtol,
            device=device,
            compute_mult=compute_mult
        )

        tot_params = sum(p.numel() for p in best_model.parameters() if p.requires_grad)
        print(f"Number of model's parameters: {tot_params}\n")
        results_dict["Number of params"] = tot_params

        best_model = best_model.eval()
        res_model = get_avg_test_error(
            g_symb=best_model.conv.model.g_net,
            h_symb=best_model.conv.model.h_net,
            is_symb=False
        )

        results_dict["model_test_MAE"] = res_model["MAE"][0]
        results_dict["model_test_Var"] = res_model["MAE"][1]
        results_dict["model_test_Std"] = res_model["MAE"][2]
        
        results_dict["model_test_MSE"] = res_model["MSE"][0]
        results_dict["model_test_MSE_Var"] = res_model["MSE"][1]
        results_dict["model_test_MSE_Std"] = res_model["MSE"][2]
        
        results_dict["model_test_RMSE"] = res_model["RMSE"][0]
        results_dict["model_test_RMSE_Var"] = res_model["RMSE"][1]
        results_dict["model_test_RMSE_Std"] = res_model["RMSE"][2]

    with open(f"{model_path}/{res_file_name}", 'w') as file:
        json.dump(results_dict, file, indent=4)



def plot_predictions(y_true, y_pred, node_index = 0, save_path = None, show=True, title = None):
    title_ = f'y_true vs y_pred for Node {node_index}' if title is None else title
    os.makedirs(save_path, exist_ok=True)
    plt.figure(figsize=(16, 8))
    plt.plot(y_true[:, node_index, :], label='y_true', marker='o')
    plt.plot(y_pred[:, node_index, :], label='y_pred', marker='o')
    plt.xlabel('Time step')
    plt.ylabel('Value')
    plt.title(title_)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    if show:
        plt.show()
    if save_path is not None:
        plt.savefig(f"{save_path}/{title_}.png")
    plt.clf()
    plt.close()


if __name__ == '__main__':
    set_pytorch_seed(0)
    """## LB losses

    ### Kuramoto
    """

    kur_config = load_config("./configs/config_pred_deriv/config_ic1/config_kuramoto.yml")

    KUR = get_test_set(
        dynamics=kur_config['name'],
        device='cuda',
        input_range=kur_config['input_range'],
        **kur_config['integration_kwargs']
    )

    g_symb = lambda x: torch.sin(x[:, 1] - x[:, 0]).unsqueeze(-1)
    h_symb = lambda x: 2.0 + 0.5 * x[:, 1].unsqueeze(-1)

    test_losses = get_symb_test_error(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=KUR,
        message_passing=True,
        include_time=False,
        is_symb=False
    )

    ts_mean = np.mean(test_losses)
    ts_var = np.var(test_losses)
    ts_std = np.std(test_losses)

    print(f"Mean Test loss of symbolic formula: {ts_mean}")
    print(f"Var Test loss of symbolic formula: {ts_var}")
    print(f"Std Test loss of symbolic formula: {ts_std}")

    """### Epidemics"""

    epid_config = load_config("./configs/config_pred_deriv/config_ic1/config_epidemics.yml")

    EPID = get_test_set(
        dynamics=epid_config['name'],
        device='cuda',
        input_range=epid_config['input_range'],
        **epid_config['integration_kwargs']
    )

    g_symb = lambda x: 0.5*x[:, 1].unsqueeze(-1) * (1 - x[:, 0].unsqueeze(-1))
    h_symb = lambda x: x[:, 1].unsqueeze(1) - 0.5 * x[:, 0].unsqueeze(-1)

    test_losses = get_symb_test_error(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=EPID,
        message_passing=True,
        include_time=False,
        is_symb=False
    )


    ts_mean = np.mean(test_losses)
    ts_var = np.var(test_losses)
    ts_std = np.std(test_losses)

    print(f"Mean Test loss of symbolic formula: {ts_mean}")
    print(f"Var Test loss of symbolic formula: {ts_var}")
    print(f"Std Test loss of symbolic formula: {ts_std}")

    """### Population"""

    pop_config = load_config("./configs/config_pred_deriv/config_ic1/config_population.yml")

    POP = get_test_set(
        dynamics=pop_config['name'],
        device='cuda',
        input_range=pop_config['input_range'],
        **pop_config['integration_kwargs']
    )

    g_symb = lambda x: 0.2*torch.pow(x[:, 1].unsqueeze(-1), 3)
    h_symb = lambda x: -0.5 * x[:, 0].unsqueeze(-1) + x[:, 1].unsqueeze(1)

    test_losses = get_symb_test_error(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=POP,
        message_passing=True,
        include_time=False,
        is_symb=False
    )

    ts_mean = np.mean(test_losses)
    ts_var = np.var(test_losses)
    ts_std = np.std(test_losses)

    print(f"Mean Test loss of symbolic formula: {ts_mean}")
    print(f"Var Test loss of symbolic formula: {ts_var}")
    print(f"Std Test loss of symbolic formula: {ts_std}")

    """### Biochemical"""

    bio_config = load_config("./configs/config_pred_deriv/config_ic1/config_biochemical.yml")

    BIO = get_test_set(
        dynamics=bio_config['name'],
        device='cuda',
        input_range=bio_config['input_range'],
        **bio_config['integration_kwargs']
    )

    g_symb = lambda x: (-0.5*x[:, 1] * x[:, 0]).unsqueeze(-1)
    h_symb = lambda x: (1.0 - 0.5 * x[:, 0]).unsqueeze(-1)  + x[:, 1].unsqueeze(-1)

    test_losses = get_symb_test_error(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=BIO,
        message_passing=True,
        include_time=False,
        is_symb=False
    )

    ts_mean = np.mean(test_losses)
    ts_var = np.var(test_losses)
    ts_std = np.std(test_losses)

    print(f"Mean Test loss of symbolic formula: {ts_mean}")
    print(f"Var Test loss of symbolic formula: {ts_var}")
    print(f"Std Test loss of symbolic formula: {ts_std}")

    """## Symb Reg
    

    ### Biochemical

    #### IC=1
    """
    
    model_path_gkan = "./saved_models_optuna/model-biochemical-gkan/biochemical_gkan_no_fp_true/0"

    post_process_gkan(
        config=bio_config,
        model_path=model_path_gkan,
        test_set=BIO,
        device='cuda',
        sample_size=10000,
        message_passing=False,
        include_time=False,
        atol=1e-5,
        rtol=1e-5,
        method="dopri5"
    )
    
    """#### SNR"""

    model_paths_gkan = [
        "./saved_models_optuna/model-biochemical-gkan/biochemical_gkan_den_true_70db/0",
        "./saved_models_optuna/model-biochemical-gkan/biochemical_gkan_den_true_50db/0",
        "./saved_models_optuna/model-biochemical-gkan/biochemical_gkan_den_true_20db/0"
    ]

    for model_path in model_paths_gkan:
        print(model_path)

        post_process_gkan(
            config=bio_config,
            model_path=model_path,
            test_set=BIO,
            device='cuda',
            sample_size=10000,
            message_passing=False,
            include_time=False,
            atol=1e-5,
            rtol=1e-5,
            method="dopri5",
            eval_model=True,
            compute_mult=True
        )

    """### Kuramoto

    #### IC=1
    """

    model_path_gkan = "./saved_models_optuna/model-kuramoto-gkan/kuramoto_gkan_no_fp_true/0"

    post_process_gkan(
        config=kur_config,
        model_path=model_path_gkan,
        test_set=KUR,
        device='cuda',
        sample_size=10000,
        message_passing=False,
        include_time=False,
        atol=1e-5,
        rtol=1e-5,
        method="dopri5"
    )

    # """#### SNR"""

    model_paths_gkan = [
        "./saved_models_optuna/model-kuramoto-gkan/kuramoto_gkan_den_true_70db/0",
        "./saved_models_optuna/model-kuramoto-gkan/kuramoto_gkan_den_true_50db/0",
        "./saved_models_optuna/model-kuramoto-gkan/kuramoto_gkan_den_true_20db/0",
    ]

    for model_path in model_paths_gkan:
        print(model_path)

        post_process_gkan(
            config=kur_config,
            model_path=model_path,
            test_set=KUR,
            device='cuda',
            sample_size=10000,
            message_passing=False,
            include_time=False,
            atol=1e-5,
            rtol=1e-5,
            method="dopri5",
            eval_model=True,
            compute_mult=True
        )

    """### Epidemics

    #### IC=1
    """

    model_path_gkan = "./saved_models_optuna/model-epidemics-gkan/epidemics_gkan_no_fp_true_2/0"

    post_process_gkan(
        config=epid_config,
        model_path=model_path_gkan,
        test_set=EPID,
        device='cuda',
        sample_size=10000,
        message_passing=False,
        include_time=False,
        atol=1e-5,
        rtol=1e-5,
        method="dopri5"
    )

    # """#### SNR"""

    model_paths_gkan = [
        "./saved_models_optuna/model-epidemics-gkan/epidemics_gkan_den_true_70db_2/0",
        "./saved_models_optuna/model-epidemics-gkan/epidemics_gkan_den_true_50db_2/0",
        "./saved_models_optuna/model-epidemics-gkan/epidemics_gkan_den_true_20db_2/0",
    ]

    for model_path in model_paths_gkan:
        print(model_path)
        post_process_gkan(
            config=epid_config,
            model_path=model_path,
            test_set=EPID,
            device='cuda',
            sample_size=10000,
            message_passing=False,
            include_time=False,
            atol=1e-5,
            rtol=1e-5,
            method="dopri5",
            eval_model=True,
            compute_mult=True
        )

    # """### Population

    # #### IC=1
    # """

    model_path_gkan = "./saved_models_optuna/model-population-gkan/population_gkan_no_fp_true/0"

    post_process_gkan(
        config=pop_config,
        model_path=model_path_gkan,
        test_set=POP,
        device='cuda',
        sample_size=10000,
        message_passing=False,
        include_time=False,
        atol=1e-5,
        rtol=1e-5,
        method="dopri5"
    )
    

    """#### SNR"""

    model_paths_gkan = [
        "./saved_models_optuna/model-population-gkan/population_gkan_den_true_70db/0",
        "./saved_models_optuna/model-population-gkan/population_gkan_den_true_50db/0",
        "./saved_models_optuna/model-population-gkan/population_gkan_den_true_20db/0",
    ]

    for model_path in model_paths_gkan:
        print(model_path)
        post_process_gkan(
            config=pop_config,
            model_path=model_path,
            test_set=POP,
            device='cuda',
            sample_size=10000,
            message_passing=False,
            include_time=False,
            atol=1e-5,
            rtol=1e-5,
            method="dopri5",
            eval_model=True,
            compute_mult=True
        )
