
from magni.src.modules.compute_graph_magnitude import compute_magnitude_subgraphs
import networkx as nx
from magni.src.modules.compute_graph_magnitude import compute_magnitude_graph
from magni.src.modules.compare_graphs import choose_graph_metric
from magni.src.modules.utils import to_numpy
import numpy as np
import os
import json
import pandas as pd
from magni.src.modules.compute_graph_magnitude import median_heuristic

def prepare_mag(A, X, L, mag_fun=False):
    methods_mag = []
    mag_dict = []
    all_distfns = []
    dist_names = []

    g = nx.from_numpy_array(to_numpy(A))
    features_dict = {i: {"feature": X[i]} for i in range(X.shape[0])}
    nx.set_node_attributes(g, features_dict)

    for mag_method in ["cholesky", "spread"]:
        if mag_method == "spread":
            all_metrics = ["diffusion_distance", "heat_kernel_distance", "resistance_distance", "shortest_path_distance"]
        else:
            all_metrics = ["diffusion_distance", "heat_kernel_distance", "resistance_distance"]
        
        
        for metric in all_metrics:
            dist_fn = choose_graph_metric(metric, mode="structure")
            if mag_fun:
                mag, ts = compute_magnitude_graph(g, dist_fn=dist_fn, get_weights=False, method=mag_method, n_ts=10, scale_finding = "convergence")
            else:
                if metric == "diffusion_distance":
                    ts = [1]
                else:
                    ts = median_heuristic(dist_fn=dist_fn, G=g)
                    ts = [ts]
                mag, ts = compute_magnitude_graph(g, dist_fn=dist_fn, ts=ts, get_weights=False, n_ts=1, scale_finding = "convergence", 
                                                method=mag_method)
            mag_dict.append([mag, ts])
            all_distfns.append(dist_fn)
            methods_mag.append(mag_method)
            dist_names.append(metric)
    
    return mag_dict, all_distfns, methods_mag, dist_names

def compute_mag_diff(X, A, L, model, dist_fn, ts, mag, method):
    X_pool, A_pool, _ = model([X, A])
    mag_dif = compute_mag_diff_nt(X, A, L, X_pool, A_pool, None, dist_fn, ts, mag, method)
    return mag_dif


def compute_mag_diff_nt(X, A, L, X_pool, A_pool, L_pool, dist_fn, ts, mag, method):
    g_pool = nx.from_numpy_array(to_numpy(A_pool))
    features_dict = {i: {"feature": X[i]} for i in range(X_pool.shape[0])}
    nx.set_node_attributes(g_pool, features_dict)
    msub, tsub = compute_magnitude_subgraphs(g_pool, dist_fn=dist_fn, ts=ts, get_weights=False, method=method)
    if len(mag) > 1:
        mag_dif = np.sum(np.abs(mag - msub))/ts[-1]
    else:
        mag_dif = np.sum(np.abs(mag - msub))
    return mag_dif

def json_to_df(json_file):
    """
    Convert a JSON file to a DataFrame.
    :param json_file: path to the JSON file
    :return: DataFrame
    """
    with open(json_file, "r") as f:
        data = json.load(f)
    data = data["results"]
    df = pd.DataFrame(data)
    return df

