import numpy as np
import tensorflow as tf
from spektral.utils import laplacian

from magni.src.modules.logging import logdir
from magni.src.modules.losses import quadratic_loss
from magni.src.modules.utils import to_numpy
from magni.src.spectral_similarity.training import load_data

from magni.src.modules.structure_difference import wasserstein_distance_eigenvalues, wasserstein_distance_eigenvalues_power
from magni.src.modules.train_mag_utils import prepare_mag, compute_mag_diff_nt

import os
import json

cpu=True
if cpu:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    tf.config.experimental.set_visible_devices([], 'GPU') # Disable GPU
    physical_devices = tf.config.list_physical_devices("CPU")
else:
    physical_devices = tf.config.list_physical_devices("GPU")
    if len(physical_devices) > 0:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)

def run_experiment(name, method, pooling, runs, ratio=0.5):
    log_dir = logdir(name)
    mag_fun = False

    # Load data
    X, A, L = load_data(name)

    mag_dict, all_distfns, methods_mag, dist_names = prepare_mag(A, X, L, mag_fun=False)

    # Run main
    results = []
    results_wasserstein = []
    results_wasserstein_normalized = []
    results_spectral = []
    results_spectral_normalized = []
    all_mag_results = {}

    for ind, metric in enumerate(all_distfns):
        all_mag_results[ind] = []

    for r in range(runs):
        print("{} of {}".format(r + 1, runs))

        np.random.seed(r)
        _, X_new, A_pool, S = pooling(X, A)

        # Convert selection mask to selection matrix
        S = to_numpy(S)
        if S.ndim == 1:
            S = np.eye(S.shape[0])[:, S.astype(bool)]

        X_pool = S.T.dot(X_new)
        L_pool = laplacian(A_pool)
        loss = quadratic_loss(X, X_pool, L, L_pool)

        wd_normalized, sp_normalized = wasserstein_distance_eigenvalues_power(L, L_pool, A, A_pool, power=-0.5)
        wd, sp = wasserstein_distance_eigenvalues(L, L_pool)

        results_wasserstein.append(float(wd))
        results_wasserstein_normalized.append(float(wd_normalized))
        results_spectral.append(float(sp))
        results_spectral_normalized.append(float(sp_normalized))

        results.append(float(loss))
        
        for ind, dist_fn in enumerate(all_distfns):
            mag, ts = mag_dict[ind]
            mag_method = methods_mag[ind]
            mag_diff = compute_mag_diff_nt(X, A, L, X_pool, A_pool, L_pool, dist_fn, ts, mag, method=mag_method)
            all_mag_results[ind].append(float(mag_diff))

    experiment_data = {
        "dataset": name,
        "method": method,
        "experiment": "spectral_similarity",
        "runs": runs,
        "ratio": ratio,
        "learning_rate": None,
        "es_patience": None,
        "es_tol": None,
        "loss": "not_trainable",
        "results": {
            "loss": results,
            "wasserstein_distance": results_wasserstein,
            "wasserstein_distance2": results_wasserstein_normalized,
            "spectral_distance": results_spectral,
            "spectral_distance2": results_spectral_normalized,
        }
    }

    for ind, dist_fn in enumerate(all_distfns):
        mag_method = methods_mag[ind]
        dist_n = dist_names[ind]
        if mag_method == "spread":
            mag_method = "spread"
        else:
            mag_method = "mag"
        
        experiment_data["results"][f"{mag_method}_diff_{dist_n}"] = all_mag_results[ind]

    # Save the experimental data to a JSON file
    json_file_path = os.path.join(log_dir, f"{method}_{name}_experiment.json")
    
    with open(json_file_path, "w") as json_file:
        json.dump(experiment_data, json_file, indent=4)
    print(f"Experiment data saved to {json_file_path}")

    np.savez(
        log_dir + "{}_{}_matrices.npz".format(method, name),
        X=to_numpy(X),
        A=to_numpy(A),
        X_pool=to_numpy(X_pool),
        A_pool=to_numpy(A_pool),
        S=S,
    )