import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from spektral.layers import ops
from spektral.utils import laplacian, sp_matrix_to_sp_tensor
from magni.src.modules.data import make_dataset
from magni.src.modules.logging import logdir
from magni.src.modules.losses import quadratic_loss_tf
from magni.src.modules.transforms import normalize_point_cloud
from magni.src.modules.utils import to_numpy


from magni.src.modules.structure_difference import wasserstein_distance_eigenvalues, wasserstein_distance_eigenvalues_power
from magni.src.modules.train_mag_utils import prepare_mag, compute_mag_diff

import os
import json

cpu=True
if cpu:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    tf.config.experimental.set_visible_devices([], 'GPU') # Disable GPU
    tf.config.run_functions_eagerly(True)
    physical_devices = tf.config.list_physical_devices("CPU")
else:
    physical_devices = tf.config.list_physical_devices("GPU")
    if len(physical_devices) > 0:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)

def laplacian_tf(A):
    D = ops.degree_matrix(A, return_sparse_batch=True)
    if K.is_sparse(A):
        A = A.__mul__(-1)
    else:
        A = -A

    return tf.sparse.add(D, A)

def compute_loss(X, A, L, model):
    """
    Evaluate the quadratic loss on the given inputs.
    :param X: node features
    :param A: adjacency matrix
    :param L: laplacian matrix
    :param model: GNN with pooling layer
    :return: scalar loss value
    """
    X_pool, A_pool, _ = model([X, A])
    L_pool = laplacian_tf(A_pool)
    return quadratic_loss_tf(X, X_pool, L, L_pool)


def compute_loss_final(X, A, L, model):
    """
    Evaluate the quadratic loss on the given inputs.
    :param X: node features
    :param A: adjacency matrix
    :param L: laplacian matrix
    :param model: GNN with pooling layer
    :return: scalar loss value
    """
    X_pool, A_pool, _ = model([X, A])
    L_pool = laplacian_tf(A_pool)
    wd, sp = wasserstein_distance_eigenvalues(L, L_pool)
    wd_n, sp_n = wasserstein_distance_eigenvalues_power(L, L_pool, A, A_pool, power=-0.5)
    return quadratic_loss_tf(X, X_pool, L, L_pool), wd, sp, wd_n, sp_n

def load_data(name):
    A, X, _ = make_dataset(name, seed=0)
    X = normalize_point_cloud(X)
    L = laplacian(A)
    _, eigvec = np.linalg.eigh(L.toarray())

    # Final graph signal is composed of node features (normalized) and the first 10 eigv
    X = np.concatenate([X, eigvec[:, :10]], axis=-1)

    return X, A, L


def main(X, A, L, create_model, learning_rate, es_patience, es_tol):
    K.clear_session()

    # Create model and set up traning
    model = create_model(k=int(np.ceil(X.shape[-2] / 2)), ratio=0.5)
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    @tf.function
    def train_step(X, A):
        with tf.GradientTape() as tape:
            main_loss = compute_loss(X, A, L, model)  # Main loss
            loss_value = main_loss + sum(model.losses)  # Auxiliary losses of the model

        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        return main_loss

    # Fit model
    patience = es_patience
    best_loss = np.inf
    best_weights = None
    ep = 0
    while True:
        ep += 1
        loss = train_step(X, A)

        if loss + es_tol < best_loss:
            best_loss = loss
            patience = es_patience
            best_weights = model.get_weights()
            print("Epoch {} - New best loss: {:.4e}".format(ep, best_loss))
        else:
            patience -= 1
            if patience == 0:
                break

    model.set_weights(best_weights)
    return model

def run_experiment(
    name, method, create_model, learning_rate, es_patience, es_tol, runs, ratio=0.5
):
    log_dir = logdir(name)

    mag_fun = False
    
    # Load data
    X, A, L = load_data(name)

    X = tf.convert_to_tensor(X.astype("f4"))
    A = sp_matrix_to_sp_tensor(A.astype("f4"))
    L = sp_matrix_to_sp_tensor(L.astype("f4"))

    mag_dict, all_distfns, methods_mag, dist_names = prepare_mag(A, X, L, mag_fun=False)

    # Run main
    results = []
    results_wasserstein = []
    results_wasserstein_normalized = []
    all_mag_results = {}
    results_spectral = []
    results_spectral_normalized = []

    for ind, metric in enumerate(all_distfns):
        all_mag_results[ind] = []

    for r in range(runs):
        print("{} of {}".format(r + 1, runs))
        model = main(X, A, L, create_model, learning_rate, es_patience, es_tol)
        
        loss1, loss2, loss3, loss4, loss5 = compute_loss_final(X, A, L, model)

        # Convert EagerTensors to Python types
        results.append(float(loss1.numpy()))
        results_wasserstein.append(float(loss2))
        results_spectral.append(float(loss3))
        results_wasserstein_normalized.append(float(loss4))
        results_spectral_normalized.append(float(loss5))

        for ind, dist_fn in enumerate(all_distfns):
            mag, ts = mag_dict[ind]
            mag_method = methods_mag[ind]
            mag_diff = compute_mag_diff(X, A, L, model, dist_fn, ts, mag, method=mag_method)
            all_mag_results[ind].append(float(mag_diff))

    experiment_data = {
        "dataset": name,
        "method": method,
        "experiment": "spectral_similarity",
        "runs": runs,
        "ratio": ratio,
        "learning_rate": learning_rate,
        "es_patience": es_patience,
        "es_tol": es_tol,
        "loss": "quadratic_spectral_loss",
        "results": {
            "loss": results,
            "wasserstein_distance": results_wasserstein,
            "wasserstein_distance2": results_wasserstein_normalized,
            "spectral_distance": results_spectral,
            "spectral_distance2": results_spectral_normalized,
        }
    }

    for ind, dist_fn in enumerate(all_distfns):
        mag_method = methods_mag[ind]
        dist_n = dist_names[ind]
        if mag_method == "spread":
            mag_method = "spread"
        else:
            mag_method = "mag"
        
        experiment_data["results"][f"{mag_method}_diff_{dist_n}"] = all_mag_results[ind]

    # Save the experimental data to a JSON file
    json_file_path = os.path.join(log_dir, f"{method}_{name}_experiment.json")
    
    with open(json_file_path, "w") as json_file:
        json.dump(experiment_data, json_file, indent=4)
    print(f"Experiment data saved to {json_file_path}")

    # Run trained model to get pooled graph
    X_pool, A_pool, S = model([X, A])

    # Convert selection mask to selection matrix
    S = to_numpy(S)
    if S.ndim == 1:
        S = np.eye(S.shape[0])[:, S.astype(bool)]

    # Save data for plotting
    np.savez(
        log_dir + "{}_{}_matrices.npz".format(method, name),
        X=to_numpy(X),
        A=to_numpy(A),
        X_pool=to_numpy(X_pool),
        A_pool=to_numpy(A_pool),
        S=S,
    )