import numpy as np
import tensorflow as tf
from spektral.datasets import ModelNet, TUDataset, OGB
from spektral.transforms import OneHotLabels
from tensorflow.keras import backend as K
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.optimizers import Adam
import json
from spektral.layers.base import GraphMasking

from magni.src.modules.transforms import Float, NormalizeSphere, RemoveEdgeFeats
import os 
from magni.src.modules.data import TorchDataset

from magni.src.modules.make_splits import make_splits 
from ogb.graphproppred import GraphPropPredDataset
import time
import pandas as pd
from torch_geometric.datasets import LRGBDataset
from magni.src.modules.dataset_to_structure_difference import graph_structure_eval_trainable
import pandas as pd
from magni.src.modules.make_splits import make_inner_split

if tf.config.list_physical_devices('GPU'):
    physical_devices = tf.config.list_physical_devices("GPU")
    if len(physical_devices) > 0:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    tf.config.experimental.set_visible_devices([], 'GPU')  # Disable GPU
    physical_devices = tf.config.list_physical_devices("CPU")
    tf.config.run_functions_eagerly(True)
    

def load_dataset(name, method=""):

    dataset = dataset_tr = dataset_va = dataset_te = None

    if name in ModelNet.available_datasets():
        dataset = ModelNet(name, transforms=[NormalizeSphere(), Float()])
        dataset_te = ModelNet(name, test=True, transforms=[NormalizeSphere(), Float()])
        filter_at = 5000
        dataset.filter(lambda g: g.n_nodes <= filter_at)
        dataset_te.filter(lambda g: g.n_nodes <= filter_at)
        to_transform = [0,1]
    elif ((name in TUDataset.available_datasets())):
        dataset = TUDataset(name, transforms=[Float()])
        dataset_te = None
        dataset_va = None
        to_transform = [0]
    elif name in ['ogbg-molhiv']:
        ogb_dataset = GraphPropPredDataset(name = name)
        dataset = OGB(ogb_dataset, transforms=[Float()])
        dataset.name = name
        for g in dataset:
            g.y = g.y[0]#[0, 0, :]  
        # Train/test split
        idx = ogb_dataset.get_idx_split()
        idx_tr, idx_va, idx_te = idx["train"], idx["valid"], idx["test"]
        dataset_tr = dataset[idx_tr]
        dataset_va = dataset[idx_va]
        dataset_te = dataset[idx_te]
        to_transform = [1,2,3]
    else:
        raise ValueError(f"Unknown dataset {name}")
    
    data_list = [dataset, dataset_te, dataset_va, dataset_tr]

    data_to_transform = [data_list[t] for t in to_transform]
    
    #if dataset is not None:
    print("Dataset name: ", name)
    g = data_to_transform[0][0]
    if g.a is not None:
        print("Adjacency matrix shape: ")
        print(data_to_transform[0][0].a.shape)
    else:
        print("Adjacency matrix is None!")
        raise ValueError("Adjacency matrix is None!")
    if g.x is not None:
        print("Feature matrix shape: ")
        print(data_to_transform[0][0].x.shape)
    else:
        print("Feature matrix is None")
    if g.y is not None:
        print("Label matrix shape: ")
        print(data_to_transform[0][0].y.shape)
        print(data_to_transform[0][0].y)
    else:
        print("Label matrix is None!")
        raise ValueError("Label matrix is None!")
    
    for dat in data_to_transform:
        dat.apply(RemoveEdgeFeats())

        if dat.n_labels == 1:
            labels = dat.map(lambda g: g.y, reduce=np.unique)
            dat.apply(OneHotLabels(labels=labels))

        dat.apply(RemoveEdgeFeats())

        dim_mismatch = 0
        if (dat[0].x is not None):
            for g in dat:
                if g.x.shape[1] != dat[0].x.shape[1]:
                    dim_mismatch += 1
                    print("Mismatch in dimensions")
            print("Graphs with different dimensions:", dim_mismatch)

        if (dat[0].x is None) | (dim_mismatch > 0):
            for g in dat:
                g.x = np.sum(g.a, axis=1).reshape(-1, 1) if g.a is not None else np.zeros((g.n_nodes, 1))

    if len(data_to_transform) == 3:
        return dataset, data_to_transform[0], data_to_transform[1], data_to_transform[2]
    elif len(data_to_transform) == 4:
        return data_to_transform
    elif len(data_to_transform) == 2:
        return dataset, data_to_transform[0], data_to_transform[1], None
    elif len(data_to_transform) == 1:
        return data_to_transform[0], None, None, None
    else:
        raise ValueError("No dataset found")

def split_dataset(dataset, dataset_tr=None, dataset_te=None, dataset_va=None, splits="random", run=0):

    if (dataset_tr is None) | (dataset_te is None) | (dataset_va is None):
        if splits == "random":
            np.random.seed(run)
            if dataset_va is not None:
                print("Using a predefined validation set")
                idx_te = None
                dataset_tr = dataset
            elif dataset_te is not None:
                l_data = len(dataset)
                idxs = np.random.permutation(l_data)
                idx_tr, idx_va = np.split(idxs, [int(0.8 * l_data)])
                idx_te = None
                dataset_tr = dataset[idx_tr]
                dataset_va = dataset[idx_va]
            else:
                l_data = len(dataset)
                idxs = np.random.permutation(l_data)
                split_va, split_te = int(0.8 * l_data), int(0.9 * l_data)
                idx_tr, idx_va, idx_te = np.split(idxs, [split_va, split_te])

                dataset_tr = dataset[idx_tr]
                dataset_va = dataset[idx_va]
                dataset_te = dataset[idx_te]
        elif (splits == "stratified") or (splits == "k_fold"):
            labels = np.array([graph.y for graph in dataset])

            if labels.ndim > 1 and labels.shape[1] > 1:
                labels = np.argmax(labels, axis=1).astype(int)
            
            if dataset_va is not None:
                print("Using a predefined validation set")
                idx_te = None
                dataset_tr = dataset
            elif dataset_te is not None:
                l_data = len(dataset)
                np.random.seed(run)
                if splits == "stratified":
                    idx_tr, idx_va = make_inner_split(np.array([int(ni) for ni in range(len(dataset))]), labels, seed=run)
                elif splits == "k_fold":
                    idx_tr, idx_va = make_inner_split(np.array([int(ni) for ni in range(len(dataset))]), labels, seed=run, stratify=False)
                idx_te = None
                dataset_tr = dataset[idx_tr]
                dataset_va = dataset[idx_va]
            else:
                if splits == "stratified":
                    split_list = make_splits(np.array([int(ni) for ni in range(len(dataset))]), labels, outer_k=10, inner_k=None, holdout_test_size=0.1, seed=42)
                elif splits == "k_fold":
                    split_list = make_splits(np.array([int(ni) for ni in range(len(dataset))]), labels, outer_k=10, inner_k=None, holdout_test_size=0.1, seed=42, stratify=False)
                idx_tr, idx_va, idx_te = split_list[run][0], split_list[run][1], split_list[run][2]
                dataset_tr = dataset[idx_tr]
                dataset_va = dataset[idx_va]
                dataset_te = dataset[idx_te]

    return dataset_tr,  dataset_te, dataset_va


def main(
    dataset,
    create_model,
    loader_class,
    learning_rate,
    batch_size,
    patience,
    dataset_te=None,
    splits="random",
    run = 0,
    method=None,
    dataset_tr=None,
    dataset_va=None,
    save_predictions=False,
    ratio = 0.5,
    return_pooled_graphs=False,
    #mask=False,
):
    K.clear_session()
    if dataset is not None:
        N_avg = dataset.map(lambda g: g.n_nodes, reduce=lambda res: np.ceil(np.mean(res)))
        n_out = dataset.n_labels
    else:
        N_avg = dataset_tr.map(lambda g: g.n_nodes, reduce=lambda res: np.ceil(np.mean(res)))
        n_out = dataset_tr.n_labels
    
    k = int(ratio * N_avg)

    dataset_tr,  dataset_te, dataset_va = split_dataset(dataset=dataset, dataset_te=dataset_te, dataset_va=dataset_va, dataset_tr=dataset_tr,  splits=splits, run=run)

    # Create loaders
    loader_tr = loader_class(dataset_tr, batch_size=batch_size)
    loader_va = loader_class(dataset_va, batch_size=batch_size)
    loader_te = loader_class(dataset_te, batch_size=batch_size)

    if hasattr(loader_tr, "mask"):
        loader_tr.mask = True
        loader_va.mask = True
        loader_te.mask = True
    
    ## set the random seed for reproducibility
    np.random.seed(run)
    tf.random.set_seed(run)

    # Model
    model = create_model(n_out, k=k, ratio=ratio, mask=hasattr(loader_tr, "mask"))

    # Training
    opt = Adam(learning_rate=learning_rate)
    loss_fn = CategoricalCrossentropy()
    acc_fn = CategoricalAccuracy()

    input_signature = loader_tr.tf_signature()

    @tf.function(input_signature=input_signature, experimental_relax_shapes=True)
    def train_step(inputs, target):
        with tf.GradientTape() as tape:
            predictions = model(inputs, training=True)
            loss_value = loss_fn(target, predictions)  # Main loss
            loss_value += sum(model.losses)  # Auxiliary losses of the model
            acc_value = acc_fn(target, predictions)
        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))
        return loss_value, acc_value

    @tf.function(input_signature=input_signature, experimental_relax_shapes=True)
    def _eval_step(inputs, target):
        
        predictions = model(inputs, training=False)
        #print(predictions)
        return loss_fn(target, predictions), acc_fn(target, predictions)

    def _predict_step(inputs):
        predictions = model(inputs, training=False)
        return predictions
    
    def _pooling_step(inputs):
        a_list = model._pooling(inputs, training=False)
        return a_list
    
    def pool_graphs(dataset, batch_size, mask=False):
        loader = loader_class(dataset, batch_size=batch_size)
        if mask:
            loader.mask = True
        step = 0

        pooled_graphs = []

        while step < loader.steps_per_epoch:
            step += 1
            inputs, target = loader.__next__()
            outs = _pooling_step(inputs)
            pooled_graphs.append([outs,target])
        return pooled_graphs
    
    def evaluate(loader):
        output = []
        step = 0
        while step < loader.steps_per_epoch:
            step += 1
            inputs, target = loader.__next__()
            outs = _eval_step(inputs, target)
            output.append(outs)
        return np.mean(output, 0)
    
    def get_predictions(dataset, batch_size, mask=False):
        loader = loader_class(dataset, batch_size=batch_size)
        if mask:
            loader.mask = True
        step = 0
        target_full = []
        predictions_full = []

        while step < loader.steps_per_epoch:
            step += 1
            inputs, target = loader.__next__()
            prediction = _predict_step(inputs)

            target_full.append(target)
            predictions_full.append(prediction)

        target_full = tf.concat(target_full, axis=0)
        predictions_full = tf.concat(predictions_full, axis=0)
        return target_full, predictions_full

    # Training loop
    epoch = step = model_loss = model_acc = best_va_acc = 0
    best_va_loss = np.inf
    best_te_loss = best_te_acc = None
    patience_remain = patience
    es_tol = 1e-6

    #k=0
    for batch in loader_tr:
        #print("Batch: ", k)
        #k += 1
        outs = train_step(*batch)
        model_loss += outs[0]
        model_acc += outs[1]
        step += 1
        if step == loader_tr.steps_per_epoch:
            epoch += 1
            model_loss /= loader_tr.steps_per_epoch
            model_acc /= loader_tr.steps_per_epoch

            # Compute validation loss and accuracy
            va_loss, va_acc = evaluate(loader_va)

            # Check if loss improved for early stopping
            if va_loss + es_tol < best_va_loss:
                te_loss, te_acc = evaluate(loader_te)
                print(
                    "Epoch {} acc: {:.4f} va_acc: {:.4f} te_acc: {:.4f}".format(
                        epoch, model_acc, va_acc, te_acc
                    )
                )
                best_va_loss = va_loss
                best_va_acc = va_acc
                best_te_loss = te_loss
                best_te_acc = te_acc
                patience_remain = patience

                if save_predictions:
                    target_full, predictions_full = get_predictions(dataset_te, 1, mask=hasattr(loader_tr, "mask"))
                else:
                    target_full, predictions_full = None, None

            else:
                patience_remain -= 1
                if patience_remain == 0:
                    break
            model_loss = model_acc = step = 0

    if return_pooled_graphs:
        pooled_graphs = pool_graphs(dataset_tr, batch_size=1, mask=hasattr(loader_tr, "mask"))
        pool_te = pool_graphs(dataset_te, batch_size=1, mask=hasattr(loader_tr, "mask"))#pool_graphs(loader_te)
        pool_va = pool_graphs(dataset_va, batch_size=1, mask=hasattr(loader_tr, "mask"))
        all_pooled_graphs = pooled_graphs + pool_te + pool_va
        return best_te_loss, best_te_acc, target_full, predictions_full, all_pooled_graphs
    print("Loss: {} - Acc: {}".format(best_te_loss, best_te_acc))

    return best_te_loss, best_te_acc, target_full, predictions_full


def run_experiments(
    runs, create_model, loader_class, dataset_name, learning_rate, batch_size, patience, method, 
    splits="stratified", model_name = None, save_predictions = True, ratio = 0.5, get_magdiff=False,
):
    dataset, dataset_te, dataset_va, dataset_tr = load_dataset(dataset_name, method)
    
    log_dir = f"./results/{dataset_name}/"
    
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if (dataset_name in ["ENZYMES", "NCI1", "IMDB-MULTI"]) & (model_name in ["GIN", "GNN"]):
        print("Computing mag differences!")
        get_magdiff = True

    time_start = time.time()
    results = []
    #all_predictions = {}
    for r in range(runs):
        print("{} of {}".format(r + 1, runs))
        if (get_magdiff) and (r==0):
            best_te_loss, best_te_acc, target_full, predictions_full, all_pooled_graphs = main(
                dataset,
                create_model,
                loader_class,
                learning_rate,
                batch_size,
                patience,
                dataset_te=dataset_te,
                splits=splits,
                run=r,
                method=method,
                dataset_va=dataset_va,
                dataset_tr=dataset_tr,
                save_predictions=save_predictions,
                ratio=ratio,
                return_pooled_graphs=get_magdiff
            )
        else:
            best_te_loss, best_te_acc, target_full, predictions_full = main(
                    dataset,
                    create_model,
                    loader_class,
                    learning_rate,
                    batch_size,
                    patience,
                    dataset_te=dataset_te,
                    splits=splits,
                    run=r,
                    method=method,
                    dataset_va=dataset_va,
                    dataset_tr=dataset_tr,
                    save_predictions=save_predictions,
                    ratio=ratio,
                )
        results.append([best_te_loss, best_te_acc])

        if save_predictions:
            if (ratio == 0.5) and (not get_magdiff):
                predictions_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_predictions_{r}.csv")
                targets_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_targets_{r}.csv")
            else:
                predictions_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_predictions_{r}_ratio_{str(round(ratio,3))}.csv")
                targets_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_targets_{r}_ratio_{str(round(ratio,3))}.csv")
            targets = pd.DataFrame(target_full)
            predictions = pd.DataFrame(predictions_full)
            targets.to_csv(targets_file_path, index=False)
            predictions.to_csv(predictions_file_path, index=False)

    time_end = time.time()
    time_diff = time_end - time_start

    avg_results = np.mean(results, axis=0)
    std_results = np.std(results, axis=0)

    experiment_data = {
        "dataset": dataset_name,
        "method": method,
        "model": model_name,
        "experiment": "graph_classification",
        "runs": runs,
        "split_strategy": splits,
        "time": time_diff,
        "time_per_run": time_diff / runs,
        "learning_rate": learning_rate,
        "es_patience": patience,
        "batch_size": batch_size,
        "loss": "categorical_crossentropy",
        "results": {
            "loss": [float(r[0]) for r in results],
            "accuracy": [float(r[1]) for r in results]
        },
        "ratio": ratio,
    }
    
    if (ratio == 0.5) and (not get_magdiff):
        json_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}.json")
        mag_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_mag.json")
        path_to_save_graphs = os.path.join(log_dir, f"graphs/{model_name}_{method}_{dataset_name}/")
        if not os.path.exists(path_to_save_graphs):
            os.makedirs(path_to_save_graphs)
    else:
        json_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_ratio_{str(round(ratio,3))}.json")
        mag_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_ratio_{str(round(ratio,3))}_mag.json")
        path_to_save_graphs = os.path.join(log_dir, f"graphs_{str(round(ratio,3))}/{model_name}_{method}_{dataset_name}/")
        if not os.path.exists(path_to_save_graphs):
            os.makedirs(path_to_save_graphs)
    

    all_mag_results = {}
    if get_magdiff:
        if not (dataset_name in ["ENZYMES", "MUTAG"]):
            path_to_save_graphs = None
        all_mag_results["dataset"] = graph_structure_eval_trainable(all_pooled_graphs, path_to_save_graphs=path_to_save_graphs)
        with open(mag_file_path, "w") as json_file:
            json.dump(all_mag_results, json_file, indent=4)
        print(f"Magni results saved to {mag_file_path}")

    with open(json_file_path, "w") as json_file:
        json.dump(experiment_data, json_file, indent=4)
    print(f"Experiment data saved to {json_file_path}")

    print(
        "{} - Test loss: {:.4f} +- {:.4f} - Test accuracy: {:.4f} +- {:.4f}".format(
            dataset_name, avg_results[0], std_results[0], avg_results[1], std_results[1]
        )
    )

    return avg_results, std_results

def results_to_file(dataset, method, avg_results, std_results):
    filename = f"./results/{dataset}/{dataset}_result.csv"
    with open(filename, "a") as f:
        line = "{}, {:.4f} +- {:.4f}, {:.4f} +- {:.4f}\n".format(
            method, avg_results[0], std_results[0], avg_results[1], std_results[1]
        )
        f.write(line)
