{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as mtick\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "topology_results = pd.read_csv(\"results/results_topology.csv\", index_col=0)\n",
    "\n",
    "def extract_results(architecture, edge_orientation, adjacency_type, metric):\n",
    "    return topology_results.loc[[f\"{architecture}_{edge_orientation}_{adjacency_type}_{fold}_{metric}\" for fold in range(6)]]\n",
    "\n",
    "def print_table(architecture):\n",
    "    display(pd.DataFrame(\n",
    "        [\n",
    "            sum([[\"{:.2f}\".format(extract_results(architecture, edge_orientation, adjacency_type, \"MSE\").mean(1).mean())\n",
    "                    + \" ± {:.2f}\".format(extract_results(architecture, edge_orientation, adjacency_type, \"MSE\").mean(1).std()),\n",
    "                  \"{:.2f}%\".format(100 * extract_results(architecture, edge_orientation, adjacency_type, \"NSE\").mean(1).mean())\n",
    "                    + \" ± {:.2f}%\".format(100 * extract_results(architecture, edge_orientation, adjacency_type, \"NSE\").mean(1).std())]\n",
    "             for edge_orientation in [\"downstream\", \"upstream\", \"bidirectional\"]], [])\n",
    "           for adjacency_type in [\"isolated\", \"binary\", \"stream_length\", \"elevation_difference\", \"average_slope\", \"learned\"]\n",
    "        ], \n",
    "        columns = [\"downstream (MSE)\", \"downstream (NSE)\", \"upstream (MSE)\", \"upstream (NSE)\", \"bidirectional (MSE)\", \"bidirectional (NSE)\"],\n",
    "        index = [\"isolated\", \"binary\", \"stream_length\", \"elevation_difference\", \"average_slope\", \"learned\"]\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>downstream (MSE)</th>\n",
       "      <th>downstream (NSE)</th>\n",
       "      <th>upstream (MSE)</th>\n",
       "      <th>upstream (NSE)</th>\n",
       "      <th>bidirectional (MSE)</th>\n",
       "      <th>bidirectional (NSE)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>isolated</th>\n",
       "      <td>899.80 ± 1329.17</td>\n",
       "      <td>80.85% ± 11.66%</td>\n",
       "      <td>899.80 ± 1329.17</td>\n",
       "      <td>80.85% ± 11.66%</td>\n",
       "      <td>899.80 ± 1329.17</td>\n",
       "      <td>80.85% ± 11.66%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>binary</th>\n",
       "      <td>353.54 ± 80.90</td>\n",
       "      <td>83.53% ± 5.63%</td>\n",
       "      <td>372.67 ± 61.11</td>\n",
       "      <td>84.99% ± 5.10%</td>\n",
       "      <td>741.20 ± 166.26</td>\n",
       "      <td>85.34% ± 4.86%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stream_length</th>\n",
       "      <td>524.03 ± 100.46</td>\n",
       "      <td>83.42% ± 5.59%</td>\n",
       "      <td>435.66 ± 60.49</td>\n",
       "      <td>84.74% ± 5.02%</td>\n",
       "      <td>785.38 ± 171.49</td>\n",
       "      <td>85.31% ± 4.92%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>elevation_difference</th>\n",
       "      <td>407.67 ± 95.16</td>\n",
       "      <td>83.46% ± 5.60%</td>\n",
       "      <td>456.32 ± 63.80</td>\n",
       "      <td>83.76% ± 4.80%</td>\n",
       "      <td>773.95 ± 182.22</td>\n",
       "      <td>85.16% ± 4.93%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>average_slope</th>\n",
       "      <td>327.22 ± 75.81</td>\n",
       "      <td>83.45% ± 5.60%</td>\n",
       "      <td>425.95 ± 86.43</td>\n",
       "      <td>84.10% ± 5.18%</td>\n",
       "      <td>656.52 ± 170.12</td>\n",
       "      <td>85.23% ± 4.92%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>learned</th>\n",
       "      <td>345.57 ± 199.76</td>\n",
       "      <td>83.50% ± 5.40%</td>\n",
       "      <td>366.94 ± 80.72</td>\n",
       "      <td>85.63% ± 4.65%</td>\n",
       "      <td>567.39 ± 160.84</td>\n",
       "      <td>85.94% ± 4.52%</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      downstream (MSE) downstream (NSE)    upstream (MSE)  \\\n",
       "isolated              899.80 ± 1329.17  80.85% ± 11.66%  899.80 ± 1329.17   \n",
       "binary                  353.54 ± 80.90   83.53% ± 5.63%    372.67 ± 61.11   \n",
       "stream_length          524.03 ± 100.46   83.42% ± 5.59%    435.66 ± 60.49   \n",
       "elevation_difference    407.67 ± 95.16   83.46% ± 5.60%    456.32 ± 63.80   \n",
       "average_slope           327.22 ± 75.81   83.45% ± 5.60%    425.95 ± 86.43   \n",
       "learned                345.57 ± 199.76   83.50% ± 5.40%    366.94 ± 80.72   \n",
       "\n",
       "                       upstream (NSE) bidirectional (MSE) bidirectional (NSE)  \n",
       "isolated              80.85% ± 11.66%    899.80 ± 1329.17     80.85% ± 11.66%  \n",
       "binary                 84.99% ± 5.10%     741.20 ± 166.26      85.34% ± 4.86%  \n",
       "stream_length          84.74% ± 5.02%     785.38 ± 171.49      85.31% ± 4.92%  \n",
       "elevation_difference   83.76% ± 4.80%     773.95 ± 182.22      85.16% ± 4.93%  \n",
       "average_slope          84.10% ± 5.18%     656.52 ± 170.12      85.23% ± 4.92%  \n",
       "learned                85.63% ± 4.65%     567.39 ± 160.84      85.94% ± 4.52%  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print_table(\"ResGCN\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>downstream (MSE)</th>\n",
       "      <th>downstream (NSE)</th>\n",
       "      <th>upstream (MSE)</th>\n",
       "      <th>upstream (NSE)</th>\n",
       "      <th>bidirectional (MSE)</th>\n",
       "      <th>bidirectional (NSE)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>isolated</th>\n",
       "      <td>289.71 ± 50.01</td>\n",
       "      <td>85.95% ± 4.97%</td>\n",
       "      <td>289.71 ± 50.01</td>\n",
       "      <td>85.95% ± 4.97%</td>\n",
       "      <td>289.71 ± 50.01</td>\n",
       "      <td>85.95% ± 4.97%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>binary</th>\n",
       "      <td>277.50 ± 33.57</td>\n",
       "      <td>86.17% ± 4.69%</td>\n",
       "      <td>312.31 ± 43.98</td>\n",
       "      <td>85.75% ± 5.03%</td>\n",
       "      <td>355.95 ± 65.61</td>\n",
       "      <td>86.44% ± 4.64%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stream_length</th>\n",
       "      <td>343.86 ± 29.33</td>\n",
       "      <td>86.17% ± 4.66%</td>\n",
       "      <td>311.32 ± 43.91</td>\n",
       "      <td>85.72% ± 5.01%</td>\n",
       "      <td>393.39 ± 81.15</td>\n",
       "      <td>86.37% ± 4.67%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>elevation_difference</th>\n",
       "      <td>302.76 ± 48.07</td>\n",
       "      <td>86.11% ± 4.69%</td>\n",
       "      <td>314.72 ± 42.75</td>\n",
       "      <td>85.35% ± 5.28%</td>\n",
       "      <td>411.96 ± 80.55</td>\n",
       "      <td>86.33% ± 4.71%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>average_slope</th>\n",
       "      <td>276.88 ± 40.39</td>\n",
       "      <td>86.08% ± 4.67%</td>\n",
       "      <td>279.22 ± 41.44</td>\n",
       "      <td>85.44% ± 5.32%</td>\n",
       "      <td>364.96 ± 79.10</td>\n",
       "      <td>86.26% ± 4.79%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>learned</th>\n",
       "      <td>169.93 ± 33.40</td>\n",
       "      <td>86.14% ± 4.87%</td>\n",
       "      <td>280.07 ± 46.97</td>\n",
       "      <td>86.03% ± 4.80%</td>\n",
       "      <td>323.54 ± 83.12</td>\n",
       "      <td>86.48% ± 4.69%</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     downstream (MSE) downstream (NSE)  upstream (MSE)  \\\n",
       "isolated               289.71 ± 50.01   85.95% ± 4.97%  289.71 ± 50.01   \n",
       "binary                 277.50 ± 33.57   86.17% ± 4.69%  312.31 ± 43.98   \n",
       "stream_length          343.86 ± 29.33   86.17% ± 4.66%  311.32 ± 43.91   \n",
       "elevation_difference   302.76 ± 48.07   86.11% ± 4.69%  314.72 ± 42.75   \n",
       "average_slope          276.88 ± 40.39   86.08% ± 4.67%  279.22 ± 41.44   \n",
       "learned                169.93 ± 33.40   86.14% ± 4.87%  280.07 ± 46.97   \n",
       "\n",
       "                      upstream (NSE) bidirectional (MSE) bidirectional (NSE)  \n",
       "isolated              85.95% ± 4.97%      289.71 ± 50.01      85.95% ± 4.97%  \n",
       "binary                85.75% ± 5.03%      355.95 ± 65.61      86.44% ± 4.64%  \n",
       "stream_length         85.72% ± 5.01%      393.39 ± 81.15      86.37% ± 4.67%  \n",
       "elevation_difference  85.35% ± 5.28%      411.96 ± 80.55      86.33% ± 4.71%  \n",
       "average_slope         85.44% ± 5.32%      364.96 ± 79.10      86.26% ± 4.79%  \n",
       "learned               86.03% ± 4.80%      323.54 ± 83.12      86.48% ± 4.69%  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print_table(\"GCNII\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>downstream (MSE)</th>\n",
       "      <th>downstream (NSE)</th>\n",
       "      <th>upstream (MSE)</th>\n",
       "      <th>upstream (NSE)</th>\n",
       "      <th>bidirectional (MSE)</th>\n",
       "      <th>bidirectional (NSE)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>isolated</th>\n",
       "      <td>354.45 ± 71.39</td>\n",
       "      <td>85.56% ± 4.93%</td>\n",
       "      <td>354.45 ± 71.39</td>\n",
       "      <td>85.56% ± 4.93%</td>\n",
       "      <td>354.45 ± 71.39</td>\n",
       "      <td>85.56% ± 4.93%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>binary</th>\n",
       "      <td>4871.21 ± 3464.82</td>\n",
       "      <td>28.79% ± 21.49%</td>\n",
       "      <td>5444.81 ± 1363.50</td>\n",
       "      <td>33.98% ± 17.36%</td>\n",
       "      <td>4715.03 ± 1359.65</td>\n",
       "      <td>69.80% ± 8.08%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stream_length</th>\n",
       "      <td>3184.44 ± 752.60</td>\n",
       "      <td>33.87% ± 18.94%</td>\n",
       "      <td>5041.33 ± 1397.70</td>\n",
       "      <td>47.14% ± 17.37%</td>\n",
       "      <td>3778.37 ± 575.68</td>\n",
       "      <td>76.74% ± 5.75%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>elevation_difference</th>\n",
       "      <td>5316.44 ± 3411.55</td>\n",
       "      <td>28.74% ± 21.42%</td>\n",
       "      <td>5577.10 ± 1259.32</td>\n",
       "      <td>35.71% ± 18.25%</td>\n",
       "      <td>3132.41 ± 799.16</td>\n",
       "      <td>78.17% ± 5.87%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>average_slope</th>\n",
       "      <td>9436.05 ± 4272.67</td>\n",
       "      <td>10.32% ± 30.85%</td>\n",
       "      <td>5060.77 ± 1535.16</td>\n",
       "      <td>34.76% ± 19.29%</td>\n",
       "      <td>4257.56 ± 1619.34</td>\n",
       "      <td>72.59% ± 8.65%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>learned</th>\n",
       "      <td>1067.59 ± 325.14</td>\n",
       "      <td>34.99% ± 18.23%</td>\n",
       "      <td>4750.12 ± 1299.36</td>\n",
       "      <td>37.87% ± 19.49%</td>\n",
       "      <td>1868.82 ± 533.35</td>\n",
       "      <td>75.48% ± 7.22%</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       downstream (MSE) downstream (NSE)     upstream (MSE)  \\\n",
       "isolated                 354.45 ± 71.39   85.56% ± 4.93%     354.45 ± 71.39   \n",
       "binary                4871.21 ± 3464.82  28.79% ± 21.49%  5444.81 ± 1363.50   \n",
       "stream_length          3184.44 ± 752.60  33.87% ± 18.94%  5041.33 ± 1397.70   \n",
       "elevation_difference  5316.44 ± 3411.55  28.74% ± 21.42%  5577.10 ± 1259.32   \n",
       "average_slope         9436.05 ± 4272.67  10.32% ± 30.85%  5060.77 ± 1535.16   \n",
       "learned                1067.59 ± 325.14  34.99% ± 18.23%  4750.12 ± 1299.36   \n",
       "\n",
       "                       upstream (NSE) bidirectional (MSE) bidirectional (NSE)  \n",
       "isolated               85.56% ± 4.93%      354.45 ± 71.39      85.56% ± 4.93%  \n",
       "binary                33.98% ± 17.36%   4715.03 ± 1359.65      69.80% ± 8.08%  \n",
       "stream_length         47.14% ± 17.37%    3778.37 ± 575.68      76.74% ± 5.75%  \n",
       "elevation_difference  35.71% ± 18.25%    3132.41 ± 799.16      78.17% ± 5.87%  \n",
       "average_slope         34.76% ± 19.29%   4257.56 ± 1619.34      72.59% ± 8.65%  \n",
       "learned               37.87% ± 19.49%    1868.82 ± 533.35      75.48% ± 7.22%  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print_table(\"GCN\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
