{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import glob\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from copy import deepcopy\n",
    "from publib import set_style, fix_style\n",
    "\n",
    "\n",
    "set_style(['article']) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment Results Visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Plotting Setups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Folder path containing the results\n",
    "exp_name = 'May08'\n",
    "baseline_results_folder = './results/' + exp_name\n",
    "influence_results_dir = \"./influence_results/\" + exp_name\n",
    "\n",
    "output_folder = './experiment_figs'\n",
    "os.makedirs(output_folder, exist_ok=True)\n",
    "\n",
    "# Define the datasets and models\n",
    "datasets = ['paris', 'shanghai', 'la', 'london', 'cora']\n",
    "models = ['mlp','cheb', 'gcn', 'sage', 'sgformer']\n",
    "n_layers_list = [2, 4, 8, 16]\n",
    "\n",
    "d_labels = {\n",
    "    \"paris\":\"Paris\", \"shanghai\":\"Shanghai\", \"la\":\"Los Angeles\", \"london\":\"London\",\n",
    "    \"cora\":\"Cora\",\"citeseer\":\"CiteSeer\", \"ogbn-arxiv\":\"ogbn-arxiv\",\n",
    "    \"PascalVOC-SP\":\"PascalVOC-Transductive\", \"COCO-SP\":\"COCO-Transductive\",\n",
    "    \"mlp\":\"MLP\", \"cheb\":\"ChebNet\", \"gcn\":\"GCN\", \"gat\":\"GAT\", \"sage\":\"GraphSAGE\", \"sgformer\":\"SGFormer\",\n",
    "}\n",
    "\n",
    "d_color = {\n",
    "    \"sgformer\":\"#c5373e\", # Red, \"#9c251c\", rgb(197, 55, 62)\n",
    "    \"gcn\":   \"#006eae\", # Blue, \"#00498d\", rgb(0, 110, 174)\n",
    "    \"sage\":  \"#439130\", # Green, \"#1c6e2b\", rgb(67, 145, 48)\n",
    "    \"cheb\": \"#e96a00\", # Orange, \"#b34900\", rgb(233, 106, 0)\n",
    "    \"mlp\":  \"#734e3e\", # \"#734e3e\", rgb(115, 78, 62)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load Baseline Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data structure to hold results\n",
    "results = {\n",
    "    dataset:{\n",
    "        model: {}\n",
    "        for model in models\n",
    "    } \n",
    "    for dataset in datasets\n",
    "}\n",
    "\n",
    "# Function to parse the final test results from a file\n",
    "def parse_results(file_path):\n",
    "    results_ = {}  \n",
    "    for name in glob.glob(f'{file_path}/*.json'):\n",
    "        n_layers = int(name.split('nlayers-')[-1].split('_')[0])\n",
    "        with open(name, 'r') as file:\n",
    "            result_single = json.load(file)\n",
    "        mean = result_single.get('mean_test_acc', None)\n",
    "        std = result_single.get('std_test_acc', None)\n",
    "        if mean:\n",
    "            results_[n_layers] = [mean, std]\n",
    "    return results_\n",
    "\n",
    "# Read and store the results for each dataset and model\n",
    "for dataset in datasets:\n",
    "    for model in models:\n",
    "        file_folder = baseline_results_folder + f\"/{dataset}_{model}\"\n",
    "        if os.path.exists(file_folder):\n",
    "            current_result = parse_results(file_folder)\n",
    "            for n_layer, result in current_result.items():\n",
    "                results[dataset][model][n_layer] = result\n",
    "            r = deepcopy(results[dataset][model])\n",
    "            results[dataset][model] = dict(sorted(r.items()))\n",
    "\n",
    "for i in results:\n",
    "    print(f\"results for {i}\")\n",
    "    for j in results[i]:\n",
    "        print(f\"{j}: {results[i][j]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Plot Baseline Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tick_size=24\n",
    "fig, axes = plt.subplots(1, 5, figsize=(25, 6), sharey=False)\n",
    "\n",
    "for idx, dataset in enumerate(datasets):\n",
    "    ax = axes[idx]\n",
    "    for model in models:\n",
    "        n_layers_plot = results[dataset][model].keys()\n",
    "        mean_std = results[dataset][model].values()\n",
    "        means = [x[0] for x in mean_std]\n",
    "        stds = [x[1] for x in mean_std]\n",
    "        if means and stds:  # Ensure there are data to plot\n",
    "            lower_bound = [\n",
    "                m - s if m else None\n",
    "                for m, s in zip(means, stds)\n",
    "            ]\n",
    "            upper_bound = [\n",
    "                m + s if m else None\n",
    "                for m, s in zip(means, stds)\n",
    "            ]\n",
    "            ax.plot(n_layers_plot, means, \".-\", label=d_labels[model], color=d_color[model], markersize=20)  # Mean line\n",
    "            ax.fill_between(n_layers_plot, lower_bound, upper_bound, color=d_color[model], alpha=0.2)  # Shaded area for std deviation\n",
    "        if idx==4:\n",
    "            handles, labels = axes[idx].get_legend_handles_labels()\n",
    "\n",
    "    # Setting custom tick labels\n",
    "    ax.set_xscale('log', base=2)\n",
    "    ax.set_xticks(n_layers_list, )\n",
    "    ax.tick_params(axis='both', labelsize=tick_size)\n",
    "    ax.set_xticklabels([str(x) for x in n_layers_list])\n",
    "    ax.grid(True, which='major', axis='both', color='gray', alpha=0.5, linestyle='--', linewidth=0.7)\n",
    "\n",
    "\n",
    "    if idx < 4:\n",
    "        ax.set_ylim(10, 85)\n",
    "        ax.set_yticks([20, 40, 60, 80])\n",
    "    else:\n",
    "        ax.set_ylim(10, 85)\n",
    "    \n",
    "    ax.set_title(d_labels[dataset], fontsize=tick_size+4)\n",
    "    ax.set_xlabel('# Layers', fontsize=tick_size+2)\n",
    "    if idx == 0:\n",
    "        ax.set_ylabel('Test Acc', fontsize=tick_size+4)\n",
    "\n",
    "# Adding a global legend outside the plots\n",
    "fig.legend(handles=handles, labels=labels, loc='upper center', bbox_to_anchor=(0.5, 0.02), \n",
    "            frameon=True, fancybox=False, shadow=False, fontsize=tick_size-1, ncol=8, markerscale=1)\n",
    "\n",
    "fix_style('article')\n",
    "plt.savefig(f\"{output_folder}/{exp_name}_baseline_results.jpg\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Load Influence Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Update the plot settings\n",
    "datasets = ['paris', 'shanghai', 'la', 'london', 'cora', 'citeseer']\n",
    "models = ['gcn', 'cheb', 'sage', 'sgformer']\n",
    "\n",
    "d_color = {\n",
    "    \"paris\":\"#c5373e\", # Red, \"#9c251c\", rgb(197, 55, 62)\n",
    "    \"shanghai\":   \"#006eae\", # Blue, \"#00498d\", rgb(0, 110, 174)\n",
    "    \"la\":  \"#439130\", # Green, \"#1c6e2b\",   rgb(67, 145, 48)\n",
    "    \"london\": \"#e96a00\", # Orange, \"#b34900\",    rgb(233, 106, 0)\n",
    "    \"cora\": \"#a54891\", # \"#43536a\",   rgb(110, 120, 142)\n",
    "    \"citeseer\":  \"#734e3e\", # \"#734e3e\",    rgb(115, 78, 62)\n",
    "    \"PascalVOC-SP\": \"#202124\", # black #792373,     rgb(165, 72, 145)\n",
    "    \"COCO-SP\": \"#6e788e\", # Grey #6e788e\n",
    "    \"ogbn-arxiv\": \"#dca22c\" # Yellow #dca22c\n",
    "}\n",
    "\n",
    "dot_type_dict = {\n",
    "    \"paris\":\"o-\", \"shanghai\":\"o-\", \"la\":\"o-\", \"london\":\"o-\", \n",
    "    \"cora\":\"s-\", \"citeseer\":\"s-\", \"ogbn-arxiv\":\"s-\", \n",
    "    \"PascalVOC-SP\":\"v-\", \"COCO-SP\":\"v-\", \n",
    "}\n",
    "\n",
    "# Load all results into a dictionary\n",
    "results = {model: {} for model in models}\n",
    "\n",
    "for dataset in datasets:\n",
    "    for model in models:\n",
    "        file_path = os.path.join(influence_results_dir, f\"{dataset}_{model}_avg_tot_inf.npy\")\n",
    "        if os.path.exists(file_path):\n",
    "            results[model][dataset] = {}\n",
    "            results[model][dataset]['avg_tot_inf'] = np.load(file_path)\n",
    "            file_path = os.path.join(influence_results_dir, f\"{dataset}_{model}_R.npy\")\n",
    "            results[model][dataset]['R'] = np.load(file_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Plot Influence Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tick_size = 20\n",
    "R_dict = {model: {} for model in models}\n",
    "\n",
    "fig, axes = plt.subplots(1, 4, figsize=(20, 5), sharey=False)\n",
    "\n",
    "for i, model in enumerate(models):\n",
    "    ax = axes[i]\n",
    "    for dataset in datasets:\n",
    "        if dataset in results[model]:\n",
    "            norm_avg_tot_inf = results[model][dataset]['avg_tot_inf']\n",
    "            dot_type = dot_type_dict[dataset]\n",
    "            R = results[model][dataset]['R']\n",
    "            R_dict[model][dataset] = R.round(3)\n",
    "            x = np.arange(len(norm_avg_tot_inf))\n",
    "            ax.plot(x, norm_avg_tot_inf, dot_type, \n",
    "                    label=d_labels[dataset], color=d_color[dataset], markersize=8)\n",
    "        \n",
    "    ax.set_title(d_labels[model], fontsize=tick_size+2)\n",
    "    ax.tick_params(axis='both', labelsize=tick_size)\n",
    "    ax.set_xlabel(r\"Hop $h$\", fontsize=tick_size+2)\n",
    "    ax.set_yscale('log', base=10)\n",
    "    if i == 0:\n",
    "        ax.set_ylabel(r\"$\\bar{T}_h$ / $\\bar{T}_0$\", fontsize=tick_size+3)\n",
    "        handles, labels = axes[i].get_legend_handles_labels()\n",
    "\n",
    "# Add a legend from the last subplot\n",
    "fig.legend(handles=handles, labels=labels, loc='upper center', bbox_to_anchor=(0.5, 0.02), \n",
    "            frameon=True, fancybox=False, shadow=False, fontsize=tick_size-3, ncol=9, markerscale=1)\n",
    "\n",
    "fix_style('article')\n",
    "plt.savefig(f\"{output_folder}/{exp_name}_influence.jpg\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Results for influence-weighted receptive field R\n",
    "for key, value in R_dict.items():\n",
    "    print(key, value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Calculation of the Lower Bound in Section 5\n",
    "\n",
    "Here we show the calculation of Equation (6) in Theorem 5.1: Bound on second largest positive eigenvalue of the normalized adjacency operator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_degree = {\n",
    "    \"paris\":15,\n",
    "    \"shanghai\":8,\n",
    "    \"la\":9,\n",
    "    \"london\":10,\n",
    "    \"pascal\":10,\n",
    "    \"coco\":10,\n",
    "    \"cora\":168,\n",
    "    \"citeseer\":99,\n",
    "    \"ogbn\":13000,\n",
    "}\n",
    "diameter = {\n",
    "    \"paris\":121,\n",
    "    \"shanghai\":123,\n",
    "    \"la\":171,\n",
    "    \"london\":404,\n",
    "    \"pascal\":28,\n",
    "    \"coco\":27,\n",
    "    \"cora\":19,\n",
    "    \"citeseer\":28,\n",
    "    \"ogbn\":25,\n",
    "}\n",
    "\n",
    "def get_lower_bound(max_degree, diameter):\n",
    "    term1 = 2 * (max_degree - 1)**0.5 / max_degree\n",
    "    term2 = 2 / diameter\n",
    "    result = term1 - term2 * (1 + term1)\n",
    "    return result\n",
    "\n",
    "for key in max_degree.keys():\n",
    "    ret = get_lower_bound(max_degree[key], diameter[key])\n",
    "    print(f\"{key}: {ret:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
