{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63db1be4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "from magni.src.graph_classification.jsons_to_csvs import json_to_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19b4f4fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"ENZYMES\"\n",
    "datasets = [\"NCI1\", \"ENZYMES\", \"IMDB-MULTI\"] \n",
    "\n",
    "model = \"GIN\"\n",
    "path = \"../src/graph_classification/results/\"\n",
    "dataset_path = f\"{path}/{dataset}\"\n",
    "\n",
    "ratios = [0.0625, 0.125, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] \n",
    "\n",
    "methods = [\n",
    "    \"MAG_EDGE_diffusion_distance\", \"SPREAD_EDGE_diffusion_distance\", 'NMF',\n",
    "      'NDP', 'TopK', 'SAGPool',  'Graclus', \"DiffPool\", \"MinCut\"]\n",
    "\n",
    "models = [\"GNN\", \"GIN\"] \n",
    "\n",
    "all_dfs = []\n",
    "all_dfs_mag = []\n",
    "for dataset in datasets:\n",
    "    dataset_path = f\"{path}/{dataset}\"\n",
    "    #print(dataset_path)\n",
    "    if not os.path.exists(dataset_path):\n",
    "        print(f\"Path {dataset_path} does not exist.\")\n",
    "    else:\n",
    "        methods_found = []\n",
    "        dfs_means = []\n",
    "        dfs_stds = []\n",
    "        all_acc = {}\n",
    "        acc_json = f\"{path}/{dataset}/{dataset}_{model}_accuracies.json\"\n",
    "        for model in models:\n",
    "            for method in methods:\n",
    "                for ratio in ratios:\n",
    "                    print(f\"Method: {method}\")\n",
    "                    rocs = []\n",
    "                    balanced_accs = []\n",
    "                    avg_precs= []\n",
    "                    #\n",
    "                    if ratio == 0.5:\n",
    "                        json_path = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified.json\"\n",
    "                    else:\n",
    "                        json_path = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified_ratio_{str(round(ratio, 3))}.json\"\n",
    "                    mag_paths = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified_ratio_{str(round(ratio, 3))}_mag.json\"\n",
    "                    df = json_to_df(json_path, \"results\")\n",
    "                    if model in [\"GNN\", \"GIN\"]:\n",
    "                        try:\n",
    "                            df_mag = json_to_df(mag_paths, \"dataset\")\n",
    "                            df_mag = pd.DataFrame(df_mag)\n",
    "                        except:\n",
    "                            df_mag = None\n",
    "                    else:\n",
    "                        df_mag = None\n",
    "                    if df is None:\n",
    "                        print(f\"JSON file {json_path} not found or empty.\")\n",
    "                        continue\n",
    "                    else:\n",
    "                        df[\"ratio\"] = ratio\n",
    "                        df[\"method\"] = method\n",
    "                        df[\"model\"] = model\n",
    "                        df[\"dataset\"] = dataset\n",
    "                        all_dfs.append(df)\n",
    "                    if df_mag is None:\n",
    "                        print(f\"JSON file {json_path} not found or empty.\")\n",
    "                        continue\n",
    "                    else:\n",
    "                        df_mag[\"ratio\"] = ratio\n",
    "                        df_mag[\"method\"] = method\n",
    "                        df_mag[\"model\"] = model\n",
    "                        df_mag[\"dataset\"] = dataset\n",
    "                        all_dfs_mag.append(df_mag)\n",
    "\n",
    "all_dfs = pd.concat(all_dfs)\n",
    "try:\n",
    "    all_dfs_mag = pd.concat(all_dfs_mag)\n",
    "except:\n",
    "    all_dfs_mag = pd.DataFrame()\n",
    "    print(\"No MAG results found.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30ae67d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_dfs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b97a876a",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_dfs[[\"ratio\", \"method\", \"model\", \"dataset\"]].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbc18b6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_dfs[[\"method\"]].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e0761a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_to_plot = all_dfs[[\"ratio\", \"method\", \"model\", \"accuracy\", \"dataset\"]].reset_index().copy()#.groupby([\"ratio\", \"method\", \"model\"])[\"accuracy\"].mean().reset_index()\n",
    "name_to = {'MAG_EDGE_diffusion_distance': 'MagEdgePool',\n",
    "                                     'SPREAD_EDGE_diffusion_distance': 'SpreadEdgePool',\n",
    "                                     \"Flat\": \"No Pooling\",\n",
    "                                        \"TopK\": \"TopK\",\n",
    "                                        \"SAGPool\": \"SAGPool\",\n",
    "                                        \"Graclus\": \"Graclus\",\n",
    "                                        \"NMF\": \"NMF\",\n",
    "                                        \"NDP\": \"NDP\",\n",
    "                                        \"DiffPool\": \"DiffPool\",\n",
    "                                        \"MinCut\": \"MinCut\",\n",
    "                                     }\n",
    "\n",
    "df_to_plot[\"method\"] = df_to_plot[\"method\"].map(name_to)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c14e998b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "sort_methods = ['MagEdgePool', 'SpreadEdgePool', \n",
    "                'NDP', 'Graclus', \n",
    "                'NMF', 'TopK', 'SAGPool', \"DiffPool\", \"MinCut\"]\n",
    "\n",
    "model_subsets = [ \n",
    "                ['MagEdgePool', 'SpreadEdgePool', 'NDP', 'Graclus'], ['MagEdgePool', 'SpreadEdgePool', 'TopK', 'SAGPool', \"DiffPool\", \"MinCut\"]]\n",
    "cutoff_ratio = [0.5, 0.9]\n",
    "subset_names = [\"not_trainable\", \"trainable\"]\n",
    "for dataset in datasets:\n",
    "    for p, subsetm in enumerate(model_subsets):\n",
    "        df_this_plot = df_to_plot[df_to_plot[\"dataset\"] == dataset]\n",
    "        max_acc = df_this_plot[\"accuracy\"].max()\n",
    "        min_acc = df_this_plot[\"accuracy\"].quantile(0.02)\n",
    "        df_this_plot = df_this_plot.iloc[[(mi in subsetm) for mi in df_this_plot[\"method\"]],:]\n",
    "        df_this_plot = df_this_plot[df_this_plot[\"ratio\"].isin(ratios)] \n",
    "        df_this_plot = df_this_plot[df_this_plot[\"ratio\"]<=cutoff_ratio[p]] \n",
    "        for m in [\"GIN\"]:\n",
    "            this_df = df_this_plot[df_this_plot[\"model\"] == m]\n",
    "            this_df[\"method\"] = pd.Categorical(this_df[\"method\"], categories=sort_methods, ordered=True)\n",
    "            for c in [\"accuracy\"]:\n",
    "\n",
    "                # Plot with Seaborn\n",
    "                plt.figure(figsize=(3, 3))\n",
    "                this_means = this_df.groupby([\"ratio\", \"method\"])[c].mean().reset_index()\n",
    "                sns.lineplot(data=this_df, x='ratio', y=c, errorbar='sd', \n",
    "                            palette='tab10', hue=\"method\", legend=False)\n",
    "                sns.scatterplot(data=this_means, x='ratio', y=c,\n",
    "                                palette='tab10', hue=\"method\", s=30, alpha=0.8)\n",
    "                plt.xlabel(\"pooling ratio\")\n",
    "                plt.axhline(min_acc, color='white', lw=1, alpha=0.)\n",
    "                plt.axhline(max_acc, color='white', lw=1, alpha=0)\n",
    "                plt.axvline(cutoff_ratio[p], color='white', lw=1, alpha=0.)\n",
    "                \n",
    "                plt.gca().invert_xaxis()\n",
    "                plt.title(f'{m} {dataset}')\n",
    "                #plt.axhline(0, color='black', lw=1, alpha=0.5)\n",
    "                plt.legend(title=\"Method\", loc='upper left', bbox_to_anchor=(1, 1), frameon=False)\n",
    "                sns.despine()\n",
    "\n",
    "                plt.savefig(f\"../plots/{dataset}_{m}_{c}_{subset_names[p]}.pdf\", bbox_inches='tight', dpi=500)\n",
    "                plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3885603d",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_dfs_mag[\"relative_mag_diffs\"] = all_dfs_mag[\"mag_diffs\"]/all_dfs_mag[\"magnitude\"]\n",
    "all_dfs_mag[\"relative_spread_diffs\"] = all_dfs_mag[\"spread_diffs\"]/all_dfs_mag[\"spread\"]\n",
    "all_dfs_mag[\"mag_spread_diffs\"] = all_dfs_mag[\"mag_diffs\"] - all_dfs_mag[\"spread_diffs\"]\n",
    "#all_dfs_mag[\"relative_mag_diffs\"].apply(lambda x: np.mean(x) if isinstance(x, list) else x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2327ed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "(all_dfs_mag[\"n_nodes_sub\"] / all_dfs_mag[\"n_nodes\"]).groupby(all_dfs_mag[\"ratio\"]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64232371",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_dfs_mag[\"n_nodes_sub\"].groupby(all_dfs_mag[\"ratio\"]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa1480d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_dfs_mag[\"method\"] = all_dfs_mag[\"method\"].map(name_to)\n",
    "all_dfs[\"method\"] = all_dfs[\"method\"].map(name_to)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "987fc088",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "all_model_names = [\"MagEdgePool\", \"SpreadEdgePool\", \n",
    "                                          \"NDP\", \"Graclus\", \"NMF\",\n",
    "                                          \"TopK\", \"SAGPool\", \"DiffPool\", \"MinCut\"]\n",
    "\n",
    "to_plot = all_dfs_mag.copy()\n",
    "to_plot = to_plot[to_plot[\"method\"].isin(all_model_names)]\n",
    "not_trained = [\"MagEdgePool\", \"SpreadEdgePool\", #\"SpreadEdgeApprox\",#\"DistEdge\",\n",
    "                \"NDP\", \"Graclus\", \"NMF\"]\n",
    "to_plot.method = pd.Categorical(to_plot[\"method\"], categories=all_model_names, ordered=True)\n",
    "\n",
    "to_plot[\"average_magnitude\"] = (to_plot[\"magnitude\"]-to_plot[\"mag_diffs\"])#/to_plot[\"n_nodes\"]#.mean()\n",
    "to_plot[\"relative_mag_decrease\"] = 1-(to_plot[\"mag_diffs\"])/to_plot[\"magnitude\"]\n",
    "to_plot[\"relative_spread_decrease\"] = 1-(to_plot[\"spread_diffs\"])/to_plot[\"spread\"]\n",
    "\n",
    "to_plot[\"magnitude_minus_spread\"] = to_plot[\"magnitude\"] - to_plot[\"spread\"]\n",
    "to_plot[\"magnitude_minus_spread_pooled\"] = to_plot[\"magnitude_pooled\"] - to_plot[\"magnitude_pooled\"]\n",
    "\n",
    "nice_names = [\n",
    "    \"relative magnitude difference\\n with original graphs\", \"relative spread difference\\n with original graphs\", \n",
    "              \"normalised spectral difference\\n with original graphs\", \"relative difference in magnitude\\n compared to original graphs\",\n",
    "              ]\n",
    "\n",
    "for d in to_plot[\"dataset\"].unique():\n",
    "    for m in [\"GNN\"]:\n",
    "        # to_plot[\"model\"].unique()\n",
    "        for p, subsetm in enumerate(model_subsets):\n",
    "            to_plot1 = to_plot[to_plot[\"dataset\"] == d]\n",
    "            to_plot1 = to_plot1[to_plot1[\"ratio\"].isin(ratios)]\n",
    "            to_plot1 = to_plot1[to_plot1[\"ratio\"] <= cutoff_ratio[p]]\n",
    "            to_plot1 = to_plot1[to_plot1[\"model\"] == m]\n",
    "            to_plot1 = to_plot1[to_plot1[\"method\"].isin(subsetm)]\n",
    "            for j, c in enumerate([ \n",
    "                \"relative_mag_diffs\", \"relative_spread_diffs\", \n",
    "                'spectral_distance_normalized', \"relative_mag_decrease\", \n",
    "                ]):\n",
    "\n",
    "                # Plot with Seaborn\n",
    "                plt.figure(figsize=(3, 3))\n",
    "                \n",
    "                sns.lineplot(data=to_plot1, x='ratio', y=c, errorbar='sd', palette='tab10', hue=\"method\", legend=False)\n",
    "\n",
    "                means = to_plot1.groupby([\"ratio\", \"method\"])[c].mean().reset_index()\n",
    "                sns.scatterplot(data=means, x='ratio', y=c, palette='tab10', hue=\"method\", alpha = 0.6)\n",
    "                plt.gca().invert_xaxis()\n",
    "                plt.title(f'{d} {m}')\n",
    "                plt.ylabel(nice_names[j])\n",
    "                plt.xlabel(\"pooling ratio\")\n",
    "                plt.legend(title=\"Method\", loc='upper left', bbox_to_anchor=(1, 1), frameon=False)\n",
    "                sns.despine()\n",
    "                plt.savefig(f\"../plots/structural_properties_{d}_{c}_{subset_names[p]}.pdf\", bbox_inches='tight', dpi=500)\n",
    "                plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09752182",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "to_plot = all_dfs_mag.copy()\n",
    "for d in to_plot[\"dataset\"].unique():\n",
    "    to_plot1 = to_plot[to_plot[\"dataset\"] == d]\n",
    "    to_plot1 = to_plot1[to_plot1[\"ratio\"].isin(ratios)]\n",
    "\n",
    "    print(to_plot1[[\"magnitude\",\"spread\"]].corr())\n",
    "\n",
    "    plt.figure(figsize=(3, 3))\n",
    "\n",
    "\n",
    "    sns.scatterplot(data=to_plot1, x='magnitude_pooled', y=\"spread_pooled\", color=\"k\", legend=False, alpha=0.5, s=10)\n",
    "    plt.title(f'Magnitude vs. Spread {d}')\n",
    "    sns.despine()\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure(figsize=(3, 3))\n",
    "\n",
    "    sns.scatterplot(data=to_plot1, x='magnitude', y=\"spread\", color=\"k\", legend=False, alpha=0.5, s=10)\n",
    "\n",
    "    plt.title(f'Magnitude vs. Spread {d}')\n",
    "    sns.despine()\n",
    "    plt.savefig(f\"../plots/magnitude_vs_spread_{d}.pdf\", bbox_inches='tight', dpi=100)\n",
    "    plt.savefig(f\"../plots/magnitude_vs_spread_{d}.png\", bbox_inches='tight', dpi=500)\n",
    "\n",
    "    hist = (to_plot1.magnitude/to_plot1.spread)\n",
    "    plt.figure(figsize=(4, 3))\n",
    "    sns.histplot(hist, color=\"k\", legend=False)\n",
    "    plt.title(f'Magnitude vs. Spread {d}')\n",
    "    plt.xlabel(\"Mag(G)/Sp(G)\")\n",
    "    plt.ylabel(\"number of graphs\")\n",
    "    sns.despine()\n",
    "    plt.savefig(f\"../plots/magnitude_vs_spread_hist_{d}.pdf\", bbox_inches='tight', dpi=100)\n",
    "    plt.savefig(f\"../plots/magnitude_vs_spread_hist_{d}.png\", bbox_inches='tight', dpi=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d63db01",
   "metadata": {},
   "outputs": [],
   "source": [
    "hist = (to_plot1.magnitude/to_plot1.spread)\n",
    "plt.figure(figsize=(4, 3))\n",
    "sns.histplot(hist, color=\"k\", legend=False, kde=True)\n",
    "plt.title(f'Magnitude vs. Spread {d}')\n",
    "plt.xlabel(\"Mag(G)/Sp(G)\")\n",
    "plt.ylabel(\"number of graphs\")\n",
    "sns.despine()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd326dba",
   "metadata": {},
   "outputs": [],
   "source": [
    "to_plot1[[\"magnitude\",\"spread\"]].corr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03aceceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "to_plot = all_dfs_mag.copy()\n",
    "to_plot = to_plot[to_plot[\"method\"].isin([\"MagEdgePool\", \"SpreadEdgePool\", \n",
    "                                          \"NDP\", \"Graclus\", #\"NMF\", ,,   \"MinCut\"\n",
    "                                          \"TopK\", \"SAGPool\"])] #DiffPool\",\n",
    "to_plot.method = pd.Categorical(to_plot[\"method\"], categories=[\"MagEdgePool\", \"SpreadEdgePool\",  \n",
    "                                                               \"NDP\", \"Graclus\", \"NMF\", \"TopK\", \"SAGPool\", \"DiffPool\", \"MinCut\"], ordered=True)\n",
    "\n",
    "to_plot[\"average_magnitude\"] = (to_plot[\"magnitude\"]-to_plot[\"mag_diffs\"])#/to_plot[\"n_nodes\"]#.mean()\n",
    "to_plot[\"relative_mag_decrease\"] = (-to_plot[\"mag_diffs\"])/to_plot[\"magnitude\"]\n",
    "\n",
    "\n",
    "d == \"NCI1\"\n",
    "to_plot1 = to_plot[to_plot[\"dataset\"] == d]\n",
    "to_plot1 = to_plot1[to_plot1[\"ratio\"].isin([0.5])]\n",
    "to_plot2 = to_plot1\n",
    "to_plot2 = to_plot2[to_plot2[\"model\"] == \"GNN\"]\n",
    "to_plot2[\"relative_mag_difference\"] = (to_plot2[\"mag_diffs\"])/to_plot2[\"magnitude\"]\n",
    "to_plot2.method = pd.Categorical(to_plot2[\"method\"], categories=[\"MagEdgePool\", \"SpreadEdgePool\", \n",
    "                                                               \"NDP\", \"Graclus\", #\"NMF\", \n",
    "                                                               \"TopK\", \"SAGPool\"#, \"DiffPool\", \"MinCut\"\n",
    "                                                               ], ordered=True)\n",
    "to_plot2.method = to_plot2.method.map({\"MagEdgePool\": \"MagEdge\", \"SpreadEdgePool\": \"SpreadEdge\",\n",
    "                                       \"NDP\": \"NDP\", \"Graclus\": \"Graclus\", \"TopK\": \"TopK\", \"SAGPool\": \"SAGPool\",\n",
    "                                       }) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd067037",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5, 3.5))\n",
    "plt.title(f'Relative magnitude difference\\nat pooling ratio 0.5', fontsize=18)\n",
    "sns.violinplot(data=to_plot2, x='method', y=\"relative_mag_difference\", hue=\"method\", legend=False)\n",
    "plt.xlabel(\"\")\n",
    "plt.xticks(fontsize=8)\n",
    "plt.ylabel(\"relative magnitude difference\\n between pooled and original graphs\", fontsize=10)\n",
    "sns.despine()\n",
    "plt.savefig(f\"../plots/relative_mag_diff_{d}_0.5.svg\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27fadb2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5, 3.5))\n",
    "sns.violinplot(data=to_plot2, x='method', y=\"spectral_distance_normalized\", hue=\"method\", legend=False)\n",
    "plt.title(f'Normalised spectral distance\\nat pooling ratio 0.5', fontsize=18)\n",
    "plt.xlabel(\"\")\n",
    "plt.ylabel(\"normalised spectral distance\\nbetween pooled and original graphs\", fontsize=10)\n",
    "plt.xlabel(\"\")\n",
    "plt.xticks(fontsize=8)\n",
    "sns.despine()\n",
    "plt.savefig(f\"../plots/normalised_spectral_{d}_0.5.svg\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90efecb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "to_plot = all_dfs_mag.copy()\n",
    "to_plot = to_plot[to_plot[\"method\"].isin([\"MagEdgePool\", \"SpreadEdgePool\",  \n",
    "                                          \"NDP\", \"Graclus\", #\"NMF\", ,,   \"MinCut\"\n",
    "                                          \"TopK\", \"SAGPool\"])] \n",
    "to_plot.method = pd.Categorical(to_plot[\"method\"], categories=[\"MagEdgePool\", \"SpreadEdgePool\", \n",
    "                                                               \"NDP\", \"Graclus\", \"NMF\", \"TopK\", \"SAGPool\", \"DiffPool\", \"MinCut\"], ordered=True)\n",
    "\n",
    "to_plot[\"average_magnitude\"] = (to_plot[\"magnitude\"]-to_plot[\"mag_diffs\"])\n",
    "to_plot[\"relative_mag_decrease\"] = (-to_plot[\"mag_diffs\"])/to_plot[\"magnitude\"]\n",
    "\n",
    "\n",
    "\n",
    "for d in [\"NCI1\", \"ENZYMES\"]:\n",
    "    to_plot1 = to_plot[to_plot[\"dataset\"] == d]\n",
    "    to_plot1 = to_plot1[to_plot1[\"ratio\"].isin(ratios)]\n",
    "\n",
    "    to_plot2 = to_plot1[to_plot1[\"method\"] == \"MagEdgePool\"]\n",
    "    to_plot2 = to_plot2[to_plot2[\"model\"] == \"GNN\"]\n",
    "    plt.figure(figsize=(3, 3))\n",
    "\n",
    "    \n",
    "    sns.scatterplot(data=to_plot2, x='magnitude_pooled', y=\"spread_pooled\", color=\"k\", legend=False, alpha=0.5, s=10)\n",
    "    plt.title(f'Magnitude vs. Spread {d}')\n",
    "    sns.despine()\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure(figsize=(3, 3))\n",
    "    \n",
    "    sns.scatterplot(data=to_plot2, x='magnitude', y=\"spread\", color=\"k\", legend=False, alpha=0.5, s=10)\n",
    "\n",
    "    plt.title(f'Magnitude vs. Spread {d}')\n",
    "    sns.despine()\n",
    "    if not os.path.exists(f\"../plots/{d}/\"):\n",
    "        os.makedirs(f\"../plots/{d}/\")\n",
    "        \n",
    "    plt.savefig(f\"../plots/{d}/magnitude_vs_spectral_{d}.pdf\", bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "    # Plot with Seaborn\n",
    "    plt.figure(figsize=(4, 3)) #palette='tab10',\n",
    "        \n",
    "    sns.scatterplot(data=to_plot1, y='mag_diffs', x=\"spectral_distance_normalized\", hue=\"method\", legend=False, alpha=0.2, s=10)\n",
    "\n",
    "    plt.ylim(min(to_plot.mag_diffs),np.quantile(to_plot.mag_diffs, 0.998))\n",
    "    plt.xlim(min(to_plot.spectral_distance_normalized),max(to_plot.spectral_distance_normalized))\n",
    "    plt.title(f'{d}')\n",
    "    sns.despine()\n",
    "\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    plt.figure(figsize=(4, 3))\n",
    "    #sns.scatterplot(data=to_plot1, y='mag_diffs', x=\"ratio\", hue=\"method\", legend=False, alpha=0.2, s=10)\n",
    "    sns.lineplot(data=to_plot1, x='ratio', y='mag_diffs', errorbar=None, palette='tab10', hue=\"method\", legend=False)\n",
    "    \n",
    "    means = to_plot1.groupby([\"ratio\", \"method\"])[\"mag_diffs\"].mean().reset_index()\n",
    "    sns.scatterplot(data=means, x='ratio', y='mag_diffs', palette='tab10', hue=\"method\", alpha = 0.6, legend=False)\n",
    "\n",
    "    plt.ylabel(\"magnitude difference\")\n",
    "    #plt.ylim(min(to_plot.mag_diffs),np.quantile(to_plot.mag_diffs, 0.998))\n",
    "    plt.xlim(min(to_plot.ratio),max(to_plot.ratio))\n",
    "    plt.gca().invert_xaxis()\n",
    "    #plt.title(f'{d}')\n",
    "    plt.title(f'Magnitude Difference')\n",
    "    sns.despine()\n",
    "\n",
    "    if not os.path.exists(f\"../plots/{d}/\"):\n",
    "            os.makedirs(f\"../plots/{d}/\")\n",
    "    plt.xlabel(\"pooling ratio\")\n",
    "    plt.savefig(f\"../plots/{d}/magnitude_vs_ratio_{d}.pdf\", bbox_inches='tight')\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure(figsize=(4, 3))\n",
    "    #sns.scatterplot(data=to_plot1, y='spectral_distance_normalized', x=\"ratio\", hue=\"method\", legend=False, alpha=0.2, s=10)\n",
    "    sns.lineplot(data=to_plot1, x='ratio', y='spectral_distance_normalized', errorbar=None, palette='tab10', hue=\"method\", legend=False)\n",
    "\n",
    "    means = to_plot1.groupby([\"ratio\", \"method\"])[\"spectral_distance_normalized\"].mean().reset_index()\n",
    "    sns.scatterplot(data=means, x='ratio', y='spectral_distance_normalized', palette='tab10', hue=\"method\", alpha = 0.6, legend=False)\n",
    "    plt.ylabel(\"normalised spectral distance\")\n",
    "    #plt.ylim(min(to_plot.spectral_distance_normalized),np.quantile(to_plot.spectral_distance_normalized, 0.998))\n",
    "    plt.xlim(min(to_plot.ratio),max(to_plot.ratio))\n",
    "    plt.gca().invert_xaxis()\n",
    "    #else:\n",
    "    plt.title(f'Normalised Spectral Distance')\n",
    "    #plt.title(f'{d}')\n",
    "    plt.xlabel(\"pooling ratio\")\n",
    "    sns.despine()\n",
    "    \n",
    "    if not os.path.exists(f\"../plots/{d}/\"):\n",
    "        os.makedirs(f\"../plots/{d}/\")\n",
    "        \n",
    "    plt.savefig(f\"../plots/{d}/spectral_vs_ratio_{d}.pdf\", bbox_inches='tight')\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "    for m in to_plot[\"method\"].unique():\n",
    "        to_plot1 = to_plot[to_plot[\"dataset\"] == d]\n",
    "        to_plot1 = to_plot1[to_plot1[\"ratio\"].isin(ratios)]\n",
    "        to_plot1 = to_plot1[to_plot1[\"method\"] == m]\n",
    "        # Plot with Seaborn\n",
    "\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        sns.scatterplot(data=to_plot1, y='mag_diffs', x=\"ratio\", hue=\"method\", legend=False, alpha=0.2, s=10)\n",
    "        sns.lineplot(data=to_plot1, x='ratio', y='mag_diffs', errorbar='sd', palette='tab10', hue=\"method\", legend=False)\n",
    "\n",
    "        plt.ylim(min(to_plot.mag_diffs),np.quantile(to_plot.mag_diffs, 0.998))\n",
    "        plt.xlim(min(to_plot.ratio),max(to_plot.ratio))\n",
    "        plt.gca().invert_xaxis()\n",
    "        plt.ylabel(\"magnitude difference\")\n",
    "        plt.xlabel(\"pooling ratio\")\n",
    "        #plt.title(f'{d} {m}')\n",
    "        plt.title(f'{m}')\n",
    "        sns.despine()\n",
    "        if not os.path.exists(f\"../plots/{d}/\"):\n",
    "            os.makedirs(f\"../plots/{d}/\")\n",
    "        \n",
    "        plt.savefig(f\"../plots/{d}/magnitude_vs_ratio_{d}_{m}.pdf\", bbox_inches='tight')\n",
    "        plt.show()\n",
    "\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        sns.scatterplot(data=to_plot1, y='spectral_distance_normalized', x=\"ratio\", hue=\"method\", legend=False, alpha=0.2, s=10)\n",
    "        sns.lineplot(data=to_plot1, x='ratio', y='spectral_distance_normalized', errorbar='sd', palette='tab10', hue=\"method\", legend=False)\n",
    "        plt.ylabel(\"normalised spectral distance\")\n",
    "        plt.xlabel(\"pooling ratio\")\n",
    "        plt.ylim(min(to_plot.spectral_distance_normalized),np.quantile(to_plot.spectral_distance_normalized, 0.998))\n",
    "        plt.xlim(min(to_plot.ratio),max(to_plot.ratio))\n",
    "        plt.gca().invert_xaxis()\n",
    "        #else:\n",
    "        #plt.title(f'{d} {m}')\n",
    "        plt.title(f'{m}')\n",
    "        sns.despine()\n",
    "        if not os.path.exists(f\"../plots/{d}/\"):\n",
    "            os.makedirs(f\"../plots/{d}/\")\n",
    "        \n",
    "        plt.savefig(f\"../plots/{d}/spectral_vs_ratio_{d}_{m}.pdf\", bbox_inches='tight')\n",
    "        plt.show()\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
