"""
Some static methods
"""


from pysr import PySRRegressor
import torch
import yaml
import matplotlib.pyplot as plt
import os 
import numpy as np
from datasets.data_utils import numerical_integration
import sympy as sp
import json
from models.kan.KanLayer import KANLayer
from models.kan.KAN import KAN
from collections import defaultdict
from scipy.optimize import curve_fit
from sklearn.metrics import mean_squared_error
from sympy import count_ops
import warnings
import sympy
import re
import pandas as pd
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split
from sympy import lambdify
import time
from utils.orig_sw_fitting import *


warnings.filterwarnings("ignore", category=FutureWarning)
from scipy.optimize import OptimizeWarning
warnings.simplefilter("ignore", OptimizeWarning)


SCORES = {
    'MSE': torch.nn.MSELoss(),
    'MAE': torch.nn.L1Loss()
}


SYMBOLIC_LIB_NUMPY = {
    'x': lambda x: x,
    'x^2': lambda x: x**2,
    'x^3': lambda x: x**3,
    'exp': lambda x: np.clip(np.exp(x), a_min=None, a_max=1e5),
    'abs': lambda x: np.abs(x),
    'sin': lambda x: np.sin(x),
    'cos': lambda x: np.cos(x),
    'tan': lambda x: np.tan(x),
    'tanh': lambda x: np.tanh(x),
    'ln': lambda x: np.log(x),
    '0': lambda x: x*0,
}

SYMBOLIC_LIB_SYMPY = {
    'x': lambda x: x,
    'x^2': lambda x: x**2,
    'x^3': lambda x: x**3,
    'exp': lambda x: sp.exp(x),
    'abs': lambda x: sp.Abs(x),
    'sin': lambda x: sp.sin(x),
    'cos': lambda x: sp.cos(x),
    'tan': lambda x: sp.tan(x),
    'tanh': lambda x: sp.tanh(x),
    'ln': lambda x: sp.ln(x),
    '0': lambda x: 0 * x,
}


def save_logs(file_name, log_message, save_updates=True):
    """
    Save logs to file
    
    Args:
        - file_name : Logs file name
        - log_message : Message to save
        - save_updates : Whether to save log message or not
    """
    if save_updates:
        print(log_message)
        with open(file_name, 'a') as logs:
            logs.write('\n'+log_message)



def load_config(config_path='config.yml'):
    """
    Returns a dictionary of the specified config file
    """
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config



def plot(folder_path, layers, show_plots=False):
    '''
    Plots the shape of all the activation functions of the specified KAN layer
    '''
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
        
    for l, layer in enumerate(layers):
        assert layer.cache_act is not None and layer.cache_preact is not None, 'Populate model activations before plotting' 
        activations = layer.cache_act
        pre_activations = layer.cache_preact
        preact_sorted, indices = torch.sort(pre_activations, dim=0)
        mult_mask = layer.multiplicative_mask
        for j in range(layer.out_features):
            color = "blue" if not mult_mask[j] else "red"
            for i in range(layer.in_features):
                out = activations[:, j, i]
                out = out[indices[:, i]]
                plt.figure()
                plt.plot(preact_sorted[:, i].cpu().detach().numpy(), out.cpu().detach().numpy(), linewidth=2.5, color=color)
                plt.title(f"Act. (Layer: {l}, Neuron: {j}, Input: {i})")
                plt.savefig(f"{folder_path}/out_{l}_{j}_{i}.png")
                if show_plots:
                    plt.show()
                plt.clf()
                plt.close()
                

def integrate(
    input_range,
    t_span,
    t_eval_steps,
    dynamics,
    device,   
    graph,
    rng,
    **integration_kwargs
):
    """
    Integrates the specified dynamics over the given graph
    """
    N = graph.number_of_nodes()
    y0 = rng.uniform(input_range[0], input_range[1], N).astype(np.float64)

    xs, t = numerical_integration(
        G=graph,
        dynamics=dynamics,
        initial_state=y0,
        time_span=t_span,
        t_eval_steps=t_eval_steps,
        **integration_kwargs
    )
    return torch.from_numpy(xs).float().unsqueeze(2).to(device), torch.from_numpy(t).float().to(device)


def sample_from_spatio_temporal_graph(dataset, edge_index, edge_attr, t=None, sample_size=32):
    device = dataset.device
    
    sample_size = sample_size if sample_size != -1 else len(dataset)
    interval = len(dataset) // sample_size
    sampled_indices = torch.tensor([i * interval for i in range(sample_size)], device=device)
    
    samples = dataset[sampled_indices]
    t_sampled = t[sampled_indices] if t is not None else torch.tensor([], device=device)
    concatenated_x = torch.reshape(samples, (-1, samples.size(2))).to(device)
    
    concatenated_t = t_sampled.unsqueeze(0).repeat(dataset.size(1), 1).reshape(-1, 1)
    
    all_edges = []
    all_edge_attrs = []
    num_nodes = dataset.size(1)
    
    for i in range(sample_size):
        offset = i * num_nodes
        upd_edge_index = edge_index + offset
        all_edges.append(upd_edge_index)
        
        if edge_attr is not None:
            all_edge_attrs.append(edge_attr.clone())  # Clone in case attributes are mutable
    
    concatenated_edge_index = torch.cat(all_edges, dim=1).to(device)
    
    if edge_attr is not None:
        concatenated_edge_attr = torch.cat(all_edge_attrs, dim=0).to(device)
    else:
        concatenated_edge_attr = None

    return concatenated_x, concatenated_edge_index, concatenated_t, concatenated_edge_attr
    

def sample_irregularly_per_ics(data, time, num_samples):
    ics, n_step, n_nodes, in_dim = data.shape
    num_samples = num_samples if num_samples > 0 else n_step
    sampled_data = torch.zeros((ics, num_samples, n_nodes, in_dim), dtype=data.dtype, device=data.device)
    sampled_times = torch.zeros((ics, num_samples), dtype=time.dtype, device=time.device)
    sampled_indices = torch.zeros((ics, num_samples), dtype=torch.int32, device=data.device)

    for i in range(ics):
        indices = torch.randperm(n_step)[:num_samples]  # Random unique indices
        indices = torch.sort(indices).values  # Optional: Sort indices to maintain order
        
        sampled_data[i] = data[i, indices, :, :]  # Sample separately for each ics
        sampled_times[i] = time[i, indices]
        sampled_indices[i] = indices

    return sampled_data, sampled_times


def save_acts(layers, folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    
    for l, layer in enumerate(layers):
        assert layer.cache_act is not None and layer.cache_preact is not None, 'Populate model activations before saving them'
        torch.save(layer.cache_preact, f"{folder_path}/cache_preact_{l}")
        torch.save(layer.cache_act, f"{folder_path}/cache_act_{l}")
        torch.save(layer.acts_scale_spline, f"{folder_path}/cache_act_scale_spline_{l}")
        torch.save(layer.multiplicative_mask, f"{folder_path}/cache_mult_mask_{l}")
        
        

def save_black_box_to_file(folder_path, cache_input, cache_output):    
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
        
    torch.save(cache_input, f'{folder_path}/cached_input')
    torch.save(cache_output, f'{folder_path}/cached_output')
 
 
def pruning(kan_acts, kan_preacts, kan_masks_mult, theta = 0.01, verbose=False):

    def get_acts_scale_spline(l_index):
        input_range = torch.sum(torch.abs(pruned_preacts[l_index]), dim=0)
        output_range_spline = torch.sum(torch.abs(pruned_acts[l_index]), dim=0)
        acts_scale_spline = output_range_spline / input_range
        return acts_scale_spline

    n_layers = len(kan_acts)
    pruned_acts = kan_acts.copy()
    pruned_preacts = kan_preacts.copy()
    pruned_masks_mult = kan_masks_mult.copy()

    for l in range(n_layers-1):
        # acts_scale_spline = get_acts_scale_spline(l)
        # I_lj, _ = torch.max(acts_scale_spline, dim=1)

        acts_scale_spline_next = get_acts_scale_spline(l+1)
        O_lj, _ = torch.max(acts_scale_spline_next, dim=0)

        pruned_nodes = (O_lj < theta).bool()
        remaining_indices = torch.where(~pruned_nodes)[0]
        remaining_acts = pruned_acts[l][:, remaining_indices, :]

        pruned_acts[l] = remaining_acts
        pruned_acts[l+1] = pruned_acts[l+1][:, :, remaining_indices]
        pruned_preacts[l+1] = pruned_preacts[l+1][:, remaining_indices]
        pruned_masks_mult[l] = pruned_masks_mult[l][remaining_indices]

        for j, is_pruned in enumerate(pruned_nodes):
            if is_pruned and verbose:
                print(f"Pruning node ({l},{j})")

    return pruned_acts, pruned_preacts, pruned_masks_mult



def get_pysr_model(
    n_iterations=100,
    binary_operators = ['+', '-', '*', '/'], 
    extra_sympy_mappings = {},
    unary_operators = None,
    **kwargs):
    
    extra_mapping = {"zero": lambda x: x*0}
    extra_mapping.update(extra_sympy_mappings)
    
    if unary_operators is None:
        unary_operators = [
            "exp",
            "sin",
            "neg",
            "square",
            "cube",
            "abs",
            "tan",
            "tanh",
            "log",
            "log1p",
            "zero(x) = 0*x"
        ]
    
    model = PySRRegressor(
        niterations=n_iterations,  # Number of iterations
        unary_operators=unary_operators,
        binary_operators=binary_operators,
        elementwise_loss="loss(prediction, target) = (prediction - target)^2",
        maxsize=7,
        maxdepth=5,
        verbosity=0,
        extra_sympy_mappings=extra_mapping,
        delete_tempfiles=True,
        temp_equation_file=True,
        tempdir='./pysr',
        progress=False,
        **kwargs
    )
    
    return model
    


def fit_acts_pysr(x, y, pysr_model = None, sample_size = -1, seed=42):
    rng = np.random.default_rng(seed)
    if pysr_model is None:
        model = get_pysr_model()
    else:
        model = pysr_model()
    
    if sample_size > 0 and sample_size < len(x):
        indices = rng.choice(len(x), sample_size, replace=False)
        x_sampled = x[indices]
        y_sampled = y[indices]
    else:
        x_sampled = x
        y_sampled = y
    
    model.fit(x_sampled, y_sampled)
    top_5_eq = model.equations_.nlargest(5, 'score')
    return top_5_eq 


def penalized_loss(y_true, y_pred, func_symb, alpha=0.01):
    mse = mean_squared_error(y_true, y_pred)
    complexity = count_ops(func_symb) 
    penalty = alpha * complexity
    return mse + penalty


def fit_params_scipy(x_train, y_train, func, func_name, x_symb = None, alpha=0.1, cut_threshold=1e-3):  
    if func_name == 'x' or func_name == 'neg':
        func_optim = lambda x, a, b: a*x + b  
        init_params = [1., 0.]
    elif func_name=='x^2':
        func_optim = lambda x, a, b, c: a*x**2 + b*x + c  
        init_params = [1., 0., 0]
    elif func_name == 'x^3':
        func_optim = lambda x, a, b, c, d: a*x**3 + b*x**2 + c*x + d  
        init_params = [1., 0., 0., 0.]
    elif (func_name == 'ln') and (np.any(x_train <= 0)):
        return 1e8, [], 0, lambda x: x*0
    else:
        func_optim = lambda x, a, b, c, d: c * func(a*x + b) + d
        init_params = [1., 0., 1., 0.]
        
    try:
        params, _ = curve_fit(func_optim, x_train, y_train, p0=init_params, nan_policy='omit')
    except RuntimeError:
        return 1e8, [], 0, lambda x: x*0
    
    if x_symb is None:
        x_symb = sp.Symbol('x0')    # This symbol must be x0 in order to work with the rest of the code
    
    if func_name == 'x' or func_name == 'neg':
        post_fun = params[0] * func(x_train) + params[1]
        fun_sympy = params[0] * x_symb + params[1]
    elif func_name == 'x^2':
        post_fun = params[0] * x_train**2 + params[1] * x_train + params[2]
        fun_sympy = params[0] * x_symb**2 + params[1] * x_symb + params[2]
    elif func_name == 'x^3':
        post_fun = params[0]*x_train**3 + params[1]*x_train**2 + params[2]*x_train + params[3] 
        fun_sympy = params[0] * x_symb**3 + params[1] * x_symb**2 + params[2] * x_symb + params[3]
    else:
        post_fun = params[2] * func(params[0]*x_train + params[1]) + params[3]
        fun_sympy = params[2] * SYMBOLIC_LIB_SYMPY[func_name](params[0] * x_symb + params[1]) + params[3]
    
    
    if np.any(np.isnan(post_fun)) or np.any(np.isinf(post_fun)):
        return 1e8, [], 0, lambda x: x*0
    
    fun_sympy_quantized = quantise(fun_sympy, cut_threshold, is_removing=True)
    mse = penalized_loss(y_train, post_fun, fun_sympy_quantized, alpha=alpha)
    return mse, params, fun_sympy_quantized, func_optim


def fit_acts_scipy(x, y, x_symb = None, alpha=0.1, cut_threshold = 1e-3):    
    scores = []
    for name, func in SYMBOLIC_LIB_NUMPY.items():
        mse, params, symb, func_optim = fit_params_scipy(x, y, func, name, x_symb = x_symb, alpha=alpha,
                                                         cut_threshold=cut_threshold)
        scores.append((symb, mse, params, func_optim))
    
    best_fun_sympy, _, best_params, best_func_optim  = min(scores, key=lambda x: x[1])    
    return best_fun_sympy, best_params, best_func_optim


def find_best_symbolic_func(x_train, y_train, x_val, y_val, alpha_grid, x_symb = None, sort_by='score',
                            cut_threshold=1e-3, fit_orig = False, **kwargs):
    
    if not fit_orig:
        results = []
        assert sort_by in ["score", "log_loss"], "Not supported sorting method"
        ascending = sort_by == "log_loss" 
        
        for alpha in alpha_grid:
            symb_func, params, func_optim = fit_acts_scipy(x_train, y_train, x_symb=x_symb, alpha=alpha,
                                                        cut_threshold=cut_threshold)
            val_mse = mean_squared_error(y_val, func_optim(x_val, *params))
            complexity = count_ops(symb_func)
            log_loss = np.log(val_mse)
            results.append((symb_func, complexity, log_loss, alpha))

        # Sort by complexity to compute finite difference derivative
        results.sort(key=lambda x: x[1])  # sort by complexity
        # top_equations = pd.DataFrame(results, columns=["symbolic_function", "complexity", "log_loss", "alpha"])
        scores = [(results[0][0], results[0][1], results[0][2], results[0][3], 0)]

        for k in range(1, len(results)):
            c2, c1 = results[k][1], results[k - 1][1]
            l2, l1 = results[k][2], results[k - 1][2]
            
            if c1==c2: continue

            dlogloss_dcomplexity = (l2 - l1) / (c2 - c1)
            score = -dlogloss_dcomplexity
            scores.append(
                (results[k][0], 
                results[k][1],
                results[k][2],
                results[k][3], 
                score
                )
            )

        
        top_equations = pd.DataFrame(scores, columns=["symbolic_function", "complexity", "log_loss", "alpha", "score"])
        top_equations = top_equations.sort_values(by=sort_by, ascending=ascending).reset_index(drop=True)
        return top_equations
    else:
        top_eq = suggest_symbolic(x_train, y_train, **kwargs)
        top_equations = pd.DataFrame(top_eq, columns=["symbolic_function"])
        return top_equations

def fit_layer(cached_act, cached_preact, symb_xs, mask_mult, val_ratio=0.2, seed=42, sample_size = -1, sort_by='score',
              cut_threshold=1e-3, alpha_grid=None, fit_orig=False, **kwargs):
    rng = np.random.default_rng(seed)
    
    if alpha_grid is None:
        alpha_grid = torch.logspace(-5, 1, steps=5)
    
    symb_layer_acts = []
    symbolic_functions = defaultdict(dict)
    top_equations = {}

    in_dim = cached_act.shape[2]
    out_dim = cached_act.shape[1]
    
    for j in range(out_dim):
        symb_out = 0 if not mask_mult[j] else 1.

        for i in range(in_dim):
            x = cached_preact[:, i].reshape(-1)
            y = cached_act[:, j, i].reshape(-1)
            
            if sample_size > 0 and sample_size < len(x):
                indices = rng.choice(len(x), sample_size, replace=False)
                x_sampled = x[indices]
                y_sampled = y[indices]
            else:
                x_sampled = x
                y_sampled = y

            x_train, x_val, y_train, y_val =  train_test_split(x_sampled, y_sampled, test_size=val_ratio, random_state=seed)

            top_eqs = find_best_symbolic_func(
                x_train, y_train, x_val, y_val,
                alpha_grid=alpha_grid,
                sort_by=sort_by,
                cut_threshold=cut_threshold,
                fit_orig=fit_orig,
                **kwargs
                # x_symb=symb_xs[i]
            )
            top_equations[(i, j)] = top_eqs

            symbolic_functions[j][i] = [str(expr) for expr in top_eqs["symbolic_function"]]
            
            best_symb_func = top_eqs["symbolic_function"][0]
            
            best_symb_func = best_symb_func.subs(sp.Symbol('x0'), symb_xs[i])
            
            if mask_mult[j]:
                symbolic_functions[j]["mode"] = 'mult'
                symb_out *= best_symb_func
            else:
                symb_out += best_symb_func
                symbolic_functions[j]["mode"] = 'add'
            
            

        symb_layer_acts.append(symb_out)

    return symb_layer_acts, symbolic_functions, top_equations


def fit_kan(kan_acts, kan_preacts, kan_masks_mult, symb_xs, model_path='./models', seed=42, sample_size = -1, sort_by='score', 
            verbose = False, cut_threshold = 1e-3, alpha_grid = None, fit_orig=False, **kwargs):
    
    start_time = time.time()
    n_layers = len(kan_acts)
    all_functions = {}
    top_5_save_path = f"{model_path}/top_eqs"
    os.makedirs(top_5_save_path, exist_ok=True)
    
    for l in range(n_layers):
        acts = kan_acts[l].cpu().detach().numpy()
        preacts = kan_preacts[l].cpu().detach().numpy()
        mask_mult = kan_masks_mult[l].cpu().detach().numpy()
        
        symb_xs, symb_functions, top_equations = fit_layer(
            cached_act=acts,
            cached_preact=preacts,
            symb_xs=symb_xs,
            mask_mult=mask_mult,
            seed=seed,
            sample_size=sample_size,
            sort_by=sort_by,
            fit_orig=fit_orig,
            alpha_grid=alpha_grid,
            **kwargs
        )
        
        all_functions[l] = symb_functions
        
        for k, df in top_equations.items():
            df.to_csv(f"{top_5_save_path}/top_equations({l}, {k[1]}, {k[0]}).csv")
    
    end_time = time.time()
    exec_time = end_time - start_time
    if verbose:
        print(f"Execution time: {exec_time:.6f} seconds")
        
    save_path = f"{model_path}/symb_functions.json"
    with open(save_path, "w") as f:
        json.dump(all_functions, f)
    
    out = []
    for symbx in symb_xs:
        out.append(
            quantise(sp.simplify(symbx), quantise_to=cut_threshold, is_removing=True) if count_ops(symbx) < 20 else symbx
        )
    
    return out, exec_time
                
        
def load_cached_data(cached_acts_path, cached_preacts_path, cached_mask_mult_path, device='cpu'):
    cached_act = torch.load(cached_acts_path, weights_only=False, map_location=torch.device(device)) # (batch_dim, out_dim, in_dim)
    cached_preact = torch.load(cached_preacts_path, weights_only=False, map_location=torch.device(device)) # (batch_dim, in_dim)  
    if os.path.isfile(cached_mask_mult_path):
        cached_mask_mult = torch.load(cached_mask_mult_path, weights_only=False, map_location=torch.device(device))
    else:
        cached_mask_mult = torch.tensor([False] * cached_act.shape[1], dtype=torch.bool, device = torch.device(device))
    return cached_act, cached_preact, cached_mask_mult


def get_kan_arch(n_layers, model_path):
    act_name_prefix = 'cache_act'
    preact_name_prefix = 'cache_preact'
    mask_mult_prefix = 'cache_mult_mask'
    acts, preacts, masks_mult = [], [], []
    for l in range(n_layers):
        cached_acts, cached_preacts, mask_mult = load_cached_data(
            cached_acts_path = f'{model_path}/cached_acts/{act_name_prefix}_{l}',
            cached_preacts_path = f'{model_path}/cached_acts/{preact_name_prefix}_{l}',
            cached_mask_mult_path= f'{model_path}/cached_acts/{mask_mult_prefix}_{l}'
        )            
        acts.append(cached_acts)
        preacts.append(cached_preacts)
        masks_mult.append(mask_mult)

    return acts, preacts, masks_mult


def fit_model(depth_h, depth_g, model_path, theta=0.1, message_passing=True, include_time=False, seed=42, sample_size=-1,
              sort_by='score', verbose=False, cut_threshold = 1e-3, alpha_grid = None, fit_orig=False, **kwargs):
    # G_net
    cache_acts, cache_preacts, cache_masks_mult = get_kan_arch(n_layers=depth_g, model_path=f'{model_path}/g_net')
    pruned_acts, pruned_preacts, pruned_masks_mult = pruning(cache_acts, cache_preacts, cache_masks_mult, theta=theta, verbose=verbose)    
    
    if verbose:
        print("Fitting G_Net...")
    
    if fit_orig:
        print("Fit orig GNet")
    symb_g, exec_time_g = fit_kan(
        pruned_acts,
        pruned_preacts,
        kan_masks_mult=pruned_masks_mult,
        symb_xs=[sp.Symbol('x_i'), sp.Symbol('x_j')],
        model_path=f"{model_path}/g_net",
        seed=seed,
        sample_size=sample_size,
        sort_by=sort_by,
        verbose=verbose,
        cut_threshold=cut_threshold,
        fit_orig=fit_orig,
        alpha_grid=alpha_grid,
        **kwargs
    )
    
    symb_g = symb_g[0]  # Univariate functions
    # H_Net
    cache_acts, cache_preacts, cache_masks_mult = get_kan_arch(n_layers=depth_h, model_path=f'{model_path}/h_net')
    pruned_acts, pruned_preacts, pruned_masks_mult = pruning(cache_acts, cache_preacts, cache_masks_mult, theta=theta, verbose=verbose)
    
    aggr_term = sp.Symbol(r'\sum_{j}( ' + str(symb_g) + ')')
    if message_passing:
        symb_h_in = [sp.Symbol('x_i'), aggr_term]
    else:
        symb_h_in = [sp.Symbol('x_i')]
        
    if include_time:
        symb_h_in += [sp.Symbol('t')]
    
    if verbose:
        print()
        print("Fitting H_Net...")
        
    if fit_orig:
        print("Fit orig HNet")
    
    symb_h, exec_time_h = fit_kan(
        pruned_acts,
        pruned_preacts,
        kan_masks_mult=pruned_masks_mult,
        symb_xs=symb_h_in,
        model_path=f"{model_path}/h_net",
        sample_size=sample_size,
        seed=seed,
        sort_by=sort_by,
        verbose=verbose,
        fit_orig=fit_orig,
        cut_threshold=cut_threshold,
        alpha_grid=alpha_grid,
        **kwargs
    )
    symb_h = symb_h[0]  # Univariate functions
    
    out_formula = symb_h if message_passing else symb_h + aggr_term 
    
    return out_formula, symb_g, symb_h, (exec_time_g + exec_time_h) / 2


def fit_black_box(cached_input, cached_output, symb_xs, pysr_model = None, sample_size=-1, verbose = False):
    start_time = time.time()
    in_dim = cached_input.size(1)
    out_dim = cached_output.size(1)

    x = cached_input.detach().cpu().numpy().reshape(-1, in_dim)
    y = cached_output.detach().cpu().numpy().reshape(-1, out_dim)

    top_5_eq = fit_acts_pysr(x, y, pysr_model=pysr_model, sample_size=sample_size)
    symb_func = sp.sympify(top_5_eq["sympy_format"].iloc[0])

    subs_dict = {sp.Symbol(f'x{i}'): symb_xs[i] for i in range(len(symb_xs))}

    symb_func = symb_func.subs(subs_dict)
    end_time = time.time()
    exec_time = end_time - start_time
    if verbose:
        print(f"Execution time: {exec_time:.6f} seconds")
    
    return sp.simplify(symb_func), top_5_eq[["complexity", "loss", "score", "sympy_format"]], exec_time



def fit_mpnn(model_path, device='cpu', pysr_model = None, sample_size=-1, message_passing=True, include_time=False, verbose=False, exploit_graph_struct = True):
    if exploit_graph_struct:
        # G_Net
        cached_input = torch.load(f'{model_path}/g_net/cached_data/cached_input', weights_only=False, map_location=torch.device(device))
        cached_output = torch.load(f'{model_path}/g_net/cached_data/cached_output', weights_only=False, map_location=torch.device(device))
        if verbose:
            print("Fitting G_Net...")
            
        symb_g, top_5_eqs_g, exec_time_g = fit_black_box(
            cached_input, cached_output, 
            symb_xs=[sp.Symbol('x_i'), sp.Symbol('x_j')], 
            pysr_model=pysr_model,
            sample_size=sample_size,
            verbose=verbose
        )
        
        top_5_eqs_g.to_csv(f"{model_path}/top_5_equations_g.csv")
    else:
        symb_g = sp.S(0.)
        exec_time_g = 0
    
    # H_Net
    cached_input = torch.load(f'{model_path}/h_net/cached_data/cached_input', weights_only=False, map_location=torch.device(device))
    cached_output = torch.load(f'{model_path}/h_net/cached_data/cached_output', weights_only=False, map_location=torch.device(device))

    aggr_term = sp.Symbol(r'\sum_{j}( ' + str(symb_g) + ')')
    
    if message_passing:
        symb_h_in = [sp.Symbol('x_i'), aggr_term]
    else:
        symb_h_in = [sp.Symbol('x_i')]
        
    if include_time:
        symb_h_in += [sp.Symbol('t')]
    
    if verbose:
        print()
        print("Fitting H_Net...")
        
    symb_h, top_5_eqs_h, exec_time_h = fit_black_box(
        cached_input, 
        cached_output, 
        symb_xs=symb_h_in, 
        pysr_model=pysr_model,
        sample_size=sample_size,
        verbose=verbose
    )
        
    top_5_eqs_h.to_csv(f"{model_path}/top_5_equations_h.csv")

    out_formula = symb_h if message_passing else symb_h + aggr_term 
    
    return out_formula, symb_g, symb_h, (exec_time_g + exec_time_h)/2


def fit_black_box_from_kan(
    model_path, 
    depth_g, 
    depth_h, 
    device='cpu', 
    theta=0.1, 
    pysr_model = None, 
    sample_size=-1,
    message_passing=True,
    include_time=False,
    verbose=False
    ):
    #G_Net
    cache_acts, cache_preacts, cached_masks_mult = get_kan_arch(n_layers=depth_g, model_path=f'{model_path}/g_net')
    pruned_acts, pruned_preacts, _ = pruning(cache_acts, cache_preacts, cached_masks_mult, theta=theta, verbose=verbose)

    input = pruned_preacts[0]
    output = pruned_acts[-1].sum(dim=2)

    if verbose:
        print("Fitting G_Net...")
        
    symb_g, top_5_eqs_g, exec_time_g = fit_black_box(
        input, 
        output, 
        symb_xs=[sp.Symbol('x_i'), sp.Symbol('x_j')], 
        pysr_model=pysr_model,
        sample_size=sample_size,
        verbose=verbose
    )

    save_path = f"{model_path}/black-box"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    top_5_eqs_g.to_csv(f"{save_path}/top_5_equations_g.csv")

    #H_Net
    cache_acts, cache_preacts, cached_masks_mult = get_kan_arch(n_layers=depth_h, model_path=f'{model_path}/h_net')
    pruned_acts, pruned_preacts, _ = pruning(cache_acts, cache_preacts, cached_masks_mult, theta=theta, verbose=verbose)

    input = pruned_preacts[0]
    output = pruned_acts[-1].sum(dim=2)

    aggr_term = sp.Symbol(r'\sum_{j}( ' + str(symb_g) + ')')
    
    if message_passing:
        symb_h_in = [sp.Symbol('x_i'), aggr_term]
    else:
        symb_h_in = [sp.Symbol('x_i')]
    
    if include_time:
        symb_h_in += [sp.Symbol('t')]

    if verbose:
        print()
        print("Fitting H_Net...")
        
    symb_h, top_5_eqs_h, exec_time_h = fit_black_box(
        input, 
        output, 
        symb_xs=symb_h_in, 
        pysr_model=pysr_model,
        sample_size=sample_size,
        verbose=verbose
    )

    top_5_eqs_h.to_csv(f"{save_path}/top_5_equations_h.csv")

    out_formula = symb_h if message_passing else symb_h + aggr_term 
    
    return out_formula, symb_g, symb_h, (exec_time_g + exec_time_h)/2


def quantise(expr, quantise_to=0.01, is_removing=False):
    if isinstance(expr, sympy.Float):
        quant = expr.func(round(float(expr) / quantise_to) * quantise_to)
        if abs(quant) > 0.0 and is_removing:
            return expr
        return quant
    elif isinstance(expr, (sympy.Symbol, sympy.Integer)):
        name = str(expr)
        match = re.match(r'\\sum_\{[^}]*\}\((.*)\)', name)
        if match:
            inner_expr_str = match.group(1)
            try:
                # Convert inner string to sympy expression
                inner_expr = sympy.sympify(inner_expr_str)
                # Quantise inner expression
                quantised_inner = quantise(inner_expr, quantise_to, is_removing)
                # Rebuild symbol name
                new_name = re.sub(r'\(.*\)', f'({quantised_inner})', name)
                return sympy.Symbol(new_name)
            except (sympy.SympifyError, SyntaxError):
                return expr  # If parsing fails, return the symbol as-is
        else:
            return expr
    else:
        return expr.func(*[quantise(arg, quantise_to, is_removing) for arg in expr.args])
    
                
