{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of time windows\n",
    "T = 18\n",
    "\n",
    "# Target 1-coverage for conformal prediction\n",
    "alpha = 0.1\n",
    "\n",
    "data_name=\"School\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save figures into a special folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_figs = False\n",
    "output_file_prefix = 'figures/School_'\n",
    "output_file_suffix = '_10_100_50.pdf'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "xlabels = ['', '10am', '', '12pm', '\\nDay 1', '2pm', '', '4pm', '',\n",
    "           '', '10am', '', '12pm', '\\nDay 2', '2pm', '', '4pm', '']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['BD', 'UA']\n",
    "GNN_models = ['GCN', 'GAT']\n",
    "regimes = ['Assisted Semi-Ind', 'Trans', 'Semi-Ind']\n",
    "# regimes = ['Trans', 'Semi-Ind']\n",
    "outputs = ['Accuracy', 'Avg Size', 'Coverage']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "props = []\n",
    "results_filenames = []\n",
    "for filename in os.listdir('results'):\n",
    "    if \"].pkl\" in filename:\n",
    "        results_filenames.append(filename)\n",
    "        prop_vals = np.array(filename.split('[')[1].split(']')[0].split(' '))\n",
    "        prop_vals = prop_vals.astype(float)\n",
    "        assert np.sum(prop_vals) == 1\n",
    "        props.append(prop_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df_summaries = []\n",
    "\n",
    "for prop, filename in zip(props, results_filenames):\n",
    "    results_file = f'results/{filename}'\n",
    "\n",
    "    with open(results_file, 'rb') as file:\n",
    "        results = pickle.load(file)\n",
    "\n",
    "    methods_list = []\n",
    "    GNN_models_list = []\n",
    "    regimes_list = []\n",
    "    outputs_list = []\n",
    "    stat_types_list = []\n",
    "    stats_list = []\n",
    "\n",
    "    for (method, GNN_model, regime, output) in product(methods, GNN_models, regimes, outputs):\n",
    "\n",
    "        methods_list.append(method)\n",
    "        GNN_models_list.append(GNN_model)\n",
    "        regimes_list.append(regime)\n",
    "        outputs_list.append(output)\n",
    "        stat_types_list.append('Mean')\n",
    "        stats_list.append(np.round(np.mean(results[method][GNN_model][regime][output]['All']), 3))\n",
    "        \n",
    "        methods_list.append(method)\n",
    "        GNN_models_list.append(GNN_model)\n",
    "        regimes_list.append(regime)\n",
    "        outputs_list.append(output)\n",
    "        stat_types_list.append('St Dev')\n",
    "        stats_list.append(np.round(np.std(results[method][GNN_model][regime][output]['All']), 3))\n",
    "        \n",
    "    output = \"TSC\"\n",
    "    for (method, GNN_model, regime) in product(methods, GNN_models, regimes):\n",
    "        if regime == \"Semi-Ind\":\n",
    "            num_vals = 50\n",
    "        elif regime in [\"Assisted Semi-Ind\", \"Assisted Semi-Ind\"]:\n",
    "            num_vals = 1000\n",
    "\n",
    "        T_output = np.where(np.array([len(results[method][GNN_model][regime][\"Coverage\"][t]) for t in range(T)]) > 0)[0]\n",
    "\n",
    "\n",
    "        covs = np.zeros((T, num_vals))\n",
    "        for t in T_output:\n",
    "            covs[t] = results[method][GNN_model][regime][\"Coverage\"][t]\n",
    "\n",
    "        # min_covs = []\n",
    "        # for cov_run in range(num_vals):\n",
    "        #     covs_for_run = covs[:, cov_run]\n",
    "        #     covs_for_run = covs_for_run[covs_for_run > 0]\n",
    "        #     min_covs.append(np.min(covs_for_run))\n",
    "\n",
    "        TSC = []\n",
    "        for t in T_output:\n",
    "            TSC.append(np.mean(covs[t, :]))\n",
    "\n",
    "        min_TSC_idx = np.argmin(TSC)\n",
    "        min_TSC = TSC[min_TSC_idx]\n",
    "        std_min_TSC = np.std(covs[T_output[min_TSC_idx], :])\n",
    "\n",
    "\n",
    "        methods_list.append(method)\n",
    "        GNN_models_list.append(GNN_model)\n",
    "        regimes_list.append(regime)\n",
    "        outputs_list.append(output)\n",
    "        stat_types_list.append('Mean')\n",
    "        stats_list.append(np.round(min_TSC, 3))\n",
    "\n",
    "        methods_list.append(method)\n",
    "        GNN_models_list.append(GNN_model)\n",
    "        regimes_list.append(regime)\n",
    "        outputs_list.append(output)\n",
    "        stat_types_list.append('St Dev')\n",
    "        stats_list.append(np.round(std_min_TSC, 3)) \n",
    "\n",
    "\n",
    "    df_summary = pd.DataFrame({\n",
    "        'method'   : methods_list,\n",
    "        'GNN model': GNN_models_list,\n",
    "        'regime'   : regimes_list,\n",
    "        'output'   : outputs_list,\n",
    "        'statistic': stat_types_list,\n",
    "        'value'    : stats_list,\n",
    "        'data split': \"/\".join(prop.astype(str)),\n",
    "    })\n",
    "    df_summaries.append(df_summary)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display full table of statistics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   method GNN model             regime    output statistic  value  \\\n",
      "0      BD       GCN  Assisted Semi-Ind  Accuracy      Mean  0.112   \n",
      "1      BD       GCN  Assisted Semi-Ind  Accuracy    St Dev  0.010   \n",
      "2      BD       GCN  Assisted Semi-Ind  Avg Size      Mean  9.134   \n",
      "3      BD       GCN  Assisted Semi-Ind  Avg Size    St Dev  0.164   \n",
      "4      BD       GCN  Assisted Semi-Ind  Coverage      Mean  0.909   \n",
      "..    ...       ...                ...       ...       ...    ...   \n",
      "91     UA       GAT  Assisted Semi-Ind       TSC    St Dev  0.042   \n",
      "92     UA       GAT              Trans       TSC      Mean  0.813   \n",
      "93     UA       GAT              Trans       TSC    St Dev  0.071   \n",
      "94     UA       GAT           Semi-Ind       TSC      Mean  0.842   \n",
      "95     UA       GAT           Semi-Ind       TSC    St Dev  0.053   \n",
      "\n",
      "             data split  \n",
      "0   0.25/0.25/0.25/0.25  \n",
      "1   0.25/0.25/0.25/0.25  \n",
      "2   0.25/0.25/0.25/0.25  \n",
      "3   0.25/0.25/0.25/0.25  \n",
      "4   0.25/0.25/0.25/0.25  \n",
      "..                  ...  \n",
      "91  0.25/0.25/0.25/0.25  \n",
      "92  0.25/0.25/0.25/0.25  \n",
      "93  0.25/0.25/0.25/0.25  \n",
      "94  0.25/0.25/0.25/0.25  \n",
      "95  0.25/0.25/0.25/0.25  \n",
      "\n",
      "[96 rows x 7 columns]\n"
     ]
    }
   ],
   "source": [
    "print(df_summaries[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>method</th>\n",
       "      <th>GNN model</th>\n",
       "      <th>regime</th>\n",
       "      <th>output</th>\n",
       "      <th>statistic</th>\n",
       "      <th>value</th>\n",
       "      <th>data split</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>BD</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Assisted Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.910</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>BD</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Assisted Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.017</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>BD</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Trans</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.901</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>BD</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Trans</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.013</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>BD</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.870</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>BD</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.039</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>BD</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Assisted Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.908</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>BD</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Assisted Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.019</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>BD</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Trans</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.901</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>BD</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Trans</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.013</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>BD</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.786</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>BD</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.069</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>UA</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Assisted Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.902</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>UA</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Assisted Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.014</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>UA</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Trans</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.901</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>UA</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Trans</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.013</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>52</th>\n",
       "      <td>UA</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.897</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>53</th>\n",
       "      <td>UA</td>\n",
       "      <td>GCN</td>\n",
       "      <td>Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.020</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>58</th>\n",
       "      <td>UA</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Assisted Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.901</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>59</th>\n",
       "      <td>UA</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Assisted Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.013</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>UA</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Trans</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.902</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65</th>\n",
       "      <td>UA</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Trans</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.013</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70</th>\n",
       "      <td>UA</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>Mean</td>\n",
       "      <td>0.908</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>71</th>\n",
       "      <td>UA</td>\n",
       "      <td>GAT</td>\n",
       "      <td>Semi-Ind</td>\n",
       "      <td>Coverage</td>\n",
       "      <td>St Dev</td>\n",
       "      <td>0.023</td>\n",
       "      <td>0.05/0.35/0.3/0.3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   method GNN model             regime    output statistic  value  \\\n",
       "4      BD       GCN  Assisted Semi-Ind  Coverage      Mean  0.910   \n",
       "5      BD       GCN  Assisted Semi-Ind  Coverage    St Dev  0.017   \n",
       "10     BD       GCN              Trans  Coverage      Mean  0.901   \n",
       "11     BD       GCN              Trans  Coverage    St Dev  0.013   \n",
       "16     BD       GCN           Semi-Ind  Coverage      Mean  0.870   \n",
       "17     BD       GCN           Semi-Ind  Coverage    St Dev  0.039   \n",
       "22     BD       GAT  Assisted Semi-Ind  Coverage      Mean  0.908   \n",
       "23     BD       GAT  Assisted Semi-Ind  Coverage    St Dev  0.019   \n",
       "28     BD       GAT              Trans  Coverage      Mean  0.901   \n",
       "29     BD       GAT              Trans  Coverage    St Dev  0.013   \n",
       "34     BD       GAT           Semi-Ind  Coverage      Mean  0.786   \n",
       "35     BD       GAT           Semi-Ind  Coverage    St Dev  0.069   \n",
       "40     UA       GCN  Assisted Semi-Ind  Coverage      Mean  0.902   \n",
       "41     UA       GCN  Assisted Semi-Ind  Coverage    St Dev  0.014   \n",
       "46     UA       GCN              Trans  Coverage      Mean  0.901   \n",
       "47     UA       GCN              Trans  Coverage    St Dev  0.013   \n",
       "52     UA       GCN           Semi-Ind  Coverage      Mean  0.897   \n",
       "53     UA       GCN           Semi-Ind  Coverage    St Dev  0.020   \n",
       "58     UA       GAT  Assisted Semi-Ind  Coverage      Mean  0.901   \n",
       "59     UA       GAT  Assisted Semi-Ind  Coverage    St Dev  0.013   \n",
       "64     UA       GAT              Trans  Coverage      Mean  0.902   \n",
       "65     UA       GAT              Trans  Coverage    St Dev  0.013   \n",
       "70     UA       GAT           Semi-Ind  Coverage      Mean  0.908   \n",
       "71     UA       GAT           Semi-Ind  Coverage    St Dev  0.023   \n",
       "\n",
       "           data split  \n",
       "4   0.05/0.35/0.3/0.3  \n",
       "5   0.05/0.35/0.3/0.3  \n",
       "10  0.05/0.35/0.3/0.3  \n",
       "11  0.05/0.35/0.3/0.3  \n",
       "16  0.05/0.35/0.3/0.3  \n",
       "17  0.05/0.35/0.3/0.3  \n",
       "22  0.05/0.35/0.3/0.3  \n",
       "23  0.05/0.35/0.3/0.3  \n",
       "28  0.05/0.35/0.3/0.3  \n",
       "29  0.05/0.35/0.3/0.3  \n",
       "34  0.05/0.35/0.3/0.3  \n",
       "35  0.05/0.35/0.3/0.3  \n",
       "40  0.05/0.35/0.3/0.3  \n",
       "41  0.05/0.35/0.3/0.3  \n",
       "46  0.05/0.35/0.3/0.3  \n",
       "47  0.05/0.35/0.3/0.3  \n",
       "52  0.05/0.35/0.3/0.3  \n",
       "53  0.05/0.35/0.3/0.3  \n",
       "58  0.05/0.35/0.3/0.3  \n",
       "59  0.05/0.35/0.3/0.3  \n",
       "64  0.05/0.35/0.3/0.3  \n",
       "65  0.05/0.35/0.3/0.3  \n",
       "70  0.05/0.35/0.3/0.3  \n",
       "71  0.05/0.35/0.3/0.3  "
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_summary[df_summary['output'] == 'Coverage']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{table}[ht]\n",
      "\\centering\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & \\textbf{0.901 $\\pm$ 0.014} & 0.833 $\\pm$ 0.046 & \\textbf{0.909 $\\pm$ 0.019} \\\\ \\hline\n",
      "UGCN & \\textbf{0.901 $\\pm$ 0.014} & \\textbf{0.943 $\\pm$ 0.018} & \\textbf{0.903 $\\pm$ 0.014} \\\\ \\hline\n",
      "Block GAT & \\textbf{0.902 $\\pm$ 0.014} & 0.681 $\\pm$ 0.102 & \\textbf{0.916 $\\pm$ 0.026} \\\\ \\hline\n",
      "UGAT & \\textbf{0.901 $\\pm$ 0.014} & \\textbf{0.912 $\\pm$ 0.024} & \\textbf{0.901 $\\pm$ 0.014} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Coverage for the School experiment with data split 25/25/25/25.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & \\textbf{0.903 $\\pm$ 0.015} & \\textbf{0.826 $\\pm$ 0.075} & \\textbf{0.913 $\\pm$ 0.02} \\\\ \\hline\n",
      "UGCN & \\textbf{0.902 $\\pm$ 0.016} & \\textbf{0.939 $\\pm$ 0.02} & \\textbf{0.902 $\\pm$ 0.016} \\\\ \\hline\n",
      "Block GAT & \\textbf{0.904 $\\pm$ 0.015} & 0.692 $\\pm$ 0.109 & \\textbf{0.926 $\\pm$ 0.029} \\\\ \\hline\n",
      "UGAT & \\textbf{0.903 $\\pm$ 0.016} & \\textbf{0.901 $\\pm$ 0.035} & \\textbf{0.902 $\\pm$ 0.015} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Coverage for the School experiment with data split 50/10/20/20.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & \\textbf{0.901 $\\pm$ 0.011} & 0.811 $\\pm$ 0.05 & \\textbf{0.905 $\\pm$ 0.017} \\\\ \\hline\n",
      "UGCN & \\textbf{0.901 $\\pm$ 0.011} & 0.866 $\\pm$ 0.021 & \\textbf{0.899 $\\pm$ 0.011} \\\\ \\hline\n",
      "Block GAT & \\textbf{0.902 $\\pm$ 0.011} & 0.619 $\\pm$ 0.096 & \\textbf{0.906 $\\pm$ 0.014} \\\\ \\hline\n",
      "UGAT & \\textbf{0.901 $\\pm$ 0.011} & \\textbf{0.91 $\\pm$ 0.019} & \\textbf{0.901 $\\pm$ 0.011} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Coverage for the School experiment with data split 10/10/40/40.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & \\textbf{0.901 $\\pm$ 0.012} & 0.823 $\\pm$ 0.049 & \\textbf{0.906 $\\pm$ 0.017} \\\\ \\hline\n",
      "UGCN & \\textbf{0.901 $\\pm$ 0.012} & \\textbf{0.889 $\\pm$ 0.02} & \\textbf{0.903 $\\pm$ 0.012} \\\\ \\hline\n",
      "Block GAT & \\textbf{0.901 $\\pm$ 0.012} & 0.654 $\\pm$ 0.093 & \\textbf{0.914 $\\pm$ 0.022} \\\\ \\hline\n",
      "UGAT & \\textbf{0.901 $\\pm$ 0.012} & \\textbf{0.907 $\\pm$ 0.024} & \\textbf{0.901 $\\pm$ 0.012} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Coverage for the School experiment with data split 20/10/35/35.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & \\textbf{0.906 $\\pm$ 0.022} & \\textbf{0.845 $\\pm$ 0.07} & \\textbf{0.918 $\\pm$ 0.031} \\\\ \\hline\n",
      "UGCN & \\textbf{0.905 $\\pm$ 0.022} & \\textbf{0.903 $\\pm$ 0.03} & \\textbf{0.906 $\\pm$ 0.026} \\\\ \\hline\n",
      "Block GAT & \\textbf{0.906 $\\pm$ 0.022} & 0.705 $\\pm$ 0.115 & \\textbf{0.946 $\\pm$ 0.046} \\\\ \\hline\n",
      "UGAT & \\textbf{0.906 $\\pm$ 0.022} & \\textbf{0.876 $\\pm$ 0.042} & \\textbf{0.907 $\\pm$ 0.026} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Coverage for the School experiment with data split 50/30/10/10.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & \\textbf{0.901 $\\pm$ 0.013} & \\textbf{0.87 $\\pm$ 0.039} & \\textbf{0.91 $\\pm$ 0.017} \\\\ \\hline\n",
      "UGCN & \\textbf{0.901 $\\pm$ 0.013} & \\textbf{0.897 $\\pm$ 0.02} & \\textbf{0.902 $\\pm$ 0.014} \\\\ \\hline\n",
      "Block GAT & \\textbf{0.901 $\\pm$ 0.013} & 0.786 $\\pm$ 0.069 & \\textbf{0.908 $\\pm$ 0.019} \\\\ \\hline\n",
      "UGAT & \\textbf{0.902 $\\pm$ 0.013} & \\textbf{0.908 $\\pm$ 0.023} & \\textbf{0.901 $\\pm$ 0.013} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Coverage for the School experiment with data split 5/35/30/30.}\n",
      "\\end{subtable}\n",
      "\n",
      "\n",
      "\\caption{}\n",
      "\\end{table}\n"
     ]
    }
   ],
   "source": [
    "table_start_start_str = \"\\\\begin{table}[ht]\\n\\\\centering\\n\"\n",
    "print(table_start_start_str)\n",
    "for data_idx in range(len(props)):\n",
    "    prop = props[data_idx]\n",
    "    df_summary = df_summaries[data_idx]\n",
    "\n",
    "    df_summary[\"name\"] = df_summary[\"method\"] + \" \" + df_summary[\"GNN model\"]\n",
    "\n",
    "    replace_dict = {\n",
    "        \"BD\": \"Block \",\n",
    "        \"UA\": \"U\",\n",
    "    }\n",
    "\n",
    "    df_summary[\"name\"] = df_summary[\"name\"].replace(replace_dict)\n",
    "\n",
    "    output = \"Coverage\"\n",
    "\n",
    "    table_start_str = (\n",
    "        \"\\\\begin{subtable}{\\\\textwidth}\\n\\\\centering\\n\\\\begin{tabular}{|l|l|l|l|}\\n\\\\hline\\nEmbedding & \\\\multicolumn{3}{c|}{\"\n",
    "        + f\"{data_name}\"\n",
    "        + \"} \\\\\\ \\n\\\\cline{2-4}\\n& Trans. & Semi-ind. & Temp. Trans. \\\\\\ \\n\\\\hline\\n\"\n",
    "    )\n",
    "\n",
    "    table_data = \"\"\n",
    "    # Loop over each unique method + regime pair\n",
    "    for (GNN_model, method) in product(GNN_models, methods):\n",
    "\n",
    "        max_trans_acc = df_summary[\n",
    "            (df_summary[\"regime\"] == \"Trans\") & (df_summary[\"output\"] == output) & (df_summary[\"GNN model\"] == GNN_model)\n",
    "        ][\"value\"].max()\n",
    "        max_semi_ind_acc = df_summary[\n",
    "            (df_summary[\"regime\"] == \"Semi-Ind\") & (df_summary[\"output\"] == output) & (df_summary[\"GNN model\"] == GNN_model)\n",
    "        ][\"value\"].max()\n",
    "        max_temp_trans_acc = df_summary[\n",
    "            (df_summary[\"regime\"] == \"Assisted Semi-Ind\") & (df_summary[\"output\"] == output) & (df_summary[\"GNN model\"] == GNN_model)\n",
    "        ][\"value\"].max()\n",
    "\n",
    "        trans_acc = df_summary[\n",
    "            (df_summary[\"method\"] == method)\n",
    "            & (df_summary[\"GNN model\"] == GNN_model)\n",
    "            & (df_summary[\"regime\"] == \"Trans\")\n",
    "            & (df_summary[\"output\"] == output)\n",
    "            & (df_summary[\"statistic\"] == \"Mean\")\n",
    "        ][\"value\"].values[0]\n",
    "        trans_std = df_summary[\n",
    "            (df_summary[\"method\"] == method)\n",
    "            & (df_summary[\"GNN model\"] == GNN_model)\n",
    "            & (df_summary[\"regime\"] == \"Trans\")\n",
    "            & (df_summary[\"output\"] == output)\n",
    "            & (df_summary[\"statistic\"] == \"St Dev\")\n",
    "        ][\"value\"].values[0]\n",
    "\n",
    "        semi_ind_acc = df_summary[\n",
    "            (df_summary[\"method\"] == method)\n",
    "            & (df_summary[\"GNN model\"] == GNN_model)\n",
    "            & (df_summary[\"regime\"] == \"Semi-Ind\")\n",
    "            & (df_summary[\"output\"] == output)\n",
    "            & (df_summary[\"statistic\"] == \"Mean\")\n",
    "        ][\"value\"].values[0]\n",
    "        semi_ind_std = df_summary[\n",
    "            (df_summary[\"method\"] == method)\n",
    "            & (df_summary[\"GNN model\"] == GNN_model)\n",
    "            & (df_summary[\"regime\"] == \"Semi-Ind\")\n",
    "            & (df_summary[\"output\"] == output)\n",
    "            & (df_summary[\"statistic\"] == \"St Dev\")\n",
    "        ][\"value\"].values[0]\n",
    "\n",
    "        temp_trans_acc = df_summary[\n",
    "            (df_summary[\"method\"] == method)\n",
    "            & (df_summary[\"GNN model\"] == GNN_model)\n",
    "            & (df_summary[\"regime\"] == \"Assisted Semi-Ind\")\n",
    "            & (df_summary[\"output\"] == output)\n",
    "            & (df_summary[\"statistic\"] == \"Mean\")\n",
    "        ][\"value\"].values[0]\n",
    "        temp_trans_std = df_summary[\n",
    "            (df_summary[\"method\"] == method)\n",
    "            & (df_summary[\"GNN model\"] == GNN_model)\n",
    "            & (df_summary[\"regime\"] == \"Assisted Semi-Ind\")\n",
    "            & (df_summary[\"output\"] == output)\n",
    "            & (df_summary[\"statistic\"] == \"St Dev\")\n",
    "        ][\"value\"].values[0]\n",
    "\n",
    "        def format_cell(mean, std):\n",
    "            if output == \"Coverage\":\n",
    "                value = f\"{mean} $\\\\pm$ {std}\"\n",
    "                if output == \"Coverage\" and (mean + std) >= 0.9:\n",
    "                    return f\"\\\\textbf{{{value}}}\"\n",
    "                \n",
    "            return value\n",
    "\n",
    "        table_data += f\"{replace_dict[method]}{GNN_model} & {format_cell(trans_acc, trans_std)} & {format_cell(semi_ind_acc, semi_ind_std)} & {format_cell(temp_trans_acc, temp_trans_std)} \\\\\\\\ \\\\hline\\n\"\n",
    "\n",
    "    table_str = table_start_str + table_data + \"\\\\end{tabular}\\n\\\\caption{\" + f\"{output} for the {data_name} experiment with data split {'/'.join((prop*100).astype(int).astype(str))}.\" + \"}\\n\\\\end{subtable}\"\n",
    "\n",
    "    print(table_str)\n",
    "    print(\"\")\n",
    "\n",
    "table_end_end_str = \"\\n\\\\caption{}\\n\\\\end{table}\"\n",
    "print(table_end_end_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{table}[ht]\n",
      "\\centering\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & 0.888 $\\pm$ 0.013 & 0.109 $\\pm$ 0.025 & 0.112 $\\pm$ 0.01 \\\\ \\hline\n",
      "UGCN & \\textbf{0.935 $\\pm$ 0.007} & \\textbf{0.896 $\\pm$ 0.01} & \\textbf{0.938 $\\pm$ 0.006} \\\\ \\hline\n",
      "Block GAT & 0.834 $\\pm$ 0.021 & 0.108 $\\pm$ 0.028 & 0.106 $\\pm$ 0.018 \\\\ \\hline\n",
      "UGAT & \\textbf{0.908 $\\pm$ 0.019} & \\textbf{0.883 $\\pm$ 0.014} & \\textbf{0.904 $\\pm$ 0.024} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Accuracy for the School experiment with data split 25/25/25/25.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & 0.911 $\\pm$ 0.01 & 0.108 $\\pm$ 0.035 & 0.118 $\\pm$ 0.015 \\\\ \\hline\n",
      "UGCN & \\textbf{0.931 $\\pm$ 0.01} & \\textbf{0.971 $\\pm$ 0.008} & \\textbf{0.923 $\\pm$ 0.009} \\\\ \\hline\n",
      "Block GAT & 0.88 $\\pm$ 0.016 & 0.111 $\\pm$ 0.034 & 0.095 $\\pm$ 0.032 \\\\ \\hline\n",
      "UGAT & \\textbf{0.912 $\\pm$ 0.017} & \\textbf{0.962 $\\pm$ 0.01} & \\textbf{0.908 $\\pm$ 0.014} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Accuracy for the School experiment with data split 50/10/20/20.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & 0.78 $\\pm$ 0.019 & 0.108 $\\pm$ 0.017 & 0.112 $\\pm$ 0.01 \\\\ \\hline\n",
      "UGCN & \\textbf{0.928 $\\pm$ 0.008} & \\textbf{0.916 $\\pm$ 0.013} & \\textbf{0.927 $\\pm$ 0.006} \\\\ \\hline\n",
      "Block GAT & 0.706 $\\pm$ 0.028 & 0.106 $\\pm$ 0.02 & 0.114 $\\pm$ 0.008 \\\\ \\hline\n",
      "UGAT & \\textbf{0.877 $\\pm$ 0.026} & \\textbf{0.898 $\\pm$ 0.016} & \\textbf{0.885 $\\pm$ 0.013} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Accuracy for the School experiment with data split 10/10/40/40.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & 0.871 $\\pm$ 0.01 & 0.105 $\\pm$ 0.024 & 0.113 $\\pm$ 0.009 \\\\ \\hline\n",
      "UGCN & \\textbf{0.933 $\\pm$ 0.006} & \\textbf{0.913 $\\pm$ 0.009} & \\textbf{0.953 $\\pm$ 0.006} \\\\ \\hline\n",
      "Block GAT & 0.808 $\\pm$ 0.018 & 0.107 $\\pm$ 0.025 & 0.109 $\\pm$ 0.01 \\\\ \\hline\n",
      "UGAT & \\textbf{0.901 $\\pm$ 0.021} & \\textbf{0.899 $\\pm$ 0.012} & \\textbf{0.91 $\\pm$ 0.014} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Accuracy for the School experiment with data split 20/10/35/35.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & 0.91 $\\pm$ 0.016 & 0.113 $\\pm$ 0.038 & 0.135 $\\pm$ 0.026 \\\\ \\hline\n",
      "UGCN & \\textbf{0.934 $\\pm$ 0.014} & \\textbf{0.996 $\\pm$ 0.006} & \\textbf{0.998 $\\pm$ 0.002} \\\\ \\hline\n",
      "Block GAT & 0.881 $\\pm$ 0.017 & 0.116 $\\pm$ 0.035 & 0.114 $\\pm$ 0.041 \\\\ \\hline\n",
      "UGAT & \\textbf{0.913 $\\pm$ 0.018} & \\textbf{0.991 $\\pm$ 0.007} & \\textbf{0.979 $\\pm$ 0.013} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Accuracy for the School experiment with data split 50/30/10/10.}\n",
      "\\end{subtable}\n",
      "\n",
      "\\begin{subtable}{\\textwidth}\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{3}{c|}{School} \\\\ \n",
      "\\cline{2-4}\n",
      "& Trans. & Semi-ind. & Temp. Trans. \\\\ \n",
      "\\hline\n",
      "Block GCN & 0.583 $\\pm$ 0.031 & 0.105 $\\pm$ 0.022 & 0.105 $\\pm$ 0.011 \\\\ \\hline\n",
      "UGCN & \\textbf{0.918 $\\pm$ 0.026} & \\textbf{0.897 $\\pm$ 0.022} & \\textbf{0.939 $\\pm$ 0.009} \\\\ \\hline\n",
      "Block GAT & 0.543 $\\pm$ 0.03 & 0.106 $\\pm$ 0.025 & 0.111 $\\pm$ 0.01 \\\\ \\hline\n",
      "UGAT & \\textbf{0.838 $\\pm$ 0.031} & \\textbf{0.874 $\\pm$ 0.022} & \\textbf{0.877 $\\pm$ 0.032} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{Accuracy for the School experiment with data split 5/35/30/30.}\n",
      "\\end{subtable}\n",
      "\n",
      "\n",
      "\\caption{}\n",
      "\\end{table}\n"
     ]
    }
   ],
   "source": [
    "table_start_start_str = \"\\\\begin{table}[ht]\\n\\\\centering\\n\"\n",
    "print(table_start_start_str)\n",
    "for data_idx in range(len(props)):\n",
    "    prop = props[data_idx]\n",
    "    df_summary = df_summaries[data_idx]\n",
    "\n",
    "    df_summary[\"name\"] = df_summary[\"method\"] + \" \" + df_summary[\"GNN model\"]\n",
    "\n",
    "    replace_dict = {\n",
    "        \"BD\": \"Block \",\n",
    "        \"UA\": \"U\",\n",
    "    }\n",
    "\n",
    "    df_summary[\"name\"] = df_summary[\"name\"].replace(replace_dict)\n",
    "\n",
    "    output = \"Accuracy\"\n",
    "\n",
    "    table_start_str = (\n",
    "        \"\\\\begin{subtable}{\\\\textwidth}\\n\\\\centering\\n\\\\begin{tabular}{|l|l|l|l|}\\n\\\\hline\\nEmbedding & \\\\multicolumn{3}{c|}{\"\n",
    "        + f\"{data_name}\"\n",
    "        + \"} \\\\\\ \\n\\\\cline{2-4}\\n& Trans. & Semi-ind. & Temp. Trans. \\\\\\ \\n\\\\hline\\n\"\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "    table_data = \"\"\n",
    "    # Loop over each unique method + regime pair\n",
    "    for GNN_model in GNN_models:\n",
    "        max_trans_acc = df_summary[\n",
    "            (df_summary[\"regime\"] == \"Trans\") & (df_summary[\"output\"] == output) & (df_summary[\"GNN model\"] == GNN_model) & (df_summary[\"statistic\"] == \"Mean\")\n",
    "        ][\"value\"].max()\n",
    "        max_semi_ind_acc = df_summary[\n",
    "            (df_summary[\"regime\"] == \"Semi-Ind\") & (df_summary[\"output\"] == output) & (df_summary[\"GNN model\"] == GNN_model) & (df_summary[\"statistic\"] == \"Mean\")\n",
    "        ][\"value\"].max()\n",
    "        max_temp_trans_acc = df_summary[\n",
    "            (df_summary[\"regime\"] == \"Assisted Semi-Ind\") & (df_summary[\"output\"] == output) & (df_summary[\"GNN model\"] == GNN_model) & (df_summary[\"statistic\"] == \"Mean\")\n",
    "        ][\"value\"].max()\n",
    "\n",
    "        for method in methods:\n",
    "            trans_acc = df_summary[\n",
    "                (df_summary[\"method\"] == method)\n",
    "                & (df_summary[\"GNN model\"] == GNN_model)\n",
    "                & (df_summary[\"regime\"] == \"Trans\")\n",
    "                & (df_summary[\"output\"] == output)\n",
    "                & (df_summary[\"statistic\"] == \"Mean\")\n",
    "            ][\"value\"].values[0]\n",
    "            trans_std = df_summary[\n",
    "                (df_summary[\"method\"] == method)\n",
    "                & (df_summary[\"GNN model\"] == GNN_model)\n",
    "                & (df_summary[\"regime\"] == \"Trans\")\n",
    "                & (df_summary[\"output\"] == output)\n",
    "                & (df_summary[\"statistic\"] == \"St Dev\")\n",
    "            ][\"value\"].values[0]\n",
    "\n",
    "            semi_ind_acc = df_summary[\n",
    "                (df_summary[\"method\"] == method)\n",
    "                & (df_summary[\"GNN model\"] == GNN_model)\n",
    "                & (df_summary[\"regime\"] == \"Semi-Ind\")\n",
    "                & (df_summary[\"output\"] == output)\n",
    "                & (df_summary[\"statistic\"] == \"Mean\")\n",
    "            ][\"value\"].values[0]\n",
    "            semi_ind_std = df_summary[\n",
    "                (df_summary[\"method\"] == method)\n",
    "                & (df_summary[\"GNN model\"] == GNN_model)\n",
    "                & (df_summary[\"regime\"] == \"Semi-Ind\")\n",
    "                & (df_summary[\"output\"] == output)\n",
    "                & (df_summary[\"statistic\"] == \"St Dev\")\n",
    "            ][\"value\"].values[0]\n",
    "\n",
    "            temp_trans_acc = df_summary[\n",
    "                (df_summary[\"method\"] == method)\n",
    "                & (df_summary[\"GNN model\"] == GNN_model)\n",
    "                & (df_summary[\"regime\"] == \"Assisted Semi-Ind\")\n",
    "                & (df_summary[\"output\"] == output)\n",
    "                & (df_summary[\"statistic\"] == \"Mean\")\n",
    "            ][\"value\"].values[0]\n",
    "            temp_trans_std = df_summary[\n",
    "                (df_summary[\"method\"] == method)\n",
    "                & (df_summary[\"GNN model\"] == GNN_model)\n",
    "                & (df_summary[\"regime\"] == \"Assisted Semi-Ind\")\n",
    "                & (df_summary[\"output\"] == output)\n",
    "                & (df_summary[\"statistic\"] == \"St Dev\")\n",
    "            ][\"value\"].values[0]\n",
    "\n",
    "            def format_cell(mean, std, max_mean):\n",
    "                value = f\"{mean} $\\\\pm$ {std}\"\n",
    "                if output == \"Accuracy\" and mean == max_mean:\n",
    "                    return f\"\\\\textbf{{{value}}}\"\n",
    "                return value\n",
    "\n",
    "            table_data += f\"{replace_dict[method]}{GNN_model} & {format_cell(trans_acc, trans_std, max_trans_acc)} & {format_cell(semi_ind_acc, semi_ind_std, max_semi_ind_acc)} & {format_cell(temp_trans_acc, temp_trans_std, max_temp_trans_acc)} \\\\\\\\ \\\\hline\\n\"\n",
    "\n",
    "    table_str = table_start_str + table_data + \"\\\\end{tabular}\\n\\\\caption{\" + f\"{output} for the {data_name} experiment with data split {'/'.join((prop*100).astype(int).astype(str))}.\" + \"}\\n\\\\end{subtable}\"\n",
    "\n",
    "    print(table_str)\n",
    "    print(\"\")\n",
    "\n",
    "table_end_end_str = \"\\n\\\\caption{}\\n\\\\end{table}\"\n",
    "print(table_end_end_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{table}[h]\n",
      "\\centering\n",
      "\\begin{tabular}{|l|l|l|}\n",
      "\\hline\n",
      "Embedding & \\multicolumn{2}{c|}{School} \\\\ \n",
      "\\cline{2-3}\n",
      "& Trans. & Semi-ind.  \\\\ \n",
      "\\hline\n",
      "Block GCN & \\textbf{0.781 $\\pm$ 0.115} & \\textbf{0.865 $\\pm$ 0.059} \\\\ \\hline\n",
      "UGCN & 0.712 $\\pm$ 0.075 & $0.707 \\pm 0.060$ \\\\ \\hline\n",
      "Block GAT & 0.802 $\\pm$ 0.089 & $0.779 \\pm 0.078$ \\\\ \\hline\n",
      "UGAT & \\textbf{0.842 $\\pm$ 0.072} & \\textbf{0.861 $\\pm$ 0.053} \\\\ \\hline\n",
      "\\end{tabular}\n",
      "\\caption{}\n",
      "\\label{tab:School_TSC}\n",
      "\\end{table}\n"
     ]
    }
   ],
   "source": [
    "# df_summary[\"name\"] = df_summary[\"method\"] + \" \" + df_summary[\"GNN model\"]\n",
    "\n",
    "# replace_dict = {\n",
    "#     \"BD GCN\": \"Block GCN\",\n",
    "#     \"BD GAT\": \"Block GAT\",\n",
    "#     \"UA GCN\": \"UGCN\",\n",
    "#     \"UA GAT\": \"UGAT\",\n",
    "# }\n",
    "\n",
    "# df_summary[\"name\"] = df_summary[\"name\"].replace(replace_dict)\n",
    "\n",
    "\n",
    "# output=\"TSC\"\n",
    "\n",
    "# table_start_str = (\n",
    "#     \"\\\\begin{table}[h]\\n\\\\centering\\n\\\\begin{tabular}{|l|l|l|}\\n\\\\hline\\nEmbedding & \\\\multicolumn{2}{c|}{\"\n",
    "#     + data_name\n",
    "#     + \"} \\\\\\ \\n\\\\cline{2-3}\\n& Trans. & Semi-ind.  \\\\\\ \\n\\\\hline\\n\"\n",
    "# )\n",
    "\n",
    "# # table_data_1 = \"ISE & $0.505 \\\\pm 0.000$ & $0.248 \\\\pm 0.000$ \\\\\\\\ \\\\hline\\n\"\n",
    "\n",
    "\n",
    "# table_data = \"\"\n",
    "# # Loop over each unique method + regime pair\n",
    "# for (GNN_model, method) in product(GNN_models, methods):\n",
    "\n",
    "#     max_trans_acc = df_summary[\n",
    "#         (df_summary[\"regime\"] == \"Trans\") & (df_summary[\"output\"] == output) & (df_summary[\"GNN model\"] == GNN_model)\n",
    "#     ][\"value\"].max()\n",
    "#     max_semi_ind_acc = df_summary[\n",
    "#         (df_summary[\"regime\"] == \"Semi-Ind\") & (df_summary[\"output\"] == output) & (df_summary[\"GNN model\"] == GNN_model)\n",
    "#     ][\"value\"].max()\n",
    "\n",
    "#     trans_acc = df_summary[\n",
    "#         (df_summary[\"method\"] == method)\n",
    "#         & (df_summary[\"GNN model\"] == GNN_model)\n",
    "#         & (df_summary[\"regime\"] == \"Trans\")\n",
    "#         & (df_summary[\"output\"] == output)\n",
    "#         & (df_summary[\"statistic\"] == \"Mean\")\n",
    "#     ][\"value\"].values[0]\n",
    "#     trans_std = df_summary[\n",
    "#         (df_summary[\"method\"] == method)\n",
    "#         & (df_summary[\"GNN model\"] == GNN_model)\n",
    "#         & (df_summary[\"regime\"] == \"Trans\")\n",
    "#         & (df_summary[\"output\"] == output)\n",
    "#         & (df_summary[\"statistic\"] == \"St Dev\")\n",
    "#     ][\"value\"].values[0]\n",
    "#     semi_ind_acc = df_summary[\n",
    "#         (df_summary[\"method\"] == method)\n",
    "#         & (df_summary[\"GNN model\"] == GNN_model)\n",
    "#         & (df_summary[\"regime\"] == \"Semi-Ind\")\n",
    "#         & (df_summary[\"output\"] == output)\n",
    "#         & (df_summary[\"statistic\"] == \"Mean\")\n",
    "#     ][\"value\"].values[0]\n",
    "#     semi_ind_std = df_summary[\n",
    "#         (df_summary[\"method\"] == method)\n",
    "#         & (df_summary[\"GNN model\"] == GNN_model)\n",
    "#         & (df_summary[\"regime\"] == \"Semi-Ind\")\n",
    "#         & (df_summary[\"output\"] == output)\n",
    "#         & (df_summary[\"statistic\"] == \"St Dev\")\n",
    "#     ][\"value\"].values[0]\n",
    "\n",
    "#     method_name = df_summary[\n",
    "#         (df_summary[\"method\"] == method)\n",
    "#         & (df_summary[\"GNN model\"] == GNN_model)\n",
    "#     ][\"name\"].values[0]\n",
    "    \n",
    "\n",
    "#     if trans_acc != max_trans_acc:\n",
    "#         table_data += f\"{method_name} & {trans_acc:.3f} $\\\\pm$ {trans_std:.3f}\"\n",
    "#     else:\n",
    "#         table_data += (\n",
    "#             method_name + \" & \\\\textbf{\" + f\"{trans_acc:.3f} $\\\\pm$ {trans_std:.3f}\" + \"}\"\n",
    "#         )\n",
    "\n",
    "#     if semi_ind_acc != max_semi_ind_acc:\n",
    "#         table_data += f\" & ${semi_ind_acc:.3f} \\\\pm {semi_ind_std:.3f}$ \\\\\\\\ \\\\hline\\n\"\n",
    "#     else:\n",
    "#         table_data += (\n",
    "#             \" & \\\\textbf{\"\n",
    "#             + f\"{semi_ind_acc:.3f} $\\\\pm$ {semi_ind_std:.3f}\"\n",
    "#             + \"} \\\\\\\\ \\\\hline\\n\"\n",
    "#         )\n",
    "\n",
    "# table_end_str = (\n",
    "#     \"\\\\end{tabular}\\n\\\\caption{}\\n\\\\label{tab:\"\n",
    "#     + data_name\n",
    "#     + \"_\"+output+\"}\\n\\\\end{table}\"\n",
    "# )\n",
    "\n",
    "# full_table = table_start_str + table_data + table_end_str\n",
    "\n",
    "# print(full_table)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "TGB_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
