{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "149a8aad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import glob\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "from spektral.utils import degree_power, laplacian\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",GNNGNNGNNGNNGNNGNN
    "parser.add_argument(\"--path\", type=str, default=\"results/Ring/\")\n",
    "parser.add_argument(\"--scale\", action=\"store_true\")\n",
    "parser.add_argument(\"--show\", action=\"store_true\")\n",
    "parser.add_argument(\"--fmt\", default=\"png\")\n",
    "parser.add_argument(\"--methods\", type=str)\n",
    "args = parser.parse_args()\n",
    "\n",
    "\n",
    "datasets = [\"Ring\", \"barbell\", \"Sensor\"]\n",
    "paths = [f\"../src/spectral_similarity/{p}/\" for p in datasets]\n",
    "\n",
    "methods = \"overview\"\n",
    "\n",
    "if methods == \"baseline\":\n",
    "\n",
    "    names = [\"DiffPool\", \"MinCut\", \"NMF\", #\"LaPool\", \n",
    "            \"TopK\", \"SAGPool\", \"NDP\", \"Graclus\"]\n",
    "\n",
    "elif methods == \"overview\":\n",
    "    names = [\"MAG_EDGE_diffusion_distance\", \"SPREAD_EDGE_diffusion_distance\"] + [\"NDP\", \"Graclus\", \"NMF\", \"TopK\", \"SAGPool\", \"DiffPool\", \"MinCut\"]\n",
    "    names_nice = [\"MagEdgePool\", \"SpreadEdgePool\"] + [\"NDP\", \"Graclus\", \"NMF\", \"TopK\", \"SAGPool\", \"DiffPool\", \"MinCut\"]\n",
    "\n",
    "# Config\n",
    "n_rows = 3\n",
    "n_cols = len(names) + 1\n",
    "threshold = 1e-9  # Sparsification threshold\n",
    "scale = 1.1  # Scale of the whole figure\n",
    "\n",
    "plt.figure(figsize=(n_cols * 1 * scale, n_rows * 1.2 * scale))\n",
    "\n",
    "################################################################################\n",
    "# NX graphs\n",
    "################################################################################\n",
    "row = 0\n",
    "\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "unique_methods = np.unique(names_nice)\n",
    "cmap = cm.get_cmap('tab10', len(names_nice))\n",
    "\n",
    "for h, dataset in enumerate(datasets):\n",
    "    npzs = [glob.glob(paths[h] + \"{}*.npz\".format(n))[0] for n in names]\n",
    "\n",
    "    X = [np.load(f)[\"X\"] for f in npzs]\n",
    "    A = [np.load(f, allow_pickle=True)[\"A\"] for f in npzs]\n",
    "    X_pool = [np.load(f)[\"X_pool\"] for f in npzs]\n",
    "    A_pool = [np.load(f, allow_pickle=True)[\"A_pool\"] for f in npzs]\n",
    "    S = [np.load(f, allow_pickle=True)[\"S\"] for f in npzs]\n",
    "    \n",
    "    # Sparsify pooled adjacency matrices\n",
    "    for i in range(len(A_pool)):\n",
    "        A_pool[i][A_pool[i] < threshold] = 0\n",
    "\n",
    "    row += 1\n",
    "    plt.subplot(n_rows, n_cols, (row - 1) * n_cols + 1)\n",
    "    nx.draw(\n",
    "        nx.Graph(A[0]), pos=X[0][:, :2], node_size=3, node_color=\"k\", edge_color=\"#00000022\", width=0.8#, alpha=0.5\n",
    "    )\n",
    "    if h == 0:\n",
    "        plt.title(\"Original\", fontsize=8)\n",
    "    for o, col in enumerate(range(len(names))):\n",
    "        plt.subplot(n_rows, n_cols, (row - 1) * n_cols + col + 2)\n",
    "        if names[col] in [\"DiffPool\", \"MinCut\", \"TopK\", \"SAGPool\"]:\n",
    "            s = S[col]\n",
    "            pos = s.T.dot(X[col])\n",
    "        elif \"MAG\" in names[col]:\n",
    "            s = S[col]\n",
    "            #print(s)\n",
    "            pos = s.T.dot(X[col])\n",
    "            #print(pos)\n",
    "        else:\n",
    "            pos = X_pool[col]\n",
    "\n",
    "        adj = A[col]  # Assuming this is a square matrix (2D NumPy array)\n",
    "\n",
    "        G = nx.from_numpy_array(adj)\n",
    "   \n",
    "        weights = np.array([adj[u, v] for u, v in G.edges()])\n",
    "\n",
    "        edge_list = G.edges()\n",
    "\n",
    "        nx.draw(nx.Graph(A_pool[col]), pos=pos[:, :2], node_color=[cmap(o) for l in range(A_pool[col].shape[0])], node_size=3, edge_color=\"#00000022\", width=0.8)#, alpha=0.5)\n",
    "    \n",
    "        if h == 0:\n",
    "            plt.title(names_nice[col], fontsize=8, color=cmap(o))\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0.1, hspace=0.2)\n",
    "\n",
    "plt.savefig(\"../plots/\" + f\"spectral_similarity_{methods}.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc3fbeb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "path=\"../src/spectral_similarity/\"\n",
    "ranks_dict = {}\n",
    "results_list = []\n",
    "stds_list = []\n",
    "#for method in names:\n",
    "for dataset in [\"barbell\", \"torus\", \"erdosrenyi\", \"davidsensornet\", \"barabasialbert\", \"community\", \"Ring\", \"Sensor\", \"Grid2d\"]:\n",
    "    plt.close()\n",
    "    result = pd.read_csv(f\"{path}/{dataset}/{dataset}_means.csv\", header = 0, index_col=0)\n",
    "    stds = pd.read_csv(f\"{path}/{dataset}/{dataset}_stds.csv\", header = 0, index_col=0)\n",
    "\n",
    "    result = result.apply(pd.to_numeric, errors='ignore')\n",
    "\n",
    "    result = result.rename(columns={'wasserstein_distance2': 'wasserstein_distance_normalised'})\n",
    "    stds = stds.rename(columns={'wasserstein_distance2': 'wasserstein_distance_normalised'})\n",
    "    #result.columns[3] = \"wasserstein_distance_normalised\"\n",
    "\n",
    "    methods = [\"DiffPool\", \"MinCut\", \"NMF\", \"TopK\", \"SAGPool\", \"NDP\", \"Graclus\", \"MagEdgePool\", \"SpreadEdgePool\"]\n",
    "\n",
    "    result = result.loc[methods,:]\n",
    "    stds = stds.loc[methods,:]\n",
    "\n",
    "    result2 = result.div(result.quantile(0.9, axis=0), axis=1)\n",
    "    result2[result2 >1] = 1\n",
    "\n",
    "    ranks = result.rank(axis=0, ascending=True)\n",
    "    result[\"dataset\"] = dataset\n",
    "    stds[\"dataset\"] = dataset\n",
    "    ranks_dict[dataset] = ranks\n",
    "    results_list.append(result)\n",
    "    stds_list.append(stds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91316540",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "all_results = pd.concat(results_list, axis=0)\n",
    "all_results.reset_index(inplace=True)\n",
    "performance_metrics = [\"mag_diff_diffusion_distance\", \"spectral_distance2\", \"loss\"]\n",
    "performance_metrics_nice = [\"Magnitude Difference\", \"Normalised Spectral Difference\", \"Spectral Loss\"]\n",
    "dataset_reports = {}\n",
    "\n",
    "stds = pd.concat(stds_list, axis=0)\n",
    "order_index = [\"MagEdgePool\", \"SpreadEdgePool\", \n",
    "               \"NDP\", \"Graclus\", \"NMF\", \"TopK\", \"SAGPool\",\"DiffPool\", \"MinCut\"]\n",
    "stds = stds.reset_index()\n",
    "## [\"Ring\", \"barbell\", \"Sensor\"]\n",
    "\n",
    "means_plot = all_results[[\"mag_diff_diffusion_distance\", \"spectral_distance2\", \"loss\", \"index\", \"dataset\"]]\n",
    "stds_sub_plot = stds[[\"mag_diff_diffusion_distance\", \"spectral_distance2\", \"loss\", \"index\", \"dataset\"]]\n",
    "\n",
    "for j, pm in enumerate(performance_metrics):\n",
    "    stds_data = stds[[pm, \"dataset\", \"index\"]]\n",
    "    dataset_data = all_results[[pm, \"dataset\", \"index\"]]\n",
    "    means_plot_this = means_plot[[pm, \"dataset\", \"index\"]] \n",
    "    print(dataset)\n",
    "\n",
    "\n",
    "    means_plot_this = means_plot_this.pivot(index=\"index\", columns = [\"dataset\"])\n",
    "    means_ordered = means_plot_this\n",
    "\n",
    "    means_ordered = means_plot_this.loc[order_index,:]\n",
    "    means_ordered.columns = means_ordered.columns.droplevel(0)\n",
    "    means_ordered = means_ordered[[\"Ring\", \"barbell\", \"Sensor\"]]\n",
    "    stds_ordered = stds_data.pivot(index=\"index\", columns = [\"dataset\"])\n",
    "    stds_ordered = stds_ordered.loc[order_index,:]\n",
    "    stds_ordered.columns = stds_ordered.columns.droplevel(0)\n",
    "    stds_ordered = stds_ordered[[\"Ring\", \"barbell\", \"Sensor\"]]\n",
    "    result2 = means_ordered.subtract(means_ordered.quantile(0.01), axis=1).div((means_ordered.max()-means_ordered.min()), axis=1)\n",
    "    result2[result2 >1] = 1\n",
    "\n",
    "    means_and_stds = means_ordered.copy()\n",
    "    for col in means_ordered.columns:\n",
    "        means_and_stds[col] = means_ordered[col].round(1).astype(str) + \" ± \" + stds_ordered[col].round(1).astype(str)\n",
    "\n",
    "\n",
    "    plt.figure(figsize=(3, 3))\n",
    "    sns.heatmap(result2, annot=means_and_stds, cmap=\"rocket_r\", cbar=False, fmt='')\n",
    "    plt.title(f\"{performance_metrics_nice[j]}\")\n",
    "    #plt.colorbar()\n",
    "    plt.xlabel(\"\")\n",
    "    plt.ylabel(\"\")\n",
    "    plt.savefig(f\"../plots/{pm}_overview.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "    display(means_ordered)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
GNNGNNGNNGNNGNNGNNGNNGNNGNNGNN