# -*- coding: utf-8 -*-
"""post_processing_mpnn.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/18MmEO1-fqM1P5uxZ-3fkh5ch7s6bYqua

## Importing
"""

# Commented out IPython magic to ensure Python compatibility.
# %load_ext autoreload
# %autoreload 2

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

from utils.utils import *
import optuna
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend
from experiments.experiments_gkan import ExperimentsGKAN
from experiments.experiments_mpnn import ExperimentsMPNN
import sympytorch

import warnings
warnings.filterwarnings("ignore")

import random

def set_pytorch_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    torch.use_deterministic_algorithms(True, warn_only=True)

"""## Utils"""

from models.utils.MPNN import MPNN
from models.baseline.MPNN_ODE import MPNN_ODE
from train_and_eval import eval_model
from datasets.SyntheticData import SyntheticData
from sympy import symbols, sin, summation, simplify
import networkx as nx
from torch_geometric.utils import from_networkx
from utils.utils import integrate
from torch_geometric.data import Data
from experiments.experiments_mpnn import activations
from models.utils.MLP import MLP
from models.baseline.LLC import LLC_ODE
from models.baseline.LLC_Conv import Q_inter, Q_self
from experiments.experiments_mpnn import activations

from sympy import latex
from torch.utils.data import DataLoader

from sympy import latex
from torch.utils.data import DataLoader
from post_processing import get_model, make_callable, get_symb_test_error, get_test_set, integrate_test_set, plot_predictions, get_list_test_errors, get_test_pred
from sklearn.metrics import mean_absolute_error, mean_squared_error, root_mean_squared_error


def build_model_from_file_mpnn(model_path, message_passing=False, include_time=False, method='dopri5', adjoint=True, atol=1e-5, rtol=1e-5):
    best_params_file = f"{model_path}/best_params.json"
    best_state_path = f"{model_path}/mpnn/state_dict.pth"
    with open(best_params_file, 'r') as f:
        best_hyperparams = json.load(f)
    
    in_dim = 1
    
    hidden_layers = [best_hyperparams["hidden_dims_g_net"] for _ in range(best_hyperparams["n_hidden_layers_g_net"])]
    hidden_layers = [2*in_dim] + hidden_layers + [in_dim]    
    # g_net
    g_net = MLP(
        hidden_layers=hidden_layers,
        af = activations[best_hyperparams['af_g_net']],
        dropout_rate=best_hyperparams['drop_p_g_net'],
    )
    
    time_dim = 1 if include_time else 0
    in_dim_h = 2 if message_passing else 1
    in_dim_h += time_dim
    hidden_layers = [best_hyperparams["hidden_dims_h_net"] for _ in range(best_hyperparams["n_hidden_layers_h_net"])]
    hidden_layers = [in_dim_h] + hidden_layers + [in_dim] 
    
    
    # h_net
    h_net = MLP(
        hidden_layers=hidden_layers,
        af = activations[best_hyperparams['af_h_net']],
        dropout_rate=best_hyperparams['drop_p_h_net'],
    )
    
    mpnn = MPNN(
        h_net=h_net,
        g_net=g_net,
        message_passing=message_passing,
        include_time=include_time
    )
    
    model = MPNN_ODE(
        conv=mpnn,
        model_path='./saved_models_optuna/tmp',
        integration_method=method,
        adjoint=adjoint,
        atol=atol,
        rtol=rtol
    )
    
    model = model.to(torch.device('cuda'))
    model.load_state_dict(torch.load(best_state_path, weights_only=False, map_location=torch.device('cuda')))
    
    return model


def build_model_from_file_llc(model_path, message_passing=False, include_time=False, method='dopri5', adjoint=True, atol=1e-5, rtol=1e-5):

    best_params_file = f"{model_path}/best_params.json"
    best_state_path = f"{model_path}/llc/state_dict.pth"

    with open(best_params_file, 'r') as f:
        best_hyperparams = json.load(f)

    in_dim = 1
    time_dim = 1 if include_time else 0

    # === g_net config (for Q_inter) ===
    def build_q_layer_config(prefix):
        n_layers = best_hyperparams[f'n_hidden_layers_{prefix}']
        hidden_dim = best_hyperparams[f'hidden_dims_{prefix}']
        activation = activations[best_hyperparams[f'af_{prefix}']]
        dropout = best_hyperparams[f'drop_p_{prefix}']

        if prefix == "g0":
            input_dim = 2 * in_dim
        else:
            input_dim = in_dim 

        hidden_layes = [hidden_dim] * n_layers
        layers = [input_dim] + hidden_layes + [in_dim]
        return {
            f'hidden_layers_{prefix}': layers,
            f'af_{prefix}': activation,
            f'dr_{prefix}': dropout
        }

    g0_config = build_q_layer_config("g0")
    g1_config = build_q_layer_config("g1")
    g2_config = build_q_layer_config("g2")
    g_net = Q_inter(**{**g0_config, **g1_config, **g2_config})

    # === h_net config (for Q_self) ===
    in_dim_h = 2 * in_dim if message_passing else in_dim
    in_dim_h += time_dim
    hidden_dim_h = best_hyperparams["hidden_dims_h_net"]
    n_layers_h = best_hyperparams["n_hidden_layers_h_net"]
    hidden_layers_h = [hidden_dim_h] * n_layers_h
    layers_h = [in_dim_h] + hidden_layers_h + [in_dim]

    h_net = Q_self(
        hidden_layers=layers_h,
        af=activations[best_hyperparams['af_h_net']],
        dropout_rate=best_hyperparams['drop_p_h_net']
    )

    # === Full MPNN and ODE wrapper ===
    mpnn = MPNN(
        h_net=h_net,
        g_net=g_net,
        message_passing=message_passing,
        include_time=include_time
    )

    model = LLC_ODE(
        conv=mpnn,
        model_path='./saved_models_optuna/tmp',
        adjoint=adjoint,
        integration_method=method,
        atol=atol,
        rtol=rtol
    )

    model = model.to(torch.device('cuda'))
    model.load_state_dict(torch.load(best_state_path, weights_only=False, map_location=torch.device('cuda')))

    return model


def valid_symb_model(
    config,
    model_path_gkan,
    device='cuda',
    atol=1e-5,
    rtol=1e-5,
    method='dopri5',
    sample_size=10000
):
    seed = 9999
    graph = nx.barabasi_albert_graph(100, 3, seed=seed)

    # Prepare validation/test set
    valid_set = integrate_test_set(
        graph=graph,
        dynamics=config['name'],
        seed=seed,
        device=device,
        input_range=config['input_range'],
        t_span=(0, 1),
        **config['integration_kwargs']
    )

    # Helper to compute validation loss
    def evaluate_model(g_symb, h_symb, is_symb=True):
        errs = get_symb_test_error(
            g_symb=g_symb,
            h_symb=h_symb,
            test_set=[valid_set],
            message_passing=False,
            include_time=False,
            method=method,
            atol=atol,
            rtol=rtol,
            is_symb=is_symb
        )
        return errs[0]

    # Helper to fit model for current config
    def fit_single_model(param1, param2):
        print(f"Fitting black-box model with {param1} and {param2} iterations")
        pysr_model = lambda: get_pysr_model(
            model_selection=param1, 
            n_iterations=param2,
            # parallelism="serial",
            # random_state = seed,
            # deterministic = True
        )
        _, g_symb, h_symb, _ = fit_mpnn(
            device=device,
            model_path=model_path_gkan,
            pysr_model=pysr_model,
            sample_size=sample_size,
            message_passing=False,
            verbose=False
        )
        
        return g_symb, h_symb

    param_grid = (["score", "accuracy"], [50, 100, 200])
    search_space = [(mod, val) for mod in param_grid[0] for val in param_grid[1]]
    valid_losses = []
    
    for mod, val in search_space:
        g_symb, h_symb = fit_single_model(mod, val)
        try:
            loss = evaluate_model(g_symb, h_symb, is_symb=True)
        except AssertionError:
            loss = 1e8
        valid_losses.append({'model_selection': mod, 'param': val, 'valid_loss': loss})
    
    best = min(valid_losses, key=lambda x: x['valid_loss'])    
    
    print(f"Refitting best model with {best}")
    
    gkan_symb, symb_g, symb_h, exec_time = fit_mpnn(
        model_path=model_path_gkan,
        device=device,
        pysr_model=lambda: get_pysr_model(
            model_selection=best['model_selection'],
            n_iterations=best['param'],
            # parallelism="serial",
            # random_state = seed,
            # deterministic = True
        ),
        sample_size=sample_size,
        message_passing=False,
        verbose=True,
        include_time=False
    )
    
    return gkan_symb, symb_g, symb_h, exec_time


def post_process_mpnn(
    config,
    model_path, 
    test_set, 
    device='cuda',
    sample_size=10000,
    message_passing=False, 
    include_time=False,
    atol=1e-5,
    rtol=1e-5,
    method='dopri5',
    adjoint=True,
    eval_model=True,
    model_type="MPNN",
    res_file_name="post_process_res_more_metrics.json"
):
    
    results_dict = {}
    
    def get_avg_test_error(g_symb, h_symb, is_symb=True):
        try:
            res = {}
            
            y_pred_test, y_true_test = get_test_pred(
                g_symb=g_symb,
                h_symb=h_symb,
                test_set=test_set,
                message_passing=message_passing,
                include_time=include_time,
                atol=atol,
                rtol=rtol,
                method=method,
                is_symb=is_symb,
                device=device
            )
            
            test_losses_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=mean_absolute_error)
            test_mse_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=mean_squared_error)
            test_rmse_symb = get_list_test_errors(y_pred_test, y_true_test, criterion=root_mean_squared_error)
            
            res["MAE"] = (np.mean(test_losses_symb), np.var(test_losses_symb), np.std(test_losses_symb))
            res["MSE"] = (np.mean(test_mse_symb), np.var(test_mse_symb), np.std(test_mse_symb))
            res["RMSE"] = (np.mean(test_rmse_symb), np.var(test_rmse_symb), np.std(test_rmse_symb))
            # print(f"Mean Test loss of {txt}: {ts_mean}")
            # print(f"Var Test loss of {txt}: {ts_var}")
            # print(f"Std Test loss of {txt}: {ts_std}")

            return res
        except AssertionError:
            print("Evaluation failed!")
            return {
                "MAE": (np.inf, np.inf, np.inf),
                "MSE": (np.inf, np.inf, np.inf),
                "RMSE": (np.inf, np.inf, np.inf)
            }
        
    
    print("Black-Box fitting \n")
    bb_symb, bb_g_symb, bb_h_symb, exec_time = valid_symb_model(
        config=config,
        model_path_gkan=f"{model_path}/mpnn" if model_type == 'MPNN' else f"{model_path}/llc",
        device=device,
        atol=atol,
        rtol=rtol,
        method=method,
        sample_size = sample_size
    )
    
    print(latex(quantise(bb_symb)))
    res_bb = get_avg_test_error(g_symb=bb_g_symb, h_symb=bb_h_symb)
    
    results_dict["black_box_symb_quant"] = str(quantise(bb_symb))
    results_dict["black_box_symb"] = str(bb_symb)
    results_dict["black_box_exec_time"] = exec_time
    
    results_dict["black_box_symb_test_MAE"] = res_bb["MAE"][0]
    results_dict["black_box_symb_test_Var"] = res_bb["MAE"][1]
    results_dict["black_box_symb_test_Std"] = res_bb["MAE"][2]
    
    results_dict["black_box_symb_test_MSE"] = res_bb["MSE"][0]
    results_dict["black_box_symb_test_MSE_Var"] = res_bb["MSE"][1]
    results_dict["black_box_symb_test_MSE_Std"] = res_bb["MSE"][2]
    
    results_dict["black_box_symb_test_RMSE"] = res_bb["RMSE"][0]
    results_dict["black_box_symb_test_RMSE_Var"] = res_bb["RMSE"][1]
    results_dict["black_box_symb_test_RMSE_Std"] = res_bb["RMSE"][2]
    
    if eval_model:
        print("Evaluate raw model\n")
        # Loading best model
        if model_type == "MPNN":
            best_model = build_model_from_file_mpnn(
                model_path=model_path,
                message_passing=message_passing,
                include_time=include_time,
                method=method,
                adjoint=adjoint,
                atol=atol,
                rtol=rtol
            )
        elif model_type == "LLC":
            best_model = build_model_from_file_llc(
                model_path=model_path,
                message_passing=message_passing,
                include_time=include_time,
                method=method,
                adjoint=adjoint,
                rtol=rtol,
                atol=atol
            )
        else:
            raise NotImplementedError("Not supported model!")

        tot_params = sum(p.numel() for p in best_model.parameters() if p.requires_grad)
        print(f"Number of model's parameters: {tot_params}\n")
        results_dict["Number of params"] = tot_params

        best_model = best_model.eval()
        res_model = get_avg_test_error(
            g_symb=best_model.conv.model.g_net,
            h_symb=best_model.conv.model.h_net,
            is_symb=False
        )
        
        results_dict["model_test_MAE"] = res_model["MAE"][0]
        results_dict["model_test_Var"] = res_model["MAE"][1]
        results_dict["model_test_Std"] = res_model["MAE"][2]
        
        results_dict["model_test_MSE"] = res_model["MSE"][0]
        results_dict["model_test_MSE_Var"] = res_model["MSE"][1]
        results_dict["model_test_MSE_Std"] = res_model["MSE"][2]
        
        results_dict["model_test_RMSE"] = res_model["RMSE"][0]
        results_dict["model_test_RMSE_Var"] = res_model["RMSE"][1]
        results_dict["model_test_RMSE_Std"] = res_model["RMSE"][2]
    
    with open(f"{model_path}/{res_file_name}", 'w') as file:
        json.dump(results_dict, file, indent=4)   

import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--is_llc", 
        action="store_true", 
        help="Set this flag to enable LLC mode"
    )
    args = parser.parse_args()
    
    set_pytorch_seed(0)
    is_llc = args.is_llc
    
    print(f"\n\nIS LLC? {is_llc}\n\n")
    """## LB Losses

    ### Kuramoto
    """

    kur_config = load_config("./configs/config_pred_deriv/config_ic1/config_kuramoto.yml")

    KUR = get_test_set(
        dynamics=kur_config['name'],
        device='cuda',
        input_range=kur_config['input_range'],
        **kur_config['integration_kwargs']
    )

    g_symb = lambda x: torch.sin(x[:, 1] - x[:, 0]).unsqueeze(-1)
    h_symb = lambda x: 2.0 + 0.5 * x[:, 1].unsqueeze(-1)

    test_losses = get_symb_test_error(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=KUR,
        message_passing=True,
        include_time=False,
        is_symb=False
    )

    ts_mean = np.mean(test_losses)
    ts_var = np.var(test_losses)
    ts_std = np.std(test_losses)

    print(f"Mean Test loss of symbolic formula: {ts_mean}")
    print(f"Var Test loss of symbolic formula: {ts_var}")
    print(f"Std Test loss of symbolic formula: {ts_std}")

    """### Epidemics"""

    epid_config = load_config("./configs/config_pred_deriv/config_ic1/config_epidemics.yml")

    EPID = get_test_set(
        dynamics=epid_config['name'],
        device='cuda',
        input_range=epid_config['input_range'],
        **epid_config['integration_kwargs']
    )

    g_symb = lambda x: 0.5*x[:, 1].unsqueeze(-1) * (1 - x[:, 0].unsqueeze(-1))
    h_symb = lambda x: x[:, 1].unsqueeze(1) - 0.5 * x[:, 0].unsqueeze(-1)

    test_losses = get_symb_test_error(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=EPID,
        message_passing=True,
        include_time=False,
        is_symb=False
    )


    ts_mean = np.mean(test_losses)
    ts_var = np.var(test_losses)
    ts_std = np.std(test_losses)

    print(f"Mean Test loss of symbolic formula: {ts_mean}")
    print(f"Var Test loss of symbolic formula: {ts_var}")
    print(f"Std Test loss of symbolic formula: {ts_std}")

    """### Biochemical"""

    bio_config = load_config("./configs/config_pred_deriv/config_ic1/config_biochemical.yml")

    BIO = get_test_set(
        dynamics=bio_config['name'],
        device='cuda',
        input_range=bio_config['input_range'],
        **bio_config['integration_kwargs']
    )

    g_symb = lambda x: (-0.5*x[:, 1] * x[:, 0]).unsqueeze(-1)
    h_symb = lambda x: (1.0 - 0.5 * x[:, 0]).unsqueeze(-1)  + x[:, 1].unsqueeze(-1)

    test_losses = get_symb_test_error(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=BIO,
        message_passing=True,
        include_time=False,
        is_symb=False
    )

    ts_mean = np.mean(test_losses)
    ts_var = np.var(test_losses)
    ts_std = np.std(test_losses)

    print(f"Mean Test loss of symbolic formula: {ts_mean}")
    print(f"Var Test loss of symbolic formula: {ts_var}")
    print(f"Std Test loss of symbolic formula: {ts_std}")

    """### Population"""

    pop_config = load_config("./configs/config_pred_deriv/config_ic1/config_population.yml")

    POP = get_test_set(
        dynamics=pop_config['name'],
        device='cuda',
        input_range=pop_config['input_range'],
        **pop_config['integration_kwargs']
    )

    g_symb = lambda x: 0.2*torch.pow(x[:, 1].unsqueeze(-1), 3)
    h_symb = lambda x: -0.5 * x[:, 0].unsqueeze(-1) + x[:, 1].unsqueeze(1)

    test_losses = get_symb_test_error(
        g_symb=g_symb,
        h_symb=h_symb,
        test_set=POP,
        message_passing=True,
        include_time=False,
        is_symb=False
    )

    ts_mean = np.mean(test_losses)
    ts_var = np.var(test_losses)
    ts_std = np.std(test_losses)

    print(f"Mean Test loss of symbolic formula: {ts_mean}")
    print(f"Var Test loss of symbolic formula: {ts_var}")
    print(f"Std Test loss of symbolic formula: {ts_std}")

    """## Symb Reg

    ### Biochemical

    #### IC=1
    """
    if not is_llc:
        model_path_mpnn = './saved_models_optuna/model-biochemical-mpnn/biochemical_mpnn_ic1_s5_pd_mult_12/0'
    else: model_path_mpnn = './saved_models_optuna/model-biochemical-llc/biochemical_llc_2/0'
    
    post_process_mpnn(
        config=bio_config,
        model_path=model_path_mpnn,
        test_set=BIO,
        device='cuda',
        sample_size=10000,
        message_passing=False,
        include_time=False,
        atol=1e-5,
        rtol=1e-5,
        method="dopri5",
        model_type="LLC" if is_llc else "MPNN"
    )

    """#### SNR"""
    if is_llc:
        model_paths = [
            "./saved_models_optuna/model-biochemical-llc/biochemical_llc_70db_3/0",
            "./saved_models_optuna/model-biochemical-llc/biochemical_llc_50db_3/0",
            "./saved_models_optuna/model-biochemical-llc/biochemical_llc_20db_3/0"
        ]
    else:
        model_paths = [
            "./saved_models_optuna/model-biochemical-mpnn/biochemical_mpnn_ic1_s5_pd_mult_noise_70db_2/0",
            "./saved_models_optuna/model-biochemical-mpnn/biochemical_mpnn_ic1_s5_pd_mult_noise_50db_2/0",
            "./saved_models_optuna/model-biochemical-mpnn/biochemical_mpnn_ic1_s5_pd_mult_noise_20db_2/0"
        ]

    for model_path in model_paths:
        print(model_path)
        
        post_process_mpnn(
            config=bio_config,
            model_path=model_path,
            test_set=BIO,
            device='cuda',
            sample_size=10000,
            message_passing=False,
            include_time=False,
            atol=1e-5,
            rtol=1e-5,
            method="dopri5",
            model_type="LLC" if is_llc else "MPNN"
        )

    """### Kuramoto

    #### IC=1
    """
    if is_llc:
        model_path_mpnn = './saved_models_optuna/model-kuramoto-llc/kuramoto_llc_2/0'
    else:
        model_path_mpnn = './saved_models_optuna/model-kuramoto-mpnn/kuramoto_mpnn_ic1_s5_pd_mult_12/0'

    post_process_mpnn(
        config=kur_config,
        model_path=model_path_mpnn,
        test_set=KUR,
        device='cuda',
        sample_size=10000,
        message_passing=False,
        include_time=False,
        atol=1e-5,
        rtol=1e-5,
        method="dopri5",
        model_type="LLC" if is_llc else "MPNN"
    )

    """#### SNR"""

    if is_llc:
        model_paths = [
            "./saved_models_optuna/model-kuramoto-llc/kuramoto_llc_70db_3/0",
            "./saved_models_optuna/model-kuramoto-llc/kuramoto_llc_50db_3/0",
            "./saved_models_optuna/model-kuramoto-llc/kuramoto_llc_20db_3/0"
        ]
    else:
        model_paths = [
            "./saved_models_optuna/model-kuramoto-mpnn/kuramoto_mpnn_ic1_s5_pd_mult_noise_70db_2/0",
            "./saved_models_optuna/model-kuramoto-mpnn/kuramoto_mpnn_ic1_s5_pd_mult_noise_50db_2/0",
            "./saved_models_optuna/model-kuramoto-mpnn/kuramoto_mpnn_ic1_s5_pd_mult_noise_20db_2/0"
        ]

    for model_path in model_paths:
        print(model_path)
        
        post_process_mpnn(
            config=kur_config,
            model_path=model_path,
            test_set=KUR,
            device='cuda',
            sample_size=10000,
            message_passing=False,
            include_time=False,
            atol=1e-5,
            rtol=1e-5,
            method="dopri5",
            model_type="LLC" if is_llc else "MPNN"
        )

    """### Epidemics

    #### IC=1
    """

    if is_llc:
        model_path_mpnn = './saved_models_optuna/model-epidemics-llc/epidemics_llc_2/0'
    else:
        model_path_mpnn = './saved_models_optuna/model-epidemics-mpnn/epidemics_mpnn_ic1_s5_pd_mult_12/0'

    post_process_mpnn(
        config=epid_config,
        model_path=model_path_mpnn,
        test_set=EPID,
        device='cuda',
        sample_size=10000,
        message_passing=False,
        include_time=False,
        atol=1e-5,
        rtol=1e-5,
        method="dopri5",
        model_type="LLC" if is_llc else "MPNN"
    )

    """#### SNR"""

    if is_llc:
        model_paths = [
            "./saved_models_optuna/model-epidemics-llc/epidemics_llc_70db_3/0",
            "./saved_models_optuna/model-epidemics-llc/epidemics_llc_50db_3/0",
            "./saved_models_optuna/model-epidemics-llc/epidemics_llc_20db_3/0"
        ]
    else:
        model_paths = [
            "./saved_models_optuna/model-epidemics-mpnn/epidemics_mpnn_ic1_s5_pd_mult_noise_70db_2/0",
            "./saved_models_optuna/model-epidemics-mpnn/epidemics_mpnn_ic1_s5_pd_mult_noise_50db_2/0",
            "./saved_models_optuna/model-epidemics-mpnn/epidemics_mpnn_ic1_s5_pd_mult_noise_20db_2/0"
        ]

    for model_path in model_paths:
        print(model_path)
        post_process_mpnn(
            config=epid_config,
            model_path=model_path,
            test_set=EPID,
            device='cuda',
            sample_size=10000,
            message_passing=False,
            include_time=False,
            atol=1e-5,
            rtol=1e-5,
            method="dopri5",
            model_type="LLC" if is_llc else "MPNN"
        )

    """### Population

    #### IC=1
    """
    if is_llc:
        model_path_mpnn = './saved_models_optuna/model-population-llc/population_llc_2/0'
    else:
        model_path_mpnn = './saved_models_optuna/model-population-mpnn/population_mpnn_ic1_s5_pd_mult_12/0'

    post_process_mpnn(
        config=pop_config,
        model_path=model_path_mpnn,
        test_set=POP,
        device='cuda',
        sample_size=10000,
        message_passing=False,
        include_time=False,
        atol=1e-5,
        rtol=1e-5,
        method="dopri5",
        model_type="LLC" if is_llc else "MPNN"
    )

    """#### SNR"""

    if is_llc:
        model_paths = [
            "./saved_models_optuna/model-population-llc/population_llc_70db_3/0",
            "./saved_models_optuna/model-population-llc/population_llc_50db_3/0",
            "./saved_models_optuna/model-population-llc/population_llc_20db_3/0"
        ]
    else:
        model_paths = [
            "./saved_models_optuna/model-population-mpnn/population_mpnn_ic1_s5_pd_mult_noise_70db_2/0",
            "./saved_models_optuna/model-population-mpnn/population_mpnn_ic1_s5_pd_mult_noise_50db_2/0",
            "./saved_models_optuna/model-population-mpnn/population_mpnn_ic1_s5_pd_mult_noise_20db_2/0"
        ]

    for model_path in model_paths:
        print(model_path)
        post_process_mpnn(
            config=pop_config,
            model_path=model_path,
            test_set=POP,
            device='cuda',
            sample_size=10000,
            message_passing=False,
            include_time=False,
            atol=1e-5,
            rtol=1e-5,
            method="dopri5",
            model_type="LLC" if is_llc else "MPNN"
        )