from torch_geometric.utils import index_to_mask

from torch_geometric.datasets import (
    Planetoid,
    Coauthor,
    Amazon,
    HeterophilousGraphDataset,
)
from ogb.nodeproppred import PygNodePropPredDataset

import pandas as pd

pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", 100)
pd.set_option("display.width", 1000)

import glob

import os

import sys

import warnings

# disable FutureWarnings caused by PygNodePropPredDataset
warnings.simplefilter(action="ignore", category=FutureWarning)

if __name__ == "__main__":
    num_runs_per_experiment = 10
    save_path = "../Data/Supervised_Clustering/"

    slurm_id = 0
    experiment_id = 0

    experiments = []
    for dataset, max_clusters, split, num_train_nodes_per_class, num_val_nodes in [
        ("CS", 100, "default", 20, 500),
        ("CS", 100, "sparse", 2, 50),
        ("Physics", 100, "default", 20, 500),
        ("Physics", 100, "sparse", 2, 50),
        ("Computers", 100, "default", 20, 500),
        ("Computers", 100, "sparse", 2, 50),
        ("Photo", 100, "default", 20, 500),
        ("Photo", 100, "sparse", 2, 50),
        ("Roman-empire", 100, "default", pd.NA, pd.NA),
        ("Roman-empire", 100, "sparse", 2, 50),
        ("Amazon-ratings", 100, "default", pd.NA, pd.NA),
        ("Amazon-ratings", 100, "sparse", 2, 50),
        # ("ogbn-arxiv", 100, "default", pd.NA, pd.NA),
        # ("ogbn-arxiv", 100, "sparse", 2, 50),
        # ogbn-products default has only 1 training node per class for some classes
        # ("ogbn-products", 100, "default", pd.NA, pd.NA),
    ]:
        for pooling, reconstruct_attributes, gnn, xnn, snn, bnn in [
            # no clustering
            (None, False, "GCN", None, None, None),
            (None, False, "GraphSAGE", None, None, None),
            (None, False, None, None, "Transformer", None),
            (None, False, None, None, "MLP", None),
            # NOCD
            ("NOCD", False, "GCN", None, None, None),
            ("NOCD", False, "GraphSAGE", None, None, None),
            ("NOCD", True, "GCN", "Transformer", None, None),
            ("NOCD", True, "GraphSAGE", "Transformer", None, None),
            ("NOCD", True, "GCN", "MLP", None, None),
            ("NOCD", True, "GraphSAGE", "MLP", None, None),
            ("NOCD", False, None, None, "Transformer", None),
            ("NOCD", False, None, None, "MLP", None),
            # DMoN
            ("DMoNPool", False, "GCN", None, None, None),
            ("DMoNPool", False, "GraphSAGE", None, None, None),
            ("DMoNPool", True, "GCN", "Transformer", None, None),
            ("DMoNPool", True, "GraphSAGE", "Transformer", None, None),
            ("DMoNPool", True, "GCN", "MLP", None, None),
            ("DMoNPool", True, "GraphSAGE", "MLP", None, None),
            ("DMoNPool", False, None, None, "Transformer", None),
            ("DMoNPool", False, None, None, "MLP", None),
            # Neuromap
            ("MapEqPool", False, "GCN", None, None, None),
            ("MapEqPool", False, "GraphSAGE", None, None, None),
            ("MapEqPool", True, "GCN", "Transformer", None, None),
            ("MapEqPool", True, "GraphSAGE", "Transformer", None, None),
            ("MapEqPool", True, "GCN", "MLP", None, None),
            ("MapEqPool", True, "GraphSAGE", "MLP", None, None),
            ("MapEqPool", False, None, None, "Transformer", None),
            ("MapEqPool", False, None, None, "MLP", None),
            # SBM
            ("SBMPool", False, "GCN", None, None, "Transformer"),
            ("SBMPool", False, "GraphSAGE", None, None, "Transformer"),
            ("SBMPool", False, "GCN", None, None, "MLP"),
            ("SBMPool", False, "GraphSAGE", None, None, "MLP"),
            ("SBMPool", True, "GCN", "Transformer", None, "Transformer"),
            ("SBMPool", True, "GraphSAGE", "Transformer", None, "Transformer"),
            ("SBMPool", True, "GCN", "MLP", None, "MLP"),
            ("SBMPool", True, "GraphSAGE", "MLP", None, "MLP"),
            ("SBMPool", False, None, None, "Transformer", "Transformer"),
            ("SBMPool", False, None, None, "MLP", "MLP"),
        ]:
            # skip datasets > 100,000 nodes for Transformer NNs and DMoNPool pooling (quadratic memory requirements)
            # if ("ogbn" in dataset) and (
            #     ("Transformer" in [xnn, snn, bnn]) or (pooling in ["DMoNPool"])
            # ):
            #     continue

            for graph_reconstruction_method in [
                None,
                "sample",
            ]:
                if pooling in ["NOCD", "SBMPool"]:
                    if graph_reconstruction_method is None:
                        continue
                else:
                    if graph_reconstruction_method is not None:
                        continue

                for regularization in [
                    None,
                    "DMoNPool",
                    "L2",
                ]:
                    # skip ogbn datasets for DMoNPool regularization (quadradic memory requirements)
                    # if ("ogbn" in dataset) and (regularization in ["DMoNPool"]):
                    #     continue

                    for supervised in [
                        True,
                        False,
                    ]:
                        if (not supervised) and (pooling is None):
                            continue

                        for stopping_criterion in [
                            "loss",
                            "mcc",
                        ]:
                            if (not supervised) and (stopping_criterion != "loss"):
                                continue

                            # if supervised and (stopping_criterion == "loss"):
                            #     continue

                            for optimization in [
                                "AdamW",
                            ]:
                                experiment_id += 1
                                for run_id in range(num_runs_per_experiment):
                                    slurm_id += 1
                                    experiments.append(
                                        (
                                            save_path,
                                            slurm_id,
                                            experiment_id,
                                            run_id,
                                            dataset,
                                            split,
                                            max_clusters,
                                            num_train_nodes_per_class,
                                            num_val_nodes,
                                            pooling,
                                            reconstruct_attributes,
                                            gnn,
                                            xnn,
                                            snn,
                                            bnn,
                                            graph_reconstruction_method,
                                            regularization,
                                            supervised,
                                            stopping_criterion,
                                            optimization,
                                        )
                                    )

    dtype_dict = {
        "slurm_id": "Int64",
        "experiment_id": "Int64",
        "run_id": "Int64",
        "dataset": "string",
        "split": "string",
        "max_clusters": "Int64",
        "num_train_nodes_per_class": "Int64",
        "num_val_nodes": "Int64",
        "pooling": "string",
        "reconstruct_attributes": "boolean",
        "gnn": "string",
        "xnn": "string",
        "snn": "string",
        "bnn": "string",
        "graph_reconstruction_method": "string",
        "regularization": "string",
        "supervised": "boolean",
        "stopping_criterion": "string",
        "optimization": "string",
    }

    experiments = (
        pd.DataFrame.from_records(
            experiments,
            columns=["save_path"] + list(dtype_dict.keys()),
        )
        .astype({"save_path": "string"})
        .astype(dtype_dict)
    )

    # ensure null values are consistent
    experiments = experiments.where(pd.notnull(experiments), None)

    # save experiment configurations
    experiments.to_csv(save_path + "experiments.csv", index=False)

    # get all unique datasets
    datasets = experiments["dataset"].unique()

    # pre-download all datasets and check validity of splits
    for dataset in datasets:
        if dataset in ["Cora", "CiteSeer", "PubMed"]:
            data = Planetoid(
                root=save_path,
                name=dataset,
            )[0]

        if dataset in ["CS", "Physics"]:
            data = Coauthor(
                root=save_path,
                name=dataset,
            )[0]

        if dataset in ["Computers", "Photo"]:
            data = Amazon(
                root=save_path,
                name=dataset,
            )[0]

        if dataset in [
            "Roman-empire",
            "Amazon-ratings",
            "Minesweeper",
            "Tolokers",
            "Questions",
        ]:
            data = HeterophilousGraphDataset(
                root=save_path,
                name=dataset,
            )[0]

        if "ogbn" in dataset:
            _dataset = PygNodePropPredDataset(
                root=save_path,
                name=dataset,
            )
            data = _dataset[0]

            split_idx = _dataset.get_idx_split()
            train_idx, valid_idx, test_idx = (
                split_idx["train"],
                split_idx["valid"],
                split_idx["test"],
            )
            data.train_mask = index_to_mask(train_idx, size=data.num_nodes)
            data.val_mask = index_to_mask(valid_idx, size=data.num_nodes)
            data.test_mask = index_to_mask(test_idx, size=data.num_nodes)

            data.y = data.y.squeeze(dim=-1)

        if not hasattr(data, "num_classes"):
            data.num_classes = data.y.unique().shape[0]

        print("\nDataset:", dataset)
        print("Number of nodes:", data.num_nodes)
        print("Number of edges:", data.num_edges)
        print("Number of features:", data.num_features)
        print("Number of classes:", data.num_classes)

        # no split in Coauthor or Amazon datasets so just check total number of nodes per class
        if dataset in ["CS", "Physics", "Computers", "Photo"]:
            print(
                "Number of nodes per class:",
                data.y.bincount().sort().values.tolist(),
            )

        # 10 different splits in HeterophilousGraphDataset datasets so check for each split
        if dataset in ["Roman-empire", "Amazon-ratings"]:
            for split in range(10):
                train_mask = data.train_mask.T[split]
                val_mask = data.val_mask.T[split]
                test_mask = data.test_mask.T[split]

                print("\nSplit:", split)
                print(
                    "Number of training nodes per class:",
                    data.y[train_mask].bincount().sort().values.tolist(),
                )
                print(
                    "Number of validation nodes per class:",
                    data.y[val_mask].bincount().sort().values.tolist(),
                )
                print(
                    "Number of test nodes per class:",
                    data.y[test_mask].bincount().sort().values.tolist(),
                )

        if "ogbn" in dataset:
            print(
                "Number of training nodes per class:",
                data.y[data.train_mask].bincount().sort().values.tolist(),
            )
            print(
                "Number of validation nodes per class:",
                data.y[data.val_mask].bincount().sort().values.tolist(),
            )
            print(
                "Number of test nodes per class:",
                data.y[data.test_mask].bincount().sort().values.tolist(),
            )

    print(
        "\nTotal number of experiments:",
        len(experiments),
        "\n",
    )

    # check if any .zip files with "hpc-results" in their file name exist
    all_results = glob.glob(save_path + "*hpc-results-*.zip")

    fresh_start = False
    if fresh_start:
        if len(all_results) > 0:
            timestamp = pd.Timestamp.now().strftime("%Y-%m-%d-%H-%M-%S")
            os.makedirs(save_path + "archive/" + timestamp + "/", exist_ok=True)
            for result in all_results:
                print("Archiving", result)
                sys.stdout.flush()
                os.rename(
                    result,
                    save_path + "archive/" + timestamp + "/" + os.path.basename(result),
                )

        remaining_ids = experiments["slurm_id"].tolist()
    else:
        if len(all_results) > 0:
            # get a list of all the experiment ids that have already been run
            completed_ids = [
                int(result.split("-")[-1].split(".")[0]) for result in all_results
            ]
            remaining_ids = experiments[~experiments["slurm_id"].isin(completed_ids)][
                "slurm_id"
            ].tolist()

            print(
                "\nNumber of completed runs:",
                len(completed_ids),
                "\n",
            )
        else:
            remaining_ids = experiments["slurm_id"].tolist()

    print(
        "\nNumber of remaining runs:",
        len(remaining_ids),
        "\n",
    )

    print("\nRemaining SLURM_ARRAY_TASK_ID to run:")
    print(remaining_ids)
    print("\n")

    # save remaining experiment ids as a text file
    with open(save_path + "remaining.txt", "w") as f:
        for item in remaining_ids:
            f.write("%s\n" % item)

    # save dtype_dict as a pandas DataFrame
    dtype_dict = pd.DataFrame.from_records(
        list(dtype_dict.items()),
        columns=["arg", "dtype"],
    ).astype("string")
    dtype_dict.to_csv(save_path + "dtype_dict.csv", index=False)
