{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "239a0c3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from spektral.datasets import TUDataset\n",
    "from magni.src.modules.transforms import Float\n",
    "from magni.src.modules.pooling_utils import to_nx_graph\n",
    "from magni.src.modules.compare_graphs import choose_graph_metric\n",
    "from magni.src.modules.compute_graph_magnitude import median_heuristic, compute_magnitude_graph, compute_magnitude_subgraphs\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "import seaborn as sns\n",
    "from magni.src.edge_dropping_magnitude import get_scores_edge\n",
    "from magni.src.edge_dropping_magnitude import edge_pooling_magnitude_repeated\n",
    "import matplotlib as mpl\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f8fe0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset(\"ENZYMES\", transforms=[Float()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e904087d",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes = [(g.n_nodes, g.y) for g in dataset]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7295bd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "node_counts = pd.Series(n_nodes).value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "409199f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "for g in dataset[:1]:\n",
    "    X=g.x\n",
    "    A=g.a\n",
    "    g = to_nx_graph(X, A)\n",
    "    print(g)\n",
    "    dist_fn = choose_graph_metric(\"diffusion_distance\", mode=\"structure\")\n",
    "    method=\"cholesky\"\n",
    "\n",
    "    mag, ts = compute_magnitude_graph(g, ts=[1], dist_fn=dist_fn, get_weights=False, n_ts=1, scale_finding = \"ts1\", method=method)\n",
    "    print(\"SCALE: \", round(ts[0],2))\n",
    "    print(\"MAG: \", round(mag[0],2))\n",
    "    print(\"-\")\n",
    "    print(\"----------------\")\n",
    "\n",
    "    scores = []    \n",
    "    edges = list(g.edges())\n",
    "    for edge in edges:\n",
    "        step_graph = g.copy()\n",
    "        node_a, node_b = edge\n",
    "\n",
    "        # Merge node_a and node_b into node_a\n",
    "        neighbors_a = set(step_graph.neighbors(node_a)) - {node_b}\n",
    "        neighbors_b = set(step_graph.neighbors(node_b)) - {node_a}\n",
    "        merged_neighbors = neighbors_b.difference(neighbors_a)\n",
    "\n",
    "        # Connect node_a to the neighbors of both nodes\n",
    "        for neighbor in merged_neighbors:\n",
    "            step_graph.add_edge(node_a, neighbor)\n",
    "\n",
    "        # Remove node_b and its edges\n",
    "        step_graph.remove_node(node_b)\n",
    "\n",
    "        # Compute the magnitude difference\n",
    "        step_magni, _ = compute_magnitude_subgraphs(step_graph, dist_fn=dist_fn, ts=ts, get_weights=False, method=method)\n",
    "        mag_diff_this = abs(mag[0] - step_magni[0])\n",
    "        scores.append(mag_diff_this)\n",
    "\n",
    "    D = dist_fn(g)\n",
    "    A = A.todense()\n",
    "\n",
    "    def get_distance_weighted_graph(A, D):  \n",
    "        G = nx.Graph()\n",
    "        num_nodes = A.shape[0]\n",
    "        for i in range(num_nodes):\n",
    "            for j in range(i + 1, num_nodes):  \n",
    "                if A[i,j] == 1:\n",
    "                    weight = D[i,j]\n",
    "                    G.add_edge(i, j, weight=weight)\n",
    "        return G\n",
    "\n",
    "    def get_scores_weighted_graph(edges, scores):  \n",
    "        G_mag = nx.Graph()\n",
    "        for k, (i,j) in enumerate(edges):\n",
    "            G_mag.add_edge(i, j, weight=scores[k])\n",
    "        return G_mag\n",
    "    \n",
    "    G_mag = get_scores_weighted_graph(edges, scores)\n",
    "\n",
    "    pos2 = nx.spring_layout(G_mag, seed=42)\n",
    "    weights2 = [data['weight'] for u, v, data in G_mag.edges(data=True)]\n",
    "\n",
    "    # --- Plotting ---\n",
    "    def plot_graph(G, pos, weights, label, cbar=True):\n",
    "        # Normalize edge weights\n",
    "        norm = plt.Normalize(vmin=min(weights), vmax=max(weights))\n",
    "        cmap = plt.cm.plasma\n",
    "        edge_colors = [cmap(norm(w)) for w in weights]\n",
    "\n",
    "        # Draw nodes and labels\n",
    "        nx.draw_networkx_nodes(G, pos, node_size=80, node_color='w', alpha=0.8)\n",
    "        nx.draw_networkx_labels(G, pos, font_size=6, font_color='black')\n",
    "\n",
    "        # Draw edges with colormap\n",
    "        nx.draw_networkx_edges(\n",
    "            G, pos,\n",
    "            edge_color=edge_colors,\n",
    "            width=2\n",
    "        )\n",
    "\n",
    "        # Add colorbar\n",
    "        if cbar:\n",
    "            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)\n",
    "            sm.set_array([])\n",
    "            plt.colorbar(sm, ax=plt.gca(), label=label).set_ticks([])\n",
    "\n",
    "        plt.axis('off')\n",
    "\n",
    "        sns.despine()\n",
    "\n",
    "    def s_to_order(S):\n",
    "        labs = []\n",
    "        for i in range(S.shape[0]):\n",
    "            lab = np.nonzero(S[i,:])\n",
    "            labs.append(''.join([f\"{l}\" for l in lab]))\n",
    "        return labs\n",
    "\n",
    "    plt.figure(figsize=(20, 3))\n",
    "    pos = nx.spring_layout(g, seed=7)\n",
    "    for j, ratio in enumerate([0.1, 0.4, 0.6, 0.8], start=1):\n",
    "        np.random.seed(5)\n",
    "        this_edges = list(g.edges())\n",
    "        mag_score = get_scores_edge(g, edges=this_edges, original_magni=mag, dist_fn=dist_fn, ts=ts, method=method)\n",
    "        G_this = g.copy()\n",
    "        (G_this, _, nodes_removed, S, _) = edge_pooling_magnitude_repeated(G_this, dist_fn=dist_fn, ts=ts, method=method, n_steps=round(ratio * g.number_of_nodes()))\n",
    "        D_this = dist_fn(G_this)\n",
    "        A_this = nx.to_numpy_array(G_this)\n",
    "        G_this = get_distance_weighted_graph(A_this, D_this)\n",
    "        weights = [data['weight'] for u, v, data in G.edges(data=True)]\n",
    "        labs = s_to_order(S)\n",
    "        G_this = nx.relabel_nodes(G_this, {i: labs[i] for i in range(len(labs))})\n",
    "\n",
    "        weights = [data['weight'] for u, v, data in G_this.edges(data=True)]\n",
    "\n",
    "        this_edges = list(G_this.edges())\n",
    "        mag_score = get_scores_edge(G_this, edges=this_edges, original_magni=mag, dist_fn=dist_fn, ts=ts, method=method)\n",
    "        mag_this = compute_magnitude_subgraphs(G_this, dist_fn=dist_fn, ts=ts, get_weights=False, method=method)\n",
    "        G_this = get_scores_weighted_graph(this_edges, mag_score)\n",
    "        weights = [data['weight'] for u, v, data in G_this.edges(data=True)]\n",
    "        \n",
    "        pos_new = {}\n",
    "        for i, node in enumerate(G_this.nodes()):\n",
    "            pos_new[node] = pos[int(node[1:3].replace(\",\", \"\").replace(\"]\", \"\"))] \n",
    "        plt.subplot(1, 4, j)\n",
    "        if ratio < 0.7:\n",
    "            plot_graph(G_this, pos_new, weights, \"Magnitude Difference\", cbar=False)\n",
    "        else:\n",
    "            plot_graph(G_this, pos_new, weights, \"Magnitude Difference\", cbar=True)\n",
    "        plt.title(f\"Pooling Ratio {(1-ratio)*100:.0f}%\")\n",
    "\n",
    "    plt.savefig(\"../plots/magnitude_enzyme_diffusion_distance.pdf\", bbox_inches='tight')\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
