{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "294e0cc8-3171-48fa-ab74-1336baa2f4b3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-23T14:05:13.319024Z",
     "iopub.status.busy": "2023-04-23T14:05:13.318787Z",
     "iopub.status.idle": "2023-04-23T14:05:13.339560Z",
     "shell.execute_reply": "2023-04-23T14:05:13.339038Z",
     "shell.execute_reply.started": "2023-04-23T14:05:13.319001Z"
    }
   },
   "outputs": [],
   "source": [
    "import functions\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b55a3015",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_PATH = \"/scratch/kirschstein/LamaH-CE\"\n",
    "CHECKPOINT_PATH = \"./checkpoints/topology\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1db0a41d",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_df = pd.read_csv(f\"{DATASET_PATH}/processed/adjacency.csv\")\n",
    "weight_cols = adj_df[[\"dist_hdn\", \"elev_diff\", \"strm_slope\"]].values\n",
    "stream_length = torch.tensor(weight_cols[:, 0], dtype=torch.float)\n",
    "elevation_difference = torch.tensor(weight_cols[:, 1], dtype=torch.float)\n",
    "average_slope = torch.tensor(weight_cols[:, 2], dtype=torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8d2bf516-fcda-484f-bd45-ae60c3e6614a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-23T14:11:53.469015Z",
     "iopub.status.busy": "2023-04-23T14:11:53.468745Z",
     "iopub.status.idle": "2023-04-23T14:11:53.492694Z",
     "shell.execute_reply": "2023-04-23T14:11:53.492020Z",
     "shell.execute_reply.started": "2023-04-23T14:11:53.468993Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResGCN downstream\n",
      "correlation matrix mean:\n",
      "tensor([[ 1.0000, -0.2213,  0.1000,  0.1683],\n",
      "        [-0.2213,  1.0000,  0.3331, -0.1741],\n",
      "        [ 0.1000,  0.3331,  1.0000,  0.5844],\n",
      "        [ 0.1683, -0.1741,  0.5844,  1.0000]])\n",
      "correlation matrix std:\n",
      "tensor([[6.1559e-08, 9.8565e-02, 2.1019e-02, 3.8168e-02],\n",
      "        [9.8565e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [2.1019e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [3.8168e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00]])\n",
      "\n",
      "ResGCN upstream\n",
      "correlation matrix mean:\n",
      "tensor([[ 1.0000,  0.0422, -0.3080, -0.2930],\n",
      "        [ 0.0422,  1.0000,  0.3331, -0.1741],\n",
      "        [-0.3080,  0.3331,  1.0000,  0.5844],\n",
      "        [-0.2930, -0.1741,  0.5844,  1.0000]])\n",
      "correlation matrix std:\n",
      "tensor([[0.0000, 0.0075, 0.0145, 0.0092],\n",
      "        [0.0075, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0145, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0092, 0.0000, 0.0000, 0.0000]])\n",
      "\n",
      "ResGCN bidirectional\n",
      "correlation matrix mean:\n",
      "tensor([[ 1.0000, -0.0023, -0.2353, -0.2400],\n",
      "        [-0.0023,  1.0000,  0.3331, -0.1741],\n",
      "        [-0.2353,  0.3331,  1.0000,  0.5844],\n",
      "        [-0.2400, -0.1741,  0.5844,  1.0000]])\n",
      "correlation matrix std:\n",
      "tensor([[6.1559e-08, 1.5621e-02, 1.4013e-02, 1.1748e-02],\n",
      "        [1.5621e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [1.4013e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [1.1748e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00]])\n",
      "\n",
      "GCNII downstream\n",
      "correlation matrix mean:\n",
      "tensor([[ 1.0000, -0.2337, -0.1703, -0.0367],\n",
      "        [-0.2337,  1.0000,  0.3331, -0.1741],\n",
      "        [-0.1703,  0.3331,  1.0000,  0.5844],\n",
      "        [-0.0367, -0.1741,  0.5844,  1.0000]])\n",
      "correlation matrix std:\n",
      "tensor([[4.8667e-08, 1.1841e-02, 3.2512e-03, 7.4772e-03],\n",
      "        [1.1841e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [3.2512e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [7.4772e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00]])\n",
      "\n",
      "GCNII upstream\n",
      "correlation matrix mean:\n",
      "tensor([[ 1.0000, -0.1440,  0.0266,  0.0901],\n",
      "        [-0.1440,  1.0000,  0.3331, -0.1741],\n",
      "        [ 0.0266,  0.3331,  1.0000,  0.5844],\n",
      "        [ 0.0901, -0.1741,  0.5844,  1.0000]])\n",
      "correlation matrix std:\n",
      "tensor([[2.4333e-08, 5.7313e-03, 7.4662e-03, 9.3634e-03],\n",
      "        [5.7313e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [7.4662e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [9.3634e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00]])\n",
      "\n",
      "GCNII bidirectional\n",
      "correlation matrix mean:\n",
      "tensor([[ 1.0000,  0.0541, -0.1031, -0.1626],\n",
      "        [ 0.0541,  1.0000,  0.3331, -0.1741],\n",
      "        [-0.1031,  0.3331,  1.0000,  0.5844],\n",
      "        [-0.1626, -0.1741,  0.5844,  1.0000]])\n",
      "correlation matrix std:\n",
      "tensor([[4.8667e-08, 3.4349e-02, 3.4657e-02, 1.2373e-02],\n",
      "        [3.4349e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [3.4657e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [1.2373e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00]])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for architecture in [\"ResGCN\", \"GCNII\"]:\n",
    "    for edge_orientation in [\"downstream\", \"upstream\", \"bidirectional\"]:\n",
    "        print(architecture, edge_orientation)\n",
    "        corrmats = []\n",
    "        for fold in range(6):\n",
    "            chkpt = torch.load(f\"{CHECKPOINT_PATH}/{architecture}_{edge_orientation}_learned_{fold}.run\")\n",
    "            learned_weights = chkpt[\"history\"][\"best_model_params\"][\"edge_weights\"].nan_to_num().cpu()\n",
    "            corrmats.append(torch.corrcoef(torch.stack([learned_weights, stream_length, elevation_difference, average_slope]))) \n",
    "        print(\"correlation matrix mean:\")\n",
    "        print(torch.stack(corrmats).mean(dim=0))\n",
    "        print(\"correlation matrix std:\")\n",
    "        print(torch.stack(corrmats).std(dim=0))\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "74a3dd0b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResGCN downstream\n",
      "GCNII downstream\n",
      "ResGCN upstream\n",
      "GCNII upstream\n",
      "ResGCN bidirectional\n",
      "GCNII bidirectional\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>downstream_ResGCN</th>\n",
       "      <th>downstream_GCNII</th>\n",
       "      <th>upstream_ResGCN</th>\n",
       "      <th>upstream_GCNII</th>\n",
       "      <th>bidirectional_ResGCN</th>\n",
       "      <th>bidirectional_GCNII</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.989 ± 0.013</td>\n",
       "      <td>0.768 ± 0.002</td>\n",
       "      <td>0.666 ± 0.011</td>\n",
       "      <td>0.793 ± 0.008</td>\n",
       "      <td>0.917 ± 0.006</td>\n",
       "      <td>0.955 ± 0.008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.511 ± 0.212</td>\n",
       "      <td>0.665 ± 0.025</td>\n",
       "      <td>0.537 ± 0.006</td>\n",
       "      <td>0.825 ± 0.022</td>\n",
       "      <td>0.635 ± 0.036</td>\n",
       "      <td>0.630 ± 0.026</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.109 ± 0.268</td>\n",
       "      <td>0.000 ± 0.000</td>\n",
       "      <td>0.000 ± 0.000</td>\n",
       "      <td>0.000 ± 0.000</td>\n",
       "      <td>0.000 ± 0.000</td>\n",
       "      <td>0.000 ± 0.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>0.624 ± 0.160</td>\n",
       "      <td>0.279 ± 0.028</td>\n",
       "      <td>0.201 ± 0.022</td>\n",
       "      <td>0.227 ± 0.022</td>\n",
       "      <td>0.451 ± 0.032</td>\n",
       "      <td>0.473 ± 0.021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>median</th>\n",
       "      <td>1.042 ± 0.019</td>\n",
       "      <td>0.599 ± 0.021</td>\n",
       "      <td>0.588 ± 0.026</td>\n",
       "      <td>0.570 ± 0.015</td>\n",
       "      <td>0.851 ± 0.024</td>\n",
       "      <td>0.919 ± 0.017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>1.365 ± 0.151</td>\n",
       "      <td>1.172 ± 0.031</td>\n",
       "      <td>1.049 ± 0.027</td>\n",
       "      <td>1.134 ± 0.037</td>\n",
       "      <td>1.298 ± 0.036</td>\n",
       "      <td>1.306 ± 0.027</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>3.257 ± 0.983</td>\n",
       "      <td>5.463 ± 0.895</td>\n",
       "      <td>2.217 ± 0.052</td>\n",
       "      <td>6.772 ± 0.489</td>\n",
       "      <td>3.197 ± 0.256</td>\n",
       "      <td>3.515 ± 0.286</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       downstream_ResGCN downstream_GCNII upstream_ResGCN upstream_GCNII  \\\n",
       "mean       0.989 ± 0.013    0.768 ± 0.002   0.666 ± 0.011  0.793 ± 0.008   \n",
       "std        0.511 ± 0.212    0.665 ± 0.025   0.537 ± 0.006  0.825 ± 0.022   \n",
       "min        0.109 ± 0.268    0.000 ± 0.000   0.000 ± 0.000  0.000 ± 0.000   \n",
       "25%        0.624 ± 0.160    0.279 ± 0.028   0.201 ± 0.022  0.227 ± 0.022   \n",
       "median     1.042 ± 0.019    0.599 ± 0.021   0.588 ± 0.026  0.570 ± 0.015   \n",
       "75%        1.365 ± 0.151    1.172 ± 0.031   1.049 ± 0.027  1.134 ± 0.037   \n",
       "max        3.257 ± 0.983    5.463 ± 0.895   2.217 ± 0.052  6.772 ± 0.489   \n",
       "\n",
       "       bidirectional_ResGCN bidirectional_GCNII  \n",
       "mean          0.917 ± 0.006       0.955 ± 0.008  \n",
       "std           0.635 ± 0.036       0.630 ± 0.026  \n",
       "min           0.000 ± 0.000       0.000 ± 0.000  \n",
       "25%           0.451 ± 0.032       0.473 ± 0.021  \n",
       "median        0.851 ± 0.024       0.919 ± 0.017  \n",
       "75%           1.298 ± 0.036       1.306 ± 0.027  \n",
       "max           3.197 ± 0.256       3.515 ± 0.286  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight_stats_df = pd.DataFrame()\n",
    "for edge_orientation in [\"downstream\", \"upstream\", \"bidirectional\"]:\n",
    "    for architecture in [\"ResGCN\", \"GCNII\"]:\n",
    "        print(architecture, edge_orientation)\n",
    "        stats = []\n",
    "        descriptors = [\"mean\", \"std\", \"min\", \"25%\", \"median\", \"75%\", \"max\"]\n",
    "        for fold in range(6):\n",
    "            chkpt = torch.load(f\"{CHECKPOINT_PATH}/{architecture}_{edge_orientation}_learned_{fold}.run\")\n",
    "            learned_weights = chkpt[\"history\"][\"best_model_params\"][\"edge_weights\"].nan_to_num().cpu().clamp(min=0)\n",
    "            stats.append([learned_weights.mean(), \n",
    "                          learned_weights.std(), \n",
    "                          learned_weights.min(),\n",
    "                          learned_weights.quantile(0.25),\n",
    "                          learned_weights.median(),\n",
    "                          learned_weights.quantile(0.75),\n",
    "                          learned_weights.max()])\n",
    "        stats = torch.tensor(stats)\n",
    "        for i, descriptor in enumerate(descriptors):\n",
    "            stat_mean = stats[:, i].mean()\n",
    "            stat_std = stats[:, i].std()\n",
    "            weight_stats_df.loc[descriptor, f\"{edge_orientation}_{architecture}\"] = f\"{stat_mean:.3f} ± {stat_std:.3f}\"\n",
    "weight_stats_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
