import numpy as np
import tensorflow as tf
from spektral.data import Loader
from spektral.layers import ops
from tensorflow.keras.layers import Lambda

from magni.src.graph_classification.training import load_dataset, main
from magni.src.models.classifiers import MfreeModel, MfreeModelGIN
import os
import json
from magni.src.modules.loaders import MFreeLoader
import time
import pandas as pd
from magni.src.modules.dataset_to_structure_difference import graph_structure_eval_trainable

def downsampling(inputs):
    X, S = inputs
    return ops.modal_dot(S, X, transpose_a=True)

if tf.config.list_physical_devices('GPU'):
    physical_devices = tf.config.list_physical_devices("GPU")
    if len(physical_devices) > 0:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    tf.config.experimental.set_visible_devices([], 'GPU')  # Disable GPU
    physical_devices = tf.config.list_physical_devices("CPU")

def run_experiments(runs, pooling, dataset_name, learning_rate, batch_size, patience, method, splits="stratified", 
                    model_name = None, get_magdiff=False, save_predictions=True, ratio = 0.5):
    # Data
    dataset, dataset_te, data_va, dataset_tr = load_dataset(dataset_name, method)
    
    def pooling_transform(graph):
        graph.x, graph.a, graph.a_1, graph.s = pooling([graph.x], [graph.a], ratio=ratio)
        graph.x, graph.a, graph.a_1, graph.s = (
            graph.x[0],
            graph.a[0],
            graph.a_1[0],
            graph.s[0],
        )
        return graph

    time_start = time.time()
    if dataset is not None:
        dataset.apply(pooling_transform)
    if dataset_te is not None:
        dataset_te.apply(pooling_transform)
    if data_va is not None:
        data_va.apply(pooling_transform)
    if dataset_tr is not None:
        dataset_tr.apply(pooling_transform)
    
    time_stop = time.time()

    if dataset is not None:
        if (dataset.name in ["ENZYMES", "PROTEINS", "NCI1", "IMDB-MULTI", "MUTAG"]) and (model_name in ["GNN", "GIN"]):
            get_magdiff = True

    time_restart = time.time()

    # Model
    def create_model(n_out, model_name=model_name, **kwargs):
        pool = Lambda(downsampling)
        if model_name == "GNN":
            model = MfreeModel(n_out, pool)
        elif model_name == "GIN":
            model = MfreeModelGIN(n_out, pool)
        else:
            raise ValueError(f"Unknown model: {model_name}")

        return model
    log_dir = f"./results/{dataset_name}/"

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    results = []
    for r in range(runs):
        print("{} of {}".format(r + 1, runs))
        if (get_magdiff) and (r==0):
            best_te_loss, best_te_acc, target_full, predictions_full, all_pooled_graphs = main(
                    dataset,
                    create_model,
                    MFreeLoader,
                    learning_rate,
                    batch_size,
                    patience,
                    dataset_tr=dataset_tr,
                    dataset_te=dataset_te,
                    splits=splits,
                    run=r,
                    method=method,
                    dataset_va=data_va,
                    save_predictions=save_predictions,
                    ratio = ratio,
                    return_pooled_graphs=get_magdiff
                )
            results.append([best_te_loss, best_te_acc])
        else:
            best_te_loss, best_te_acc, target_full, predictions_full = main(
                    dataset,
                    create_model,
                    MFreeLoader,
                    learning_rate,
                    batch_size,
                    patience,
                    dataset_tr=dataset_tr,
                    dataset_te=dataset_te,
                    splits=splits,
                    run=r,
                    method=method,
                    dataset_va=data_va,
                    save_predictions=save_predictions,
                    ratio = ratio
                )
            results.append([best_te_loss, best_te_acc])

        if save_predictions:
            if (ratio == 0.5) and (not get_magdiff):
                predictions_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_predictions_{r}.csv")
                targets_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_targets_{r}.csv")
            else:
                predictions_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_predictions_{r}_ratio_{str(round(ratio,3))}.csv")
                targets_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_targets_{r}_ratio_{str(round(ratio,3))}.csv")

            targets = pd.DataFrame(target_full)
            predictions = pd.DataFrame(predictions_full)
            targets.to_csv(targets_file_path, index=False)
            predictions.to_csv(predictions_file_path, index=False)

    time_end = time.time()

    time_diff = time_end - time_start - (time_restart - time_stop)

    experiment_data = {
        "dataset": dataset_name,
        "method": method,
        "experiment": "graph_classification",
        "model": model_name,
        "runs": runs,
        "split_strategy": splits,
        "time": time_diff,
        "time_per_run": time_diff / runs,
        "learning_rate": learning_rate,
        "es_patience": patience,
        "batch_size": batch_size,
        "loss": "categorical_crossentropy",
        "results": {
            "loss": [float(r[0]) for r in results],
            "accuracy": [float(r[1]) for r in results],
        },
        "ratio": ratio
    }
    
    if (ratio == 0.5) and (not get_magdiff):
        json_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}.json")
        mag_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_mag.json")
        path_to_save_graphs = os.path.join(log_dir, f"graphs/{model_name}_{method}_{dataset_name}/")
        if not os.path.exists(path_to_save_graphs):
            os.makedirs(path_to_save_graphs)
    else:
        json_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_ratio_{str(round(ratio,3))}.json")
        mag_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_ratio_{str(round(ratio,3))}_mag.json")
        path_to_save_graphs = os.path.join(log_dir, f"graphs_{str(round(ratio,3))}/{model_name}_{method}_{dataset_name}/")
        if not os.path.exists(path_to_save_graphs):
            os.makedirs(path_to_save_graphs)

    all_mag_results = {}
    if (get_magdiff):
        if not (dataset_name in ["ENZYMES", "MUTAG"]):
            path_to_save_graphs = None
        all_mag_results["dataset"] = graph_structure_eval_trainable(all_pooled_graphs, path_to_save_graphs=path_to_save_graphs, nt=True)
        with open(mag_file_path, "w") as json_file:
            json.dump(all_mag_results, json_file, indent=4)
        print(f"Magni results saved to {mag_file_path}")
    
    with open(json_file_path, "w") as json_file:
        json.dump(experiment_data, json_file, indent=4)
    print(f"Experiment data saved to {json_file_path}")

    avg_results = np.mean(results, axis=0)
    std_results = np.std(results, axis=0)

    print(
        "{} - Test loss: {:.4f} +- {:.4f} - Test accuracy: {:.4f} +- {:.4f}".format(
            dataset_name, avg_results[0], std_results[0], avg_results[1], std_results[1]
        )
    )

    return avg_results, std_results

def results_to_file(dataset, method, avg_results, std_results):
    filename = f"./results/{dataset}/{dataset}_result.csv"
    with open(filename, "a") as f:
        line = "{}, {:.4f} +- {:.4f}, {:.4f} +- {:.4f}\n".format(
            method, avg_results[0], std_results[0], avg_results[1], std_results[1]
        )
        f.write(line)
