{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab9ee274",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2860baf",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"MUTAG\", \"COX2\", \"ENZYMES\", \"PROTEINS\", \"Mutagenicity\", \"AIDS\", \"IMDB-BINARY\", \"IMDB-MULTI\", \"NCI1\", \"NCI109\", \n",
    "            \"BZR\", \"DHFR\", \"ogbg-molhiv\", 'BZR_MD', 'COX2_MD', 'DHFR_MD', 'ER_MD', \"PROTEINS_full\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00746bfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_per_dataset = {}\n",
    "for i, d in enumerate(datasets):\n",
    "    try:\n",
    "        data = pd.read_csv(f\"../src/graph_classification/results/{d}/{d}_GNN_means.csv\"\n",
    "                        , index_col=0, header=0)\n",
    "        std = pd.read_csv(f\"../src/graph_classification/results/{d}/{d}_GNN_stds.csv\"\n",
    "                        , index_col=0, header=0)\n",
    "        \n",
    "        with open(f\"../src/graph_classification/results/{d}/{d}_GNN_accuracies.json\", \"r\") as f:\n",
    "            acc_per_dataset[d] = json.load(f)\n",
    "\n",
    "        data = data[[\"loss\", \"accuracy\"]]\n",
    "        std = std[[\"loss\", \"accuracy\"]]\n",
    "        #data = data[[]]\n",
    "        data.columns = [\"test_loss_\"+d, \"test_accuracy_\"+d]\n",
    "        std.columns = [\"test_loss_\"+d, \"test_accuracy_\"+d]\n",
    "        if i == 0:\n",
    "            results = data\n",
    "            stds = std\n",
    "        else:\n",
    "            results = pd.merge(results, data, left_index=True, right_index=True, how =\"outer\")\n",
    "            stds = pd.merge(stds, std, left_index=True, right_index=True, how =\"outer\")\n",
    "    except:\n",
    "        print(f\"Error with {d}\")\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b0f04ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from magni.src.graph_classification.jsons_to_csvs import get_method_names, json_to_df\n",
    "\n",
    "#datasets = [\"NCI1\"]\n",
    "datasets = [\"ENZYMES\"] #n#, \"PROTEINS\", \"MUTAG\", \"COX2\", \"Mutagenicity\", \"NCI109\", \"NCI1\"]\n",
    "ratios = [0.0625, #0.125, \n",
    "          0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] #, 0.95]\n",
    "acc_per_dataset = {}\n",
    "model = \"GNN\"\n",
    "path = \"../src/graph_classification/results/\"\n",
    "\n",
    "methods = get_method_names()\n",
    "\n",
    "mm = []\n",
    "dd = []\n",
    "rr = []\n",
    "tt = []\n",
    "tti = []\n",
    "for i, dataset in enumerate(datasets):\n",
    "    for method in methods:\n",
    "        for ratio in ratios:\n",
    "            json_path = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified_ratio_{str(round(ratio, 3))}.json\"\n",
    "            \n",
    "            try:\n",
    "                timei = json_to_df(json_path, key = \"time\")\n",
    "                timet = json_to_df(json_path, key = \"time_per_run\")\n",
    "                mm.append(method)\n",
    "                dd.append(dataset)\n",
    "                rr.append(ratio)\n",
    "                tt.append(timet)\n",
    "                tti.append(timei)\n",
    "            except:\n",
    "                if ratio == 0.5:\n",
    "                    #json_path = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified.json\"\n",
    "                    json_path = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified.json\"\n",
    "                \n",
    "                    try:\n",
    "                        timei = json_to_df(json_path, key = \"time\")\n",
    "                        timet = json_to_df(json_path, key = \"time_per_run\")\n",
    "                        mm.append(method)\n",
    "                        dd.append(dataset)\n",
    "                        rr.append(0.5)\n",
    "                        tt.append(timet)\n",
    "                        tti.append(timei)\n",
    "                    except:\n",
    "                        print(f\"Error with {d}\")\n",
    "                        continue\n",
    "                print(f\"Error with {d}\")\n",
    "                continue\n",
    "            if timei is None:\n",
    "                if ratio == 0.5:\n",
    "                    #json_path = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified.json\"\n",
    "                    json_path = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified.json\"\n",
    "                \n",
    "                    try:\n",
    "                        timei = json_to_df(json_path, key = \"time\")\n",
    "                        timet = json_to_df(json_path, key = \"time_per_run\")\n",
    "                        mm.append(method)\n",
    "                        dd.append(dataset)\n",
    "                        rr.append(0.5)\n",
    "                        tt.append(timet)\n",
    "                        tti.append(timei)\n",
    "                    except:\n",
    "                        print(f\"Error with {d}\")\n",
    "                        continue\n",
    "\n",
    "df = pd.DataFrame({\"method\": mm, \"dataset\": dd, \"ratio\": rr, \"time_per_run\": tt, \"time\": tti})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a36fccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40efae04",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "methods = [\n",
    "    \"MAG_EDGE_diffusion_distance\", \"SPREAD_EDGE_diffusion_distance\", 'NMF',\n",
    "      'NDP', 'TopK', 'SAGPool',  'Graclus', \"DiffPool\", \"MinCut\"]\n",
    "\n",
    "df_to_plot = df[df[\"method\"].isin(methods)]\n",
    "name_to = {\n",
    "                                     'MAG_EDGE_diffusion_distance': 'MagEdgePool',\n",
    "                                     'SPREAD_EDGE_diffusion_distance': 'SpreadEdgePool',\n",
    "                                     \"Flat\": \"No Pooling\",\n",
    "                                        \"TopK\": \"TopK\",\n",
    "                                        \"SAGPool\": \"SAGPool\",\n",
    "                                        \"Graclus\": \"Graclus\",\n",
    "                                        \"NMF\": \"NMF\",\n",
    "                                        \"NDP\": \"NDP\",\n",
    "                                        \"DiffPool\": \"DiffPool\",\n",
    "                                        \"MinCut\": \"MinCut\",\n",
    "                                     }\n",
    "sort_methods = ['MagEdgePool', 'SpreadEdgePool', \n",
    "                'NDP', 'Graclus', \n",
    "                'NMF', 'TopK', 'SAGPool', #\n",
    "                ]\n",
    "\n",
    "df_to_plot[\"method\"] = df_to_plot[\"method\"].map(name_to)\n",
    "df_to_plot[\"method\"] = pd.Categorical(df_to_plot[\"method\"], categories=sort_methods, ordered=True)\n",
    "sns.lineplot(data=df_to_plot, x=\"ratio\", y=\"time_per_run\", hue=\"method\", style=\"dataset\", markers=True, dashes=False)\n",
    "plt.gca().invert_xaxis()\n",
    "plt.xlabel(\"pooling ratio\")\n",
    "plt.ylabel(\"time in seconds\")\n",
    "plt.legend(title=\"Method\", bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False)\n",
    "sns.despine()\n",
    "plt.savefig(f\"../plots/{model}_{d}_pooling_time_per_run.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38295047",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "methods = [\n",
    "    \"MAG_EDGE_diffusion_distance\", \"SPREAD_EDGE_diffusion_distance\", 'NMF',\n",
    "      'NDP', 'TopK', 'SAGPool',  'Graclus', \"DiffPool\", \"MinCut\"]\n",
    "\n",
    "df_to_plot = df[df[\"method\"].isin(methods)]\n",
    "name_to = {\n",
    "                                     'MAG_EDGE_diffusion_distance': 'MagEdgePool',\n",
    "                                     'SPREAD_EDGE_diffusion_distance': 'SpreadEdgePool',\n",
    "                                     \"Flat\": \"No Pooling\",\n",
    "                                        \"TopK\": \"TopK\",\n",
    "                                        \"SAGPool\": \"SAGPool\",\n",
    "                                        \"Graclus\": \"Graclus\",\n",
    "                                        \"NMF\": \"NMF\",\n",
    "                                        \"NDP\": \"NDP\",\n",
    "                                        \"DiffPool\": \"DiffPool\",\n",
    "                                        \"MinCut\": \"MinCut\",\n",
    "                                     }\n",
    "sort_methods = ['MagEdgePool', 'SpreadEdgePool', \n",
    "                'NDP', 'Graclus', \n",
    "                'NMF', 'TopK', 'SAGPool', #\n",
    "                ]\n",
    "\n",
    "df_to_plot[\"method\"] = df_to_plot[\"method\"].map(name_to)\n",
    "df_to_plot[\"method\"] = pd.Categorical(df_to_plot[\"method\"], categories=sort_methods, ordered=True)\n",
    "sns.lineplot(data=df_to_plot, x=\"ratio\", y=\"time\", hue=\"method\", style=\"dataset\", markers=True, dashes=False)\n",
    "plt.gca().invert_xaxis()\n",
    "plt.xlabel(\"pooling ratio\")\n",
    "plt.ylabel(\"time in seconds\")\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False)\n",
    "sns.despine()\n",
    "plt.savefig(f\"pooling_time_{dataset}_{model}.svg\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
