from torch_geometric.datasets import (
    Planetoid,
    Coauthor,
    Amazon,
    HeterophilousGraphDataset,
)
from ogb.nodeproppred import PygNodePropPredDataset

from torch_geometric.utils import to_networkx

import cugraph
import networkx as nx

# https://github.com/rapidsai/cugraph/blob/main/notebooks/demo/accelerating_networkx.ipynb
nx.config.backend_priority = ["cugraph"]  # NETWORKX_BACKEND_PRIORITY=cugraph
nx.config.cache_converted_graphs = True  # NETWORKX_CACHE_CONVERTED_GRAPHS=True

import matplotlib.pyplot as plt

import dask.dataframe as dd
import pandas as pd

pd.set_option("display.max_rows", 10000)
pd.set_option("display.max_columns", 100)
pd.set_option("display.width", 1000)

import numpy as np

import glob
import sys
import os

from concurrent.futures import ThreadPoolExecutor

from contextlib import redirect_stdout

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import time


def read_csv(file):
    return pd.read_csv(file)


if __name__ == "__main__":
    save_path = "../Data/Supervised_Clustering/"

    start_time = time.time()

    compile_results = False
    if compile_results:
        compressed = True
        if compressed:
            # get a list of all CSV files
            files = glob.glob(save_path + "hpc-results-*.zip")

            print("\nFound", len(files), "data files. Loading data...\n")
            sys.stdout.flush()

            # use ThreadPoolExecutor to read files in parallel (dask can't read zip files in parallel https://github.com/dask/dask/issues/2554)
            with ThreadPoolExecutor() as executor:
                # Map the read_csv function to the list of files
                results = list(executor.map(read_csv, files))

            # combine all dataframes into a single dataframe
            results = pd.concat(results, ignore_index=True)
        else:
            # use dask to read all CSV files in parallel
            results = dd.read_csv(
                save_path + "hpc-results-*[!.zip]", dtype="string"
            ).compute()

        # save the combined dataframe
        results.to_csv(save_path + "hpc-results.zip", index=False, compression="zip")

        print("\nData compiled and saved.\n")
        sys.stdout.flush()
    else:
        if os.path.exists(save_path + "hpc-results.zip"):
            print("\nCompiled data found. Loading data...\n")
            sys.stdout.flush()

            results = pd.read_csv(save_path + "hpc-results.zip", dtype="string")

            print("\nData loaded.\n")
            sys.stdout.flush()
        else:
            print("\nNo results found.\n")
            sys.exit()

    end_time = time.time()
    print("Loading time: {:.3f} seconds".format(end_time - start_time))
    sys.stdout.flush()
    start_time = end_time

    # print all unique gpus
    print("\nGPU models used:")
    print(results["gpu"].unique())
    print("\n")

    # save raw results for plotting
    raw_results = results.copy()

    # drop columns that are not needed e.g. pred_labels
    results = results.drop(
        columns=[
            "pred_labels",
            "gpu",
        ]
    )

    # read experiments from csv file
    experiments = pd.read_csv(save_path + "experiments.csv", dtype="string")
    experiments = experiments.drop(columns=["save_path"])

    # get all planned slurm ids
    slurm_ids = experiments["slurm_id"].unique()

    # get matching columns in results and experiments
    matching_columns = set(results.columns).intersection(experiments.columns)
    result_columns = list(set(results.columns).difference(matching_columns))
    config_columns = list(matching_columns.difference(["slurm_id", "run_id"]))

    # set dtypes
    config_dtypes = {column: "string" for column in config_columns}
    result_dtypes = {column: "Float64" for column in result_columns}

    # check for null values in result_columns
    print("\nNull values in result columns:")
    print(results[result_columns].isnull().sum())
    print("\n")

    # convert columns to correct dtypes
    results = results.astype(config_dtypes)
    results = results.astype(result_dtypes)
    results = results.astype({"slurm_id": "string"})

    raw_results = raw_results.astype(config_dtypes)
    raw_results = raw_results.astype(result_dtypes)

    # ensure null values are consistent
    experiments = experiments.fillna("None")
    results = results.fillna("None")
    raw_results = raw_results.fillna("None")

    # convert pred_labels to numpy arrays
    raw_results["pred_labels"] = raw_results["pred_labels"].apply(
        lambda x: np.array(x.strip("[]").split(", "), dtype=int)
    )

    # take negative exponential of loss so that higher is better
    results["loss"] = results["loss"].apply(lambda x: np.exp(-x))
    raw_results["loss"] = raw_results["loss"].apply(lambda x: np.exp(-x))

    # delete files with slurm ids in results that are not in experiments
    delete_rest = False
    if delete_rest:
        count = 0
        trash_ids = results[~results["slurm_id"].isin(slurm_ids)]["slurm_id"].unique()
        if len(trash_ids) == 0:
            print("No files to delete.")
        else:
            timestamp = pd.Timestamp.now().strftime("%Y-%m-%d-%H-%M-%S")
            os.makedirs(save_path + "archive/" + timestamp + "/", exist_ok=True)
            for trash_id in trash_ids:
                trash_files = glob.glob(save_path + "/*" + trash_id + "*")
                for trash_file in trash_files:
                    print("Archiving", trash_file)
                    sys.stdout.flush()
                    os.rename(
                        trash_file,
                        save_path
                        + "archive/"
                        + timestamp
                        + "/"
                        + os.path.basename(trash_file),
                    )
                    count += 1
        print("Deleted", count, "files.")

    # only keep results with planned slurm ids
    results = results[results["slurm_id"].isin(slurm_ids)]

    # remove slurm id and run id from results
    results = results.drop(columns=["slurm_id", "run_id"])

    # groupby matching columns and get mean and std for each result column
    results = results.groupby(config_columns)[result_columns].agg(["mean", "std"])
    results = results.reset_index()

    # merge gnn and snn columns
    results["model"] = results["gnn"] + results["snn"]
    results["model"] = results["model"].str.replace("None", "")
    config_columns.append("model")

    # print(results.head(100))

    # group by dataset and split
    groups = [
        "dataset",
        "split",
        "max_clusters",
        "num_train_nodes_per_class",
        "num_val_nodes",
    ]

    completed = results.groupby(groups).size()
    completed = completed.reset_index(name="completed")

    planned = experiments[experiments["run_id"] == "0"]
    planned = planned.groupby(groups).size()
    planned = planned.reset_index(name="planned")

    # handle messy columns
    messy_columns = ["max_clusters", "num_train_nodes_per_class", "num_val_nodes"]

    completed = (
        completed.replace("None", pd.NA)
        .astype({column: "Float64" for column in messy_columns})
        .astype({column: "Int64" for column in messy_columns})
        .astype("string")
        .fillna("default")
    )
    planned = (
        planned.replace("None", pd.NA)
        .astype({column: "Float64" for column in messy_columns})
        .astype({column: "Int64" for column in messy_columns})
        .astype("string")
        .fillna("default")
    )

    results[messy_columns] = (
        results[messy_columns]
        .replace("None", pd.NA)
        .astype("Float64")
        .astype("Int64")
        .astype("string")
        .fillna("default")
    )
    raw_results[messy_columns] = (
        raw_results[messy_columns]
        .replace("None", pd.NA)
        .astype("Float64")
        .astype("Int64")
        .astype("string")
        .fillna("default")
    )

    completed = completed.merge(planned, on=groups, how="outer")
    completed = completed.astype({"completed": "Int64", "planned": "Int64"})

    print("\nNumber of experiments per dataset and split:")
    print(completed)
    print("\n")

    # drop results not yet complete
    completed = completed[completed["completed"] == completed["planned"]]
    completed = completed.drop(columns=["completed", "planned"])

    latex_path = "../Papers/ICLR2025/assets/"
    latex_file = latex_path + "results.tex"
    summary_file = latex_path + "summary.tex"

    visualize_metric = "test_mcc"

    test_metrics = [
        "test_mcc",
        "test_acc",
        # "test_ami",
    ]

    selection_metrics = [
        "loss",
        "train_mcc",
        "val_mcc",
    ]

    summary_split = "sparse"
    summary_selection_metric = "val_mcc"

    summary = {}

    latex_columns = {
        "model": "Model",
        "pooling": "$ f $",
        "regularization": "$L_\\text{regularization}$",
        "reconstruct_attributes": "$L_V$",
        "supervised": "$L_S$",
        "xnn": "$\\text{NN}_{\mS \\to \mX}$",
        "stopping_criterion": "ES",
    }

    latex_values = {
        # datasets
        "CS": "Coauthor CS",
        "Physics": "Coauthor Physics",
        "Computers": "Amazon Computers",
        "Photo": "Amazon Photo",
        "Amazon-ratings": "\\texttt{amazon-ratings}",
        "Roman-empire": "\\texttt{roman-empire}",
        # models
        "True": "\\cmark",
        "False": "\\xmark",
        # pooling
        "DMoNPool": "DMoN",
        "MapEqPool": "Neuromap",
        "SBMPool": "$\\text{SBM}_\\text{NN}$",
        # early stopping
        "loss": "Loss",
        "mcc": "MCC",
    }

    latex_metric = {
        "loss": "training loss",
        "train_mcc": "training set MCC",
        "val_mcc": "validation set MCC",
        "test_acc": "Accuracy",
        "test_mcc": "MCC",
        "test_ami": "AMI",
    }

    # replace "reconstruct_attributes" values with "N/A" when model in ["Transformer", "MLP"]
    results["reconstruct_attributes"] = results["reconstruct_attributes"].mask(
        results["model"].isin(["Transformer", "MLP"]), "N/A"
    )

    results = results.sort_values(by="experiment_id", ascending=True)

    end_time = time.time()
    print("Preprocessing time: {:.3f} seconds".format(end_time - start_time))

    sys.stdout.flush()

    original_stdout = sys.stdout

    with open(latex_file, "w") as f:
        sys.stdout = f

        print("""\n
        \\section{Additional Results}
        \\label{sec:fullresults}

        \n""")

        generate_graphs = False
        if generate_graphs:
            sys.stdout = original_stdout
            print("\nGenerating graphs...\n")
            sys.stdout.flush()
            sys.stdout = f

        visualize = True
        if visualize:
            print("""\n
            \\subsection{Visualisations of Graph Labellings}
            \\label{subsec:visualisations}
            
            \n""")

            datasets = completed["dataset"].unique()

            for dataset in datasets:
                if generate_graphs:
                    start_time = time.time()
                    sys.stdout = original_stdout
                    print("\nLoading", dataset, "dataset...\n")
                    sys.stdout.flush()

                    if dataset in ["Cora", "CiteSeer", "PubMed"]:
                        data = Planetoid(
                            root=save_path,
                            name=dataset,
                        )[0]

                    if dataset in ["CS", "Physics"]:
                        data = Coauthor(
                            root=save_path,
                            name=dataset,
                        )[0]

                    if dataset in ["Computers", "Photo"]:
                        data = Amazon(
                            root=save_path,
                            name=dataset,
                        )[0]

                    if dataset in [
                        "Roman-empire",
                        "Amazon-ratings",
                        "Minesweeper",
                        "Tolokers",
                        "Questions",
                    ]:
                        data = HeterophilousGraphDataset(
                            root=save_path,
                            name=dataset,
                        )[0]

                    if "ogbn" in dataset:
                        _dataset = PygNodePropPredDataset(
                            root=save_path,
                            name=dataset,
                        )
                        data = _dataset[0]

                        data.y = data.y.squeeze(dim=-1)

                    if not hasattr(data, "num_classes"):
                        data.num_classes = data.y.unique().shape[0]

                    end_time = time.time()
                    print(
                        "\nLoaded dataset in {:.3f} seconds\n".format(
                            end_time - start_time
                        )
                    )

                    # create a stochastic block model to visualize the ground truth
                    start_time = end_time
                    print("\nGenerating stochastic block model...\n")

                    sys.stdout.flush()

                    nodelist = []
                    sizes = [0] + data.y.bincount().tolist()[:-1]
                    sizes = np.cumsum(sizes)
                    for i in range(len(data.y)):
                        cluster = data.y[i].item()
                        nodelist.append(sizes[cluster])
                        sizes[cluster] += 1

                    node_map = dict(zip(range(len(data.y)), nodelist))
                    nodelist = [node_map[i] for i in nodelist]

                    sizes = data.y.bincount().tolist()
                    n_clusters = len(sizes)
                    p = (
                        0.99 * np.eye(n_clusters)
                        + 0.01 * np.ones((n_clusters, n_clusters))
                    ).tolist()
                    seed = 42

                    G = nx.stochastic_block_model(sizes, p, seed=seed)

                    end_time = time.time()
                    print(
                        "\nGenerated stochastic block model in {:.3f} seconds\n".format(
                            end_time - start_time
                        )
                    )

                    # plot unsupervised vs supervised vs semi-supervised
                    start_time = end_time
                    print("\nGenerating graph layout...\n")

                    sys.stdout.flush()

                    pos = cugraph.force_atlas2(G).to_pandas()

                    end_time = time.time()
                    print(
                        "\nGenerated graph layout in {:.3f} seconds\n".format(
                            end_time - start_time
                        )
                    )
                    sys.stdout.flush()

                    pos = {
                        row["vertex"]: (row["x"], row["y"]) for _, row in pos.iterrows()
                    }

                    node_color = [i for i, j in enumerate(sizes) for k in range(j)]

                    start_time = time.time()

                    G = to_networkx(data, to_undirected=True)

                    node_color = [
                        node_color[node_map[i]] for i in range(len(node_color))
                    ]

                    node_map_inv = {j: i for i, j in node_map.items()}
                    pos = {node_map_inv[i]: j for i, j in pos.items()}

                for (
                    dataset,
                    split,
                    max_clusters,
                    num_train_nodes_per_class,
                    num_val_nodes,
                ) in completed[completed["dataset"] == dataset].itertuples(index=False):
                    # print current group
                    sys.stdout = f
                    print(
                        """\n
                    \\subsubsection{
                    %s dataset, %s split with %s train nodes per class and %s validation nodes.
                    }
                    \n"""
                        % (
                            latex_values[dataset],
                            split,
                            num_train_nodes_per_class,
                            num_val_nodes,
                        )
                    )

                    for selection_metric in selection_metrics:
                        if generate_graphs:
                            sys.stdout = original_stdout
                            start_time = time.time()
                            print(
                                "\nGetting best results for visualisations of %s split with %s train nodes per class and %s validation nodes based on %s...\n"
                                % (
                                    split,
                                    num_train_nodes_per_class,
                                    num_val_nodes,
                                    latex_metric[selection_metric],
                                )
                            )
                            sys.stdout.flush()

                            # get best unsupervised vs supervised vs semi-supervised
                            filtered = raw_results[
                                (raw_results["dataset"] == dataset)
                                & (raw_results["split"] == split)
                                & (raw_results["max_clusters"] == max_clusters)
                                & (
                                    raw_results["num_train_nodes_per_class"]
                                    == num_train_nodes_per_class
                                )
                                & (raw_results["num_val_nodes"] == num_val_nodes)
                            ]

                            bests = []

                            best_unsupervised = filtered[
                                (filtered["supervised"] == "False")
                            ]
                            best_supervised = filtered[
                                (filtered["supervised"] == "True")
                                & (filtered["pooling"] == "None")
                                & (filtered["regularization"] == "None")
                            ]
                            best_semi_supervised = filtered[
                                (filtered["supervised"] == "True")
                                & (filtered["pooling"] != "None")
                            ]

                            if (
                                best_unsupervised[selection_metric]
                                .isnull()
                                .all(axis=None)
                            ):
                                continue
                            if (
                                best_supervised[selection_metric]
                                .isnull()
                                .all(axis=None)
                            ):
                                continue
                            if (
                                best_semi_supervised[selection_metric]
                                .isnull()
                                .all(axis=None)
                            ):
                                continue

                            best_unsupervised = (
                                best_unsupervised.loc[
                                    best_unsupervised["loss"].idxmax()
                                ]
                                .to_frame()
                                .T
                            )
                            bests.append(best_unsupervised)

                            best_supervised = (
                                best_supervised.loc[
                                    best_supervised[selection_metric].idxmax()
                                ]
                                .to_frame()
                                .T
                            )
                            bests.append(best_supervised)

                            best_semi_supervised = (
                                best_semi_supervised.loc[
                                    best_semi_supervised[selection_metric].idxmax()
                                ]
                                .to_frame()
                                .T
                            )
                            bests.append(best_semi_supervised)

                            if len(bests) == 0:
                                continue

                            bests = pd.concat(bests).reset_index(drop=True)

                            end_time = time.time()
                            print(
                                "\nGot best results in {:.3f} seconds\n".format(
                                    end_time - start_time
                                )
                            )

                            start_time = end_time
                            print("\nPlotting graphs...\n")

                            sys.stdout.flush()

                            # create 4 subplots (in one row) for each of the 4 cases
                            fig, axs = plt.subplots(1, 4, figsize=(16, 4))

                            node_size = 10
                            alpha = 0.01

                            # plot the ground truth on the first subplot
                            axs[0].set_title("Ground Truth")
                            nx.draw_networkx_nodes(
                                G,
                                pos,
                                node_color=node_color,
                                node_size=node_size,
                                ax=axs[0],
                            )
                            nx.draw_networkx_edges(G, pos, alpha=alpha, ax=axs[0])

                            # plot the predicted labels on the other 3 subplots
                            for i, row in bests.iterrows():
                                axs[i + 1].set_title(
                                    (
                                        "Supervised "
                                        if row["supervised"] == "True"
                                        else "Unsupervised "
                                    )
                                    + (
                                        "Classification"
                                        if row["pooling"] == "None"
                                        else "Clustering"
                                    )
                                    + " ({:.3f})".format(row[visualize_metric])
                                )
                                nx.draw_networkx_nodes(
                                    G,
                                    pos,
                                    node_color=row["pred_labels"],
                                    node_size=node_size,
                                    ax=axs[i + 1],
                                )
                                nx.draw_networkx_edges(
                                    G, pos, alpha=alpha, ax=axs[i + 1]
                                )

                            plt.tight_layout()

                            end_time = time.time()
                            print(
                                "\nPlotted graphs in {:.3f} seconds\n".format(
                                    end_time - start_time
                                )
                            )
                            sys.stdout.flush()

                            # save the figure
                            fig.savefig(
                                latex_path
                                + "graphs/"
                                + dataset
                                + "-"
                                + split
                                + "-"
                                + str(num_train_nodes_per_class)
                                + "-"
                                + str(num_val_nodes)
                                + "-"
                                + selection_metric
                                + ".png"
                            )

                            plt.close(fig)

                        sys.stdout = f
                        print(
                            """\n
                        \\begin{figure}[H]
                        \\begin{center}
                        \\includegraphics[width=1.0\\textwidth]{./assets/graphs/%s-%s-%s-%s-%s.png}
                        \\end{center}
                        \\caption{
                        Visualisations of graph labellings for the %s dataset, %s split with %s train nodes per class and %s validation nodes.
                        Model selection based on %s.
                        }
                        \\label{fig:visualisations_%s_%s_%s_%s_%s}
                        \\end{figure}
                                
                        \n"""
                            % (
                                dataset,
                                split,
                                num_train_nodes_per_class,
                                num_val_nodes,
                                selection_metric,
                                latex_values[dataset],
                                split,
                                num_train_nodes_per_class,
                                num_val_nodes,
                                latex_metric[selection_metric],
                                dataset,
                                split,
                                num_train_nodes_per_class,
                                num_val_nodes,
                                selection_metric,
                            )
                        )

        print("""\n
        \\subsection{Additional Summary Results}
        \\label{subsec:supervised_vs_semi_supervised}

        In this section we present additional summary results for both default and sparse label splits of each dataset.
        We compare early stopping and model selection based on training loss, training set MCC, and validation set MCC.

        \n""")

        for (
            dataset,
            split,
            max_clusters,
            num_train_nodes_per_class,
            num_val_nodes,
        ) in completed.itertuples(index=False):
            # print current group
            print(
                """\n
            \\subsubsection{
            %s dataset, %s split with %s train nodes per class and %s validation nodes.
            }
            \n"""
                % (
                    latex_values[dataset],
                    split,
                    num_train_nodes_per_class,
                    num_val_nodes,
                )
            )

            for selection_metric in selection_metrics:
                # for each group and metric, compare:
                # - supervised vs semi-supervised for all models
                # - no attribute reconstruction vs attribute reconstruction for semi-supervised GCN2
                # - SBM vs DMoN for all models (for semi-supervised)
                # - various regularization methods for all models (including no regularization) for semi-supervised
                # - pooling+regularization (semi-supervised) vs no pooling+regularization (semi-supervised) vs only pooling (semi-supervised) vs only regularization (supervised) for all models
                # - xnn and bnn and snn for all models

                # supervised vs semi-supervised for all models
                bests = []
                for model in results["model"].unique():
                    filtered = results[
                        (results["model"] == model)
                        & (results["dataset"] == dataset)
                        & (results["split"] == split)
                        & (results["max_clusters"] == max_clusters)
                        & (
                            results["num_train_nodes_per_class"]
                            == num_train_nodes_per_class
                        )
                        & (results["num_val_nodes"] == num_val_nodes)
                        & (results["supervised"] == "True")
                    ]

                    # only keep mean column for val_metric
                    filtered = filtered.drop(columns=[(selection_metric, "std")])

                    supervised = filtered[filtered["pooling"] == "None"]
                    semi_supervised = filtered[filtered["pooling"] != "None"]

                    if supervised[selection_metric].isnull().all(axis=None):
                        continue
                    if semi_supervised[selection_metric].isnull().all(axis=None):
                        continue

                    best_supervised = supervised.loc[
                        supervised[selection_metric].idxmax()
                    ]
                    bests.append(best_supervised)

                    best_semi_supervised = semi_supervised.loc[
                        semi_supervised[selection_metric].idxmax()
                    ]
                    bests.append(best_semi_supervised)

                if (split == summary_split) and (
                    selection_metric == summary_selection_metric
                ):
                    summary[dataset] = bests
                else:
                    if len(bests) == 0:
                        continue

                    bests = pd.concat(bests).reset_index(drop=True)

                    for test_metric in test_metrics:
                        test_latex = []
                        test_results = []

                        for model in bests["model"].unique():
                            model_bests = bests[bests["model"] == model].reset_index(
                                drop=True
                            )

                            # make mean $\pm$ std column for metric
                            model_latex = (
                                model_bests[test_metric]["mean"]
                                .apply(lambda x: f"{x:.3f}")
                                .astype("string")
                                + " $\pm$ "
                                + model_bests[test_metric]["std"]
                                .apply(lambda x: f"{x:.3f}")
                                .astype("string")
                            )

                            # highlight in bold best result
                            index = model_bests[test_metric]["mean"].idxmax()
                            model_latex.iloc[index] = (
                                "\\textbf{" + model_latex.iloc[index] + "}"
                            )

                            test_latex.append(model_latex)
                            test_results.append(model_bests[test_metric]["mean"])

                        test_latex = pd.concat(test_latex).to_numpy()
                        test_results = pd.concat(test_results).to_numpy()

                        # underline the best result for each dataset
                        index = test_results.argmax()
                        test_latex[index] = "\\underline{" + test_latex[index] + "}"

                        # drop mean and std columns
                        bests = bests.drop(columns=test_metric, level=0)

                        bests[test_metric] = test_latex

                    models = set(bests["model"].unique())

                    bests = bests[list(latex_columns.keys()) + test_metrics]
                    bests = bests.rename(columns=latex_columns)
                    bests = bests.rename(
                        columns={
                            test_metric: latex_metric[test_metric]
                            for test_metric in test_metrics
                        }
                    )
                    bests = bests.replace(latex_values)
                    bests = bests.droplevel(1, axis=1)
                    bests = bests.drop(columns=latex_columns["supervised"])

                    print(
                        """\n
                    \\begin{table}[H]
                    \\caption{
                    Comparing (semi-)supervised node classification with semi-supervised graph clustering. 
                    Model selection based on %s.
                    }
                    \\label{tab:supervised_vs_semi_supervised_%s_%s_%s_%s_%s}
                    \\begin{center}
                    \n"""
                        % (
                            latex_metric[selection_metric],
                            dataset,
                            split,
                            num_train_nodes_per_class,
                            num_val_nodes,
                            selection_metric,
                        )
                    )

                    table = bests.to_latex(index=False, escape=False)

                    table = table.split("\n")
                    header = "\n".join(table[:4])
                    footer = "\n".join(table[-3:])

                    result = table[4:-3]

                    lines_added = 0
                    for index, line in enumerate(result):
                        line = line.split("&")[0]
                        for model in models:
                            if model in line:
                                if lines_added > 0:
                                    result.insert(index, "\\hdashline")

                                lines_added += 1

                                models.remove(model)
                                break

                    result = "\n".join(result)

                    print(header)
                    print(result)
                    print(footer)

                    print(
                        """\n
                    \\end{center}
                    \\end{table}
                    \n"""
                    )

        print("""\n
        \\subsection{No Attribute Reconstruction vs Attribute Reconstruction}
        \\label{subsec:no_attr_recon_vs_attr_recon}

        In this section we present results an ablation of attribute reconstruction for semi-supervised graph clustering with graph neural networks GCN and GraphSAGE. 
        
        \n""")

        for (
            dataset,
            split,
            max_clusters,
            num_train_nodes_per_class,
            num_val_nodes,
        ) in completed.itertuples(index=False):
            # print current group
            print(
                """\n
            \\subsubsection{
            %s dataset, %s split with %s train nodes per class and %s validation nodes.
            }
            \n"""
                % (
                    latex_values[dataset],
                    split,
                    num_train_nodes_per_class,
                    num_val_nodes,
                )
            )

            for selection_metric in selection_metrics:
                bests = []

                gnns = results["gnn"].unique().tolist()
                gnns.remove("None")
                for gnn in gnns:
                    filtered = results[
                        (results["gnn"] == gnn)
                        & (results["dataset"] == dataset)
                        & (results["split"] == split)
                        & (results["max_clusters"] == max_clusters)
                        & (
                            results["num_train_nodes_per_class"]
                            == num_train_nodes_per_class
                        )
                        & (results["num_val_nodes"] == num_val_nodes)
                        & (results["supervised"] == "True")
                        & (results["pooling"] != "None")
                    ]

                    # only keep mean column for selection_metric
                    filtered = filtered.drop(columns=[(selection_metric, "std")])

                    no_attr_recon = filtered[
                        filtered["reconstruct_attributes"] == "False"
                    ]
                    attr_recon = filtered[filtered["reconstruct_attributes"] == "True"]

                    if no_attr_recon[selection_metric].isnull().all(axis=None):
                        continue
                    if attr_recon[selection_metric].isnull().all(axis=None):
                        continue

                    best_no_attr_recon = no_attr_recon.loc[
                        no_attr_recon[selection_metric].idxmax()
                    ]
                    bests.append(best_no_attr_recon)

                    best_attr_recon = attr_recon.loc[
                        attr_recon[selection_metric].idxmax()
                    ]
                    bests.append(best_attr_recon)

                if len(bests) == 0:
                    continue

                bests = pd.concat(bests).reset_index(drop=True)

                for test_metric in test_metrics:
                    test_results = []

                    for gnn in gnns:
                        gnn_bests = bests[bests["gnn"] == gnn].reset_index(drop=True)

                        # make mean $\pm$ std column for metric
                        gnn_results = (
                            gnn_bests[test_metric]["mean"]
                            .apply(lambda x: f"{x:.3f}")
                            .astype("string")
                            + " $\pm$ "
                            + gnn_bests[test_metric]["std"]
                            .apply(lambda x: f"{x:.3f}")
                            .astype("string")
                        )

                        # highlight in bold best result
                        index = gnn_bests[test_metric]["mean"].idxmax()
                        gnn_results.iloc[index] = (
                            "\\textbf{" + gnn_results.iloc[index] + "}"
                        )

                        test_results.append(gnn_results)

                    test_results = pd.concat(test_results).to_numpy()

                    # drop mean and std columns
                    bests = bests.drop(columns=test_metric, level=0)

                    bests[test_metric] = test_results

                models = set(bests["gnn"].unique())

                bests = bests[list(latex_columns.keys()) + test_metrics]
                bests = bests.rename(columns=latex_columns)
                bests = bests.rename(
                    columns={
                        test_metric: latex_metric[test_metric]
                        for test_metric in test_metrics
                    }
                )
                bests = bests.replace(latex_values)
                bests = bests.droplevel(1, axis=1)
                bests = bests.drop(columns=latex_columns["supervised"])

                print(
                    """\n
                \\begin{table}[H]
                \\caption{
                Ablating attribute reconstruction for semi-supervised graph clustering with graph neural networks GCN and GraphSAGE.
                Model selection based on %s.
                }
                \\label{tab:no_attr_recon_vs_attr_recon_%s_%s_%s_%s_%s}
                \\begin{center}
                \n"""
                    % (
                        latex_metric[selection_metric],
                        dataset,
                        split,
                        num_train_nodes_per_class,
                        num_val_nodes,
                        selection_metric,
                    )
                )

                table = bests.to_latex(index=False, escape=False)

                table = table.split("\n")
                header = "\n".join(table[:4])
                footer = "\n".join(table[-3:])

                result = table[4:-3]

                lines_added = 0
                for index, line in enumerate(result):
                    line = line.split("&")[0]
                    for model in models:
                        if model in line:
                            if lines_added > 0:
                                result.insert(index, "\\hdashline")

                            lines_added += 1

                            models.remove(model)
                            break

                result = "\n".join(result)

                print(header)
                print(result)
                print(footer)

                print(
                    """\n
                \\end{center}
                \\end{table}
                \n"""
                )

        print("""\n
        \\subsection{Comparison of Pooling Methods and Regularisation}
        \\label{subsec:pooling_vs_regularisation}

        In this section we present results comparing graph clustering and regularization objectives and the effect of ablating regularization objectives.
        We compare clustering with regularization, clustering without regularization, and regularization without clustering for each neural network.
              
        Our results show that unsupervised graph clustering objectives with regularization outperform regularization alone, ablating the effect of regularization.

        \n""")

        for (
            dataset,
            split,
            max_clusters,
            num_train_nodes_per_class,
            num_val_nodes,
        ) in completed.itertuples(index=False):
            # print current group
            print(
                """\n
            \\subsubsection{
            %s dataset, %s split with %s train nodes per class and %s validation nodes.
            }
            \n"""
                % (
                    latex_values[dataset],
                    split,
                    num_train_nodes_per_class,
                    num_val_nodes,
                )
            )

            for selection_metric in selection_metrics:
                bests = []
                for model in results["model"].unique():
                    for pooling in results["pooling"].unique():
                        for regularization in results["regularization"].unique():
                            filtered = results[
                                (results["model"] == model)
                                & (results["dataset"] == dataset)
                                & (results["split"] == split)
                                & (results["max_clusters"] == max_clusters)
                                & (
                                    results["num_train_nodes_per_class"]
                                    == num_train_nodes_per_class
                                )
                                & (results["num_val_nodes"] == num_val_nodes)
                                & (results["supervised"] == "True")
                                & (results["pooling"] == pooling)
                                & (results["regularization"] == regularization)
                            ]

                            # only keep mean column for selection_metric
                            filtered = filtered.drop(
                                columns=[(selection_metric, "std")]
                            )

                            if filtered[selection_metric].isnull().all(axis=None):
                                continue

                            best = filtered.loc[filtered[selection_metric].idxmax()]
                            bests.append(best)

                if len(bests) == 0:
                    continue

                bests = pd.concat(bests).reset_index(drop=True)

                for test_metric in test_metrics:
                    test_results = []

                    for model in bests["model"].unique():
                        model_bests = bests[bests["model"] == model].reset_index(
                            drop=True
                        )

                        # make mean $\pm$ std column for metric
                        model_results = (
                            model_bests[test_metric]["mean"]
                            .apply(lambda x: f"{x:.3f}")
                            .astype("string")
                            + " $\pm$ "
                            + model_bests[test_metric]["std"]
                            .apply(lambda x: f"{x:.3f}")
                            .astype("string")
                        )

                        # highlight in bold best result
                        index = model_bests[test_metric]["mean"].idxmax()
                        model_results.iloc[index] = (
                            "\\textbf{" + model_results.iloc[index] + "}"
                        )

                        test_results.append(model_results)

                    test_results = pd.concat(test_results).to_numpy()

                    # drop mean and std columns
                    bests = bests.drop(columns=test_metric, level=0)

                    bests[test_metric] = test_results

                models = set(bests["model"].unique())

                bests = bests[list(latex_columns.keys()) + test_metrics]
                bests = bests.rename(columns=latex_columns)
                bests = bests.rename(
                    columns={
                        test_metric: latex_metric[test_metric]
                        for test_metric in test_metrics
                    }
                )
                bests = bests.replace(latex_values)
                bests = bests.droplevel(1, axis=1)
                bests = bests.drop(columns=latex_columns["supervised"])

                caption = "This is a caption."

                long_caption = """
                        Comparing graph clustering and regularization objectives with an ablation study of regularization.
                        The ablation compares clustering with clusteirng with regularization, clustering without regularization, and regularization without clustering for each neural network.
                        Model selection based on %s.
                    """ % latex_metric[selection_metric]

                short_caption = """
                        Comparing graph clustering and regularization objectives for each neural network.
                    """

                table = bests.to_latex(
                    index=False,
                    escape=False,
                    longtable=True,
                    caption=caption,
                    label="tab:pooling_vs_regularization_%s_%s_%s_%s_%s"
                    % (
                        dataset,
                        split,
                        num_train_nodes_per_class,
                        num_val_nodes,
                        selection_metric,
                    ),
                )

                table = table.split("\n")

                lines_added = 0
                for index, line in enumerate(table):
                    line = line.split("&")[0]
                    for model in models:
                        if model in line:
                            if lines_added > 0:
                                table.insert(index, "\\hdashline")

                            lines_added += 1

                            models.remove(model)
                            break

                table = "\n".join(table)

                table = table.replace(
                    "\\caption{%s}" % caption, "\\caption{%s}" % long_caption
                )
                table = table.replace(
                    "\\caption[]{%s}" % caption, "\\caption[]{%s}" % short_caption
                )

                print(table)

    sys.stdout = original_stdout
    print("\nSummarised %d datasets.\n" % len(summary))
    sys.stdout.flush()

    with open(summary_file, "w") as f:
        with redirect_stdout(f):
            caption = "This is a caption."

            long_caption = """
                    Summary results comparing semi-supervised graph clustering (where $ L_E $ is not ``None'') compared with semi-supervised node classification (where $ L_E $ is ``None'') for 4 neural network architectures -- Transformers, MultiLayer Perceptrons (MLPs), and Graph neural networks GCN and GraphSAGE -- evaluated on 6 label sparsified real-world attributed graph datasets.
                    The better result for each architecture is highlighted in \\textbf{bold} and the best result across all architectures is \\underline{underlined}.
                    The $ L_E $ and $ L_\\text{regularization} $ columns indicate which unsupervised clustering and regularization losses where used in addition to a supervised cross-entropy loss for training.
                    A \\: \\xmark \\: or \\: \\cmark \\: in the $ L_V $ column indicates if attributes were reconstructed, and
                    a \\: \\xmark \\: or \\: \cmark \\: in the $ L_S $ column indicates if a supervised cross-entropy loss was used for training.
                    The ``ES'' column indicates if training loss or validation MCC was used as an early stopping criterion.
                """

            short_caption = """
                    Summary results on label-sparsified real-world attributed graph datasets.
                """

            results = []

            no_header = True
            for dataset, result in summary.items():
                if len(result) == 0:
                    continue

                result = pd.concat(result).reset_index(drop=True)

                for test_metric in test_metrics:
                    test_latex = []
                    test_results = []

                    for model in result["model"].unique():
                        model_bests = result[result["model"] == model].reset_index(
                            drop=True
                        )

                        # make mean $\pm$ std column for metric
                        model_latex = (
                            model_bests[test_metric]["mean"]
                            .apply(lambda x: f"{x:.3f}")
                            .astype("string")
                            + " $\pm$ "
                            + model_bests[test_metric]["std"]
                            .apply(lambda x: f"{x:.3f}")
                            .astype("string")
                        )

                        # highlight in bold best result supervised vs semi-supervised
                        index = model_bests[test_metric]["mean"].idxmax()
                        model_latex.iloc[index] = (
                            "\\textbf{" + model_latex.iloc[index] + "}"
                        )

                        test_latex.append(model_latex)
                        test_results.append(model_bests[test_metric]["mean"])

                    test_latex = pd.concat(test_latex).to_numpy()
                    test_results = pd.concat(test_results).to_numpy()

                    # underline the best result for each dataset
                    index = test_results.argmax()
                    test_latex[index] = "\\underline{" + test_latex[index] + "}"

                    # drop mean and std columns
                    result = result.drop(columns=test_metric, level=0)

                    result[test_metric] = test_latex

                results.append(result)

            results = pd.concat(results)

            results = results.replace(latex_values)

            results = results.sort_values(
                by=["dataset", "model", "pooling"],
                ascending=[True, True, False],
            ).reset_index(drop=True)

            datasets = results["dataset"].unique().tolist()
            models = set(results["model"].unique())

            results = results[list(latex_columns.keys()) + test_metrics]
            results = results.rename(columns=latex_columns)
            results = results.rename(
                columns={
                    test_metric: latex_metric[test_metric]
                    for test_metric in test_metrics
                }
            )
            results = results.droplevel(1, axis=1)
            results = results.drop(columns=latex_columns["supervised"])

            table = results.to_latex(
                index=False,
                escape=False,
                longtable=True,
                caption=caption,
                label="tab:summary",
            )

            table = table.split("\n")

            all_models = models.copy()

            lines_added = 0
            previous_model = None
            for index, line in enumerate(table):
                first_word = line.split(" ")[0]

                if first_word in models:
                    model = first_word
                    if previous_model != model:
                        if len(models) == len(all_models) and len(datasets) > 0:
                            table.insert(index, "\\midrule")
                            table.insert(
                                index,
                                "\\multicolumn{%d}{c}{\\textbf{%s}} \\\\"
                                % (
                                    len(latex_columns) - 1 + len(test_metrics),
                                    datasets.pop(0),
                                ),
                            )
                            if lines_added > 0:
                                table.insert(index, "\\midrule")
                        else:
                            table.insert(index, "\\hdashline")

                        lines_added += 1

                        models.remove(model)

                        if len(models) == 0:
                            models = all_models.copy()

                    previous_model = model

            table = "\n".join(table)

            table = table.replace(
                "\\caption{%s}" % caption, "\\caption{%s}" % long_caption
            )
            table = table.replace(
                "\\caption[]{%s}" % caption, "\\caption[]{%s}" % short_caption
            )

            print(table)
