{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e416144",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from rae import PROJECT_ROOT\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50808ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encoder_factory(encoder_type, num_layers: int, in_channels: int, out_channels: int, **params):\n",
    "    assert num_layers > 0\n",
    "    if encoder_type == \"GCN2Conv\":\n",
    "        convs = []\n",
    "        for layer in range(num_layers):\n",
    "            convs.append(GCN2Conv(layer=layer + 1, channels=out_channels, **params))\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GCNConv\":\n",
    "        convs = []\n",
    "        convs = [\n",
    "            GCNConv(\n",
    "                in_channels=in_channels,\n",
    "                out_channels=out_channels,\n",
    "                **params,\n",
    "            )\n",
    "        ]\n",
    "        in_channels = out_channels\n",
    "        for layer in range(num_layers - 1):\n",
    "            convs.append(\n",
    "                GCNConv(\n",
    "                    in_channels=in_channels,\n",
    "                    out_channels=out_channels,\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GATConv\":\n",
    "        convs = []\n",
    "        convs = [\n",
    "            GATConv(\n",
    "                in_channels=in_channels,\n",
    "                out_channels=out_channels,\n",
    "                **params,\n",
    "            )\n",
    "        ]\n",
    "        in_channels = out_channels\n",
    "        for layer in range(num_layers - 1):\n",
    "            convs.append(\n",
    "                GATConv(\n",
    "                    in_channels=in_channels,\n",
    "                    out_channels=out_channels,\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    elif encoder_type == \"GINConv\":\n",
    "        convs = []\n",
    "        current_in_channels = in_channels\n",
    "        for layer in range(num_layers):\n",
    "            convs.append(\n",
    "                GINConv(\n",
    "                    nn=nn.Linear(\n",
    "                        in_features=current_in_channels,\n",
    "                        out_features=out_channels,\n",
    "                    ),\n",
    "                    **params,\n",
    "                )\n",
    "            )\n",
    "            current_in_channels = out_channels\n",
    "        return nn.ModuleList(convs)\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d76c519",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7873038",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments = torch.load(PROJECT_ROOT / \"experiments\" / \"sec:data-manifold\" / f\"{'Cora'}_data_manifold_experiments.pt\")\n",
    "len(experiments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87271d68",
   "metadata": {},
   "outputs": [],
   "source": [
    "stats = pd.read_csv(\n",
    "    PROJECT_ROOT / \"experiments\" / \"sec:data-manifold\" / f\"{'Cora'}_data_manifold_stats.tsv\", sep=\"\\t\", index_col=0\n",
    ")\n",
    "stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3274863",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c94f5bf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter experiments that reach at least 0.7 acc.\n",
    "VAL_ACC_LOWER_BOUND = 0.5\n",
    "\n",
    "df_max_acc = stats.groupby([\"experiment\"]).agg([np.max])[\"val_acc\"]\n",
    "best_experiments = df_max_acc.loc[df_max_acc[\"amax\"] > VAL_ACC_LOWER_BOUND]\n",
    "best_experiments = best_experiments.reset_index().experiment\n",
    "df_filtered = stats[stats[\"experiment\"].isin(best_experiments)]\n",
    "df_filtered, len(set(df_filtered.experiment))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7000ed53",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments_valacc_similarity_correlation = []\n",
    "for exp in set(stats.experiment):\n",
    "    d_exp = df_filtered.loc[stats[\"experiment\"] == exp]\n",
    "    exp_corr = d_exp.corr(method=\"pearson\")\n",
    "    corr = exp_corr[\"val_acc\"][\"reference_distance\"]\n",
    "    if not math.isnan(corr):\n",
    "        experiments_valacc_similarity_correlation.append(corr)\n",
    "p_corr = np.mean(experiments_valacc_similarity_correlation)\n",
    "\n",
    "print(\"Pearson correlation val_acc - ref_similarity: \", p_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9e3ff3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_run = torch.load(PROJECT_ROOT / \"experiments\" / \"sec:data-manifold\" / f\"{'Cora'}_best_run.pt\")\n",
    "best_run_latents = [best_run[\"best_epoch\"][\"rel_x\"]]\n",
    "best_run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2aa56bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "\n",
    "\n",
    "def get_distance(latents1: torch.Tensor, latents_ref: Sequence[torch.Tensor]):\n",
    "    assert not isinstance(latents_ref, (np.ndarray, torch.Tensor))\n",
    "    dists = [F.cosine_similarity(latents1, latent_ref).mean().item() for latent_ref in latents_ref]\n",
    "    return np.mean(dists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1dcf8dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "filtered_experiments = [\n",
    "    x for x in experiments if not math.isnan(x[\"best_epoch\"][\"loss\"]) or not np.isnan(x[\"best_epoch\"][\"rel_x\"]).any()\n",
    "]\n",
    "len(filtered_experiments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d645865e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "points = {\"score\": [], \"similarity\": [], \"hyperparams\": [], \"optimizer\": [], \"encoder\": [], \"color\": []}\n",
    "for run in filtered_experiments:\n",
    "    distance = get_distance(latents1=F.normalize(run[\"best_epoch\"][\"rel_x\"], dim=-1, p=2), latents_ref=best_run_latents)\n",
    "    if np.isnan(distance):\n",
    "        continue\n",
    "    score = run[\"best_epoch\"][\"val_acc\"]\n",
    "    points[\"score\"].append(score)\n",
    "    points[\"similarity\"].append(distance)\n",
    "    hyperparams = {}\n",
    "    for key in (\"encoder\", \"lr\", \"hidden_fn\", \"conv_fn\", \"optimizer\"):\n",
    "        run_value = run[key]\n",
    "        if key == \"encoder\":\n",
    "            run_value = run_value[0]\n",
    "        elif \"_fn\" in key:\n",
    "            run_value = type(run_value).__name__\n",
    "        elif key == \"optimizer\":\n",
    "            run_value = run_value.__name__\n",
    "\n",
    "        hyperparams[key] = run_value\n",
    "    points[\"optimizer\"].append(hyperparams[\"optimizer\"])\n",
    "    points[\"encoder\"].append(hyperparams[\"encoder\"])\n",
    "    points[\"hyperparams\"].append(json.dumps(hyperparams))\n",
    "    points[\"color\"].append(f'{points[\"optimizer\"]}_{points[\"encoder\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb9417b",
   "metadata": {},
   "outputs": [],
   "source": [
    "max(points[\"similarity\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eb3b540",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = 1\n",
    "RATIO = 1\n",
    "\n",
    "plt.rcParams.update(bundles.iclr2023(usetex=True))\n",
    "plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "\n",
    "fig, ax = plt.subplots(nrows=N_ROWS, ncols=N_COLS, dpi=150)\n",
    "\n",
    "\n",
    "def plot_points(ax, pts, s=5):\n",
    "    df = pd.DataFrame(pts)\n",
    "    ax.set_aspect(\"auto\")\n",
    "\n",
    "    ax.scatter(df.similarity, df.score, s=s)\n",
    "\n",
    "    z = np.polyfit(df.similarity, df.score, 1)\n",
    "    trend_line = np.poly1d(z)\n",
    "    ax.plot(np.asarray(sorted(df.similarity)), trend_line(sorted(df.similarity)), \"C3--\")\n",
    "\n",
    "\n",
    "#     ax.set_xlabel('Similarity')\n",
    "#     ax.set_ylabel('Score')\n",
    "\n",
    "plot_points(ax, points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "293dbadf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "034b1d97",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"score_vs_distance.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o score_vs_distance.pdf score_vs_distance.svg\n",
    "!rm score_vs_distance.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc64926f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "from plotly.subplots import make_subplots\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4ab6690",
   "metadata": {},
   "source": [
    "# Correlation over time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca0db554",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_lightning import seed_everything\n",
    "import random\n",
    "\n",
    "\n",
    "# Filter experiments that reach at least 0.7 acc.\n",
    "VAL_ACC_LOWER_BOUND = 0.9\n",
    "\n",
    "df_max_acc = stats.groupby([\"experiment\"]).agg([np.max])[\"val_acc\"]\n",
    "best_experiments = df_max_acc.loc[df_max_acc[\"amax\"] > VAL_ACC_LOWER_BOUND]\n",
    "best_experiments = best_experiments.reset_index().experiment\n",
    "df = stats[stats[\"experiment\"].isin(best_experiments)]\n",
    "available_experiments = sorted(set(df.experiment))\n",
    "\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = 1\n",
    "RATIO = 1\n",
    "\n",
    "plt.rcParams.update(bundles.iclr2023(usetex=True))\n",
    "plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "\n",
    "\n",
    "def plot_score_dist_over_time(ax, df):\n",
    "    ax.set_aspect(\"auto\")\n",
    "    ax2 = ax.twinx()\n",
    "    ax.plot(df.epoch, df.val_acc, \"C0-\")\n",
    "    # ax.set_ylabel(\"Validation Accuracy  \", color=\"C0\")\n",
    "\n",
    "    ax2.plot(df.epoch, df.reference_distance, \"C1-\")\n",
    "    # ax2.set_ylabel(\"Reference similarity\", color=\"C1\")\n",
    "\n",
    "\n",
    "#     ax.set_xlabel(\"epochs\")\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(nrows=N_ROWS, ncols=N_COLS, dpi=150)\n",
    "plot_score_dist_over_time(axes, df.loc[df[\"experiment\"] == available_experiments[5]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ba280b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"correlation_over_time.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o correlation_over_time.pdf correlation_over_time.svg\n",
    "!rm correlation_over_time.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbd656e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot both figures!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3124a56c",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_ROWS = 1\n",
    "N_COLS = 2\n",
    "RATIO = 0.8\n",
    "\n",
    "plt.rcParams.update(bundles.iclr2023(usetex=True))\n",
    "plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "\n",
    "fig, [col1, col2] = plt.subplots(nrows=N_ROWS, ncols=N_COLS, dpi=150)\n",
    "\n",
    "plot_points(col1, points, s=1)\n",
    "plot_score_dist_over_time(col2, df.loc[df[\"experiment\"] == available_experiments[5]])\n",
    "plt.subplots_adjust(wspace=0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c7d88e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"correlation_subfigure.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o correlation_subfigure.pdf correlation_subfigure.svg\n",
    "!rm correlation_subfigure.svg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5d4b1ec",
   "metadata": {},
   "source": [
    "# Correlation grid (supmat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "194dd6a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter experiments that reach at least 0.7 acc.\n",
    "VAL_ACC_LOWER_BOUND = 0.5\n",
    "\n",
    "df_max_acc = stats.groupby([\"experiment\"]).agg([np.max])[\"val_acc\"]\n",
    "best_experiments = df_max_acc.loc[df_max_acc[\"amax\"] > VAL_ACC_LOWER_BOUND]\n",
    "best_experiments = best_experiments.reset_index().experiment\n",
    "df = stats[stats[\"experiment\"].isin(best_experiments)]\n",
    "available_experiments = sorted(set(df.experiment))\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "581b5164",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_lightning import seed_everything\n",
    "import random\n",
    "\n",
    "seed_everything(0)\n",
    "random.shuffle(available_experiments)\n",
    "\n",
    "N_ROWS = 10\n",
    "N_COLS = 10\n",
    "RATIO = 1\n",
    "\n",
    "plt.rcParams.update(bundles.iclr2023(usetex=True))\n",
    "plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(nrows=N_ROWS, ncols=N_COLS, dpi=200, figsize=(15, 15))\n",
    "\n",
    "\n",
    "def plot_score_dist_over_time(ax, df):\n",
    "    ax2 = ax.twinx()\n",
    "    ax.plot(df.epoch, df.val_acc, \"C0-\")\n",
    "    #     ax.set_ylabel('Validation Accuracy  ', color='C0')\n",
    "\n",
    "    ax2.plot(df.epoch, df.reference_distance, \"C1-\")\n",
    "    #         ax2.set_ylabel('Reference similarity', color='C1')\n",
    "\n",
    "    ax.set_yticklabels([])\n",
    "    ax2.set_yticklabels([])\n",
    "    ax.set_xticklabels([])\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax2.set_yticks([])\n",
    "    ax.set_aspect(\"auto\")\n",
    "\n",
    "\n",
    "i = 0\n",
    "for row in axes:\n",
    "    for ax in row:\n",
    "        df_plot = df.loc[df[\"experiment\"] == available_experiments[i]]\n",
    "        plot_score_dist_over_time(ax, df_plot)\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d51322a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"correlation_grid.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o correlation_grid.pdf correlation_grid.svg\n",
    "!rm correlation_grid.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13696d84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
