{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c5e36e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "from matplotlib.font_manager import FontProperties\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from matplotlib.patches import Patch\n",
    "import pandas as pd\n",
    "import itertools\n",
    "from scipy.stats import gaussian_kde\n",
    "from pathlib import Path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70c2003c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# List of filenames\n",
    "file_names = [\n",
    "    \"ground_truth_dictionary_subfield.pkl\",\n",
    "    \"generated_dictionary_subfield.pkl\",\n",
    "    \"random_dictionary_subfield.pkl\"\n",
    "]\n",
    "\n",
    "def inspect_pickle(file_path, file_label):\n",
    "    with open(file_path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "\n",
    "    data_type = type(data)\n",
    "    num_entries = len(data)\n",
    "    \n",
    "    first_key = list(data.keys())[10]\n",
    "    first_value = data[first_key]\n",
    "\n",
    "    value_type = type(first_value)\n",
    "    top_keys = list(first_value.keys())\n",
    "\n",
    "    focal_embedding = first_value.get('focal_embedding')\n",
    "    if focal_embedding is not None:\n",
    "        focal_shape = focal_embedding.shape\n",
    "        focal_type = type(focal_embedding)\n",
    "\n",
    "    ref_embeddings = first_value.get('reference_embeddings')\n",
    "    if ref_embeddings is not None:\n",
    "        num_refs = len(ref_embeddings)\n",
    "        sample_ref_key = list(ref_embeddings.keys())[0]\n",
    "        ref_shape = ref_embeddings[sample_ref_key].shape\n",
    "        ref_type = type(ref_embeddings[sample_ref_key])\n",
    "\n",
    "# Loop through and inspect each file\n",
    "for name in file_names:\n",
    "    full_path = os.path.join(base_path, name)\n",
    "    inspect_pickle(full_path, name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a940b3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "labels = [\"ground_truth\", \"generated\", \"random\"]\n",
    "label_map = {\n",
    "    \"ground_truth\": \"Ground truth\",\n",
    "    \"generated\": \"Generated\",\n",
    "    \"random\": \"Random\"\n",
    "}\n",
    "colors = {\n",
    "    'ground_truth': '#76C1FA',\n",
    "    'generated':    '#F78FB3',\n",
    "    'random':       '#A8E6CF'\n",
    "}\n",
    "\n",
    "mean_embeddings = []\n",
    "class_labels = []\n",
    "\n",
    "\n",
    "for file_name, label in zip(file_names, labels):\n",
    "    file_path = os.path.join(base_path, file_name)\n",
    "    with open(file_path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "    \n",
    "    for graph_id, graph_data in data.items():\n",
    "        embeddings = []\n",
    "\n",
    "        if 'focal_embedding' in graph_data:\n",
    "            embeddings.append(np.asarray(graph_data['focal_embedding']).reshape(-1))\n",
    "\n",
    "        if 'reference_embeddings' in graph_data:\n",
    "            for ref in graph_data['reference_embeddings'].values():\n",
    "                embeddings.append(np.asarray(ref).reshape(-1))\n",
    "\n",
    "        if len(embeddings) == 0:\n",
    "            continue\n",
    "\n",
    "        mean_emb = np.mean(embeddings, axis=0)  \n",
    "        mean_embeddings.append(mean_emb)\n",
    "        class_labels.append(label)\n",
    "\n",
    "mean_embeddings = np.vstack(mean_embeddings) \n",
    "pca = PCA(n_components=2)\n",
    "mean_emb_pca = pca.fit_transform(mean_embeddings) \n",
    "\n",
    "pca_embeddings = mean_emb_pca         \n",
    "mean_labels   = np.array(class_labels) \n",
    "\n",
    "x_min, x_max = -0.35, 0.65\n",
    "y_vals = pca_embeddings[:, 1]\n",
    "y_pad = 0.05 * (y_vals.max() - y_vals.min() + 1e-9)\n",
    "y_min, y_max = y_vals.min() - y_pad, y_vals.max() + y_pad\n",
    "\n",
    "nx, ny = 300, 300\n",
    "xx, yy = np.meshgrid(\n",
    "    np.linspace(x_min, x_max, nx),\n",
    "    np.linspace(y_min, y_max, ny)\n",
    ")\n",
    "grid_positions = np.vstack([xx.ravel(), yy.ravel()]) \n",
    "\n",
    "plt.figure(figsize=(6, 5))\n",
    "\n",
    "for label in [\"ground_truth\", \"generated\", \"random\"]:\n",
    "    idx = np.where(mean_labels == label)[0]\n",
    "    if idx.size == 0:\n",
    "        continue\n",
    "\n",
    "    class_pca = pca_embeddings[idx].T \n",
    "\n",
    "    if class_pca.shape[1] >= 3:\n",
    "        kde = gaussian_kde(class_pca)   \n",
    "        zz = np.reshape(kde(grid_positions).T, xx.shape)\n",
    "\n",
    "        plt.contour(\n",
    "            xx, yy, zz,\n",
    "            levels=10,\n",
    "            linewidths=1.5,\n",
    "            colors=[colors[label]],\n",
    "        )\n",
    "\n",
    "    plt.scatter(\n",
    "        class_pca[0], class_pca[1],\n",
    "        s=8, alpha=0.75,\n",
    "        color=colors[label],\n",
    "        edgecolors='white',\n",
    "        linewidths=0.3,\n",
    "        label=label_map[label]\n",
    "    )\n",
    "plt.xticks(fontsize=9, fontfamily='serif')\n",
    "plt.yticks(fontsize=9, fontfamily='serif')\n",
    "plt.xlabel(\"PCA 1\", fontsize=10)\n",
    "plt.ylabel(\"PCA 2\", fontsize=10)\n",
    "plt.xlim(x_min, x_max)\n",
    "plt.ylim(y_min, y_max)\n",
    "plt.legend(\n",
    "    fontsize=8,\n",
    "    prop={'family': 'serif'},   \n",
    "    markerscale=2.0,  \n",
    "    handletextpad=0.6, \n",
    "    borderpad=0.3     \n",
    ")\n",
    "\n",
    "leg = plt.gca().get_legend()\n",
    "leg.get_frame().set_linewidth(0.5)\n",
    "leg.get_frame().set_edgecolor(\"black\") \n",
    "ax = plt.gca()\n",
    "for spine in ax.spines.values():\n",
    "    spine.set_linewidth(0.6)\n",
    "plt.tight_layout()\n",
    "\n",
    "out_path = base_path / \"kde_mean_embeddings_2d_contoursub.png\"\n",
    "plt.savefig(out_path, dpi=600, bbox_inches=\"tight\")\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa09c2a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_l2_per_dataset = {lbl: [] for lbl in labels}\n",
    "\n",
    "for file_name, label in zip(file_names, labels):\n",
    "    file_path = os.path.join(base_path, file_name)\n",
    "    with open(file_path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "\n",
    "    dists = []\n",
    "    for graph_id, graph_data in data.items():\n",
    "        if 'focal_embedding' not in graph_data or 'reference_embeddings' not in graph_data:\n",
    "            continue\n",
    "\n",
    "        focal = np.asarray(graph_data['focal_embedding']).reshape(-1)\n",
    "        refs = [np.asarray(ref).reshape(-1) for ref in graph_data['reference_embeddings'].values()]\n",
    "        if not refs:\n",
    "            continue\n",
    "\n",
    "        dist_list = [np.linalg.norm(focal - ref) for ref in refs]\n",
    "        dists.append(float(np.mean(dist_list)))\n",
    "\n",
    "    mean_l2_per_dataset[label] = dists\n",
    "\n",
    "# Build DataFrame for plotting (same structure as your centroid plot)\n",
    "plot_rows = []\n",
    "for lbl in labels:\n",
    "    display_name = label_map[lbl]\n",
    "    plot_rows.extend([(display_name, d) for d in mean_l2_per_dataset[lbl]])\n",
    "\n",
    "df_plot_mean_l2 = pd.DataFrame(plot_rows, columns=[\"Dataset\", \"Mean Euclidean Distance\"])\n",
    "\n",
    "plt.figure(figsize=(7, 3))\n",
    "sns.boxenplot(\n",
    "    x=\"Dataset\",\n",
    "    y=\"Mean Euclidean Distance\",\n",
    "    data=df_plot_mean_l2,\n",
    "    palette=[colors[lbl] for lbl in labels],\n",
    "    showfliers=False,\n",
    "    width=0.9\n",
    ")\n",
    "plt.ylim(0.4, 1.4)\n",
    "plt.yticks([0.4, 0.6, 0.8, 1, 1.2, 1.4])\n",
    "\n",
    "plt.xlabel(\"\", fontsize=20, fontfamily='serif')\n",
    "plt.ylabel(\"Mean \\n Euclidean distance\", fontsize=20, fontfamily='serif')\n",
    "plt.xticks(fontsize=20, fontfamily='serif')\n",
    "plt.yticks(fontsize=20, fontfamily='serif')\n",
    "\n",
    "out_path = base_path / \"high_quality_graphEuclideansub.png\"\n",
    "plt.savefig(out_path, dpi=600, bbox_inches='tight')\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a59cfb45",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "centroid_l2_per_dataset = {}\n",
    "name_map = {\n",
    "    \"ground_truth_dictionary_subfield.pkl\": \"ground_truth\",\n",
    "    \"generated_dictionary_subfield.pkl\": \"generated\",\n",
    "    \"random_dictionary_subfield.pkl\": \"random\"\n",
    "}\n",
    "\n",
    "label_map = {\n",
    "    \"ground_truth\": \"Ground truth\",\n",
    "    \"generated\": \"Generated\",\n",
    "    \"random\": \"Random\"\n",
    "}\n",
    "for name in file_names:\n",
    "    full_path = os.path.join(base_path, name)\n",
    "    with open(full_path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "\n",
    "    dists = []\n",
    "    for graph_content in data.values():\n",
    "        focal = np.asarray(graph_content['focal_embedding']).reshape(-1)\n",
    "        refs = [np.asarray(ref).reshape(-1) for ref in graph_content['reference_embeddings'].values()]\n",
    "\n",
    "        if len(refs) == 0:\n",
    "            continue\n",
    "\n",
    "        centroid = np.mean(refs, axis=0)  \n",
    "        l2 = np.linalg.norm(focal - centroid)   \n",
    "        dists.append(float(l2))\n",
    "\n",
    "    short_key = name_map[name]\n",
    "    centroid_l2_per_dataset[short_key] = dists\n",
    "\n",
    "plot_data_l2 = []\n",
    "for key in [\"ground_truth\", \"generated\", \"random\"]:\n",
    "    display_name = label_map[key]\n",
    "    plot_data_l2.extend([(display_name, d) for d in centroid_l2_per_dataset[key]])\n",
    "\n",
    "df_plot_l2 = pd.DataFrame(plot_data_l2, columns=[\"Dataset\", \"Euclidean Distance (focal vs. refs centroid)\"])\n",
    "\n",
    "plt.figure(figsize=(7, 3))\n",
    "sns.boxenplot(\n",
    "    x=\"Dataset\",\n",
    "    y=\"Euclidean Distance (focal vs. refs centroid)\",\n",
    "    data=df_plot_l2,\n",
    "    palette=[colors[lbl] for lbl in labels],\n",
    "    showfliers=False\n",
    ")\n",
    "plt.ylim(0.4, 1.4)\n",
    "plt.yticks([0.4, 0.6, 0.8, 1, 1.2, 1.4])\n",
    "plt.xlabel(\"\", fontsize=20, fontfamily='serif')\n",
    "plt.ylabel(\"Euclidean distance \\n to centroid\", fontsize=20, fontfamily='serif')\n",
    "plt.xticks(fontsize=20, fontfamily='serif')\n",
    "plt.yticks(fontsize=20, fontfamily='serif')\n",
    "\n",
    "plt.savefig(\n",
    "    base_path / \"high_quality_graph_centroid_l2sub.png\",\n",
    "    dpi=600,\n",
    "    bbox_inches=\"tight\"\n",
    ")\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69d9fe38",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "pairwise_avg_l2_per_dataset = {}\n",
    "\n",
    "for name in file_names:\n",
    "    full_path = os.path.join(base_path, name)\n",
    "    with open(full_path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "\n",
    "    pairwise_avg_l2 = []\n",
    "    for graph_content in data.values():\n",
    "        focal = np.asarray(graph_content['focal_embedding']).reshape(-1)\n",
    "        refs = [np.asarray(v).reshape(-1) for v in graph_content['reference_embeddings'].values()]\n",
    "\n",
    "        all_embeddings = [focal] + refs\n",
    "        if len(all_embeddings) < 2:\n",
    "            continue\n",
    "\n",
    "        pairwise_dists = [\n",
    "            np.linalg.norm(v1 - v2)\n",
    "            for v1, v2 in itertools.combinations(all_embeddings, 2)\n",
    "        ]\n",
    "        pairwise_avg_l2.append(float(np.mean(pairwise_dists)))\n",
    "\n",
    "    short_key = name_map[name]\n",
    "    pairwise_avg_l2_per_dataset[short_key] = pairwise_avg_l2\n",
    "\n",
    "plot_data_pairwise_l2 = []\n",
    "for key, values in pairwise_avg_l2_per_dataset.items():\n",
    "    display_name = label_map[key]\n",
    "    plot_data_pairwise_l2.extend([(display_name, v) for v in values])\n",
    "\n",
    "df_plot_pairwise_l2 = pd.DataFrame(plot_data_pairwise_l2, columns=[\"Dataset\", \"Average Pairwise Euclidean Distance\"])\n",
    "\n",
    "plt.figure(figsize=(7, 3))\n",
    "sns.boxenplot(\n",
    "    x=\"Dataset\",\n",
    "    y=\"Average Pairwise Euclidean Distance\",\n",
    "    data=df_plot_pairwise_l2,\n",
    "    palette=[colors[lbl] for lbl in labels],\n",
    "    showfliers=False\n",
    ")\n",
    "plt.ylim(0.4, 1.4)\n",
    "plt.yticks([0.4, 0.6, 0.8, 1, 1.2, 1.4])\n",
    "plt.xlabel(\"\", fontsize=20, fontfamily='serif')\n",
    "plt.ylabel(\"Mean pairwise \\n Euclidean distance\", fontsize=20, fontfamily='serif')\n",
    "plt.xticks(fontsize=20, fontfamily='serif')\n",
    "plt.yticks(fontsize=20, fontfamily='serif')\n",
    "\n",
    "plt.savefig(base_path / \"high_quality_graph_pairwise_l2sub.png\",\n",
    "            dpi=600, bbox_inches='tight')\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b15de05",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# PCA projection\n",
    "pca = PCA(n_components=2)\n",
    "mean_emb_pca = pca.fit_transform(mean_embeddings)\n",
    "\n",
    "# Font configs\n",
    "label_font = {'fontsize': 17, 'fontfamily': 'serif'}\n",
    "tick_fontsize = 17\n",
    "legend_font = FontProperties(family='serif', size=17)\n",
    "\n",
    "fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(4, 9), dpi=300)\n",
    "\n",
    "for ax, label in zip(axes, labels):\n",
    "    idx = [i for i, l in enumerate(class_labels) if l == label]\n",
    "    x = mean_emb_pca[idx, 0]\n",
    "    y = mean_emb_pca[idx, 1]\n",
    "\n",
    "    ax.scatter(x, y,\n",
    "               alpha=0.75,\n",
    "               s=8,\n",
    "               color=colors[label],\n",
    "               edgecolors='white',\n",
    "               linewidth=0.4,\n",
    "               marker='o')\n",
    "\n",
    "    ax.set_xlabel('', fontdict=label_font)\n",
    "    ax.set_ylabel('', fontdict=label_font)\n",
    "    ax.set_xlim(-0.35, 0.35)\n",
    "    ax.set_xticks([-0.3, -0.1, 0.1, 0.3])\n",
    "    ax.set_ylim(-0.35, 0.35)\n",
    "    ax.set_yticks([-0.3, -0.1, 0.1, 0.3])\n",
    "\n",
    "\n",
    "    for tick_label in ax.get_xticklabels() + ax.get_yticklabels():\n",
    "        tick_label.set_fontfamily('serif')\n",
    "    \n",
    "    ax.tick_params(width=0.1, length=1.2)\n",
    "    ax.tick_params(axis='x', labelsize=tick_fontsize, pad=1)\n",
    "    ax.tick_params(axis='y', labelsize=tick_fontsize, pad=1)\n",
    "\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_linewidth(0.4)\n",
    "\n",
    "plt.savefig(base_path / \"pca_separate_plotssub.png\", dpi=600, bbox_inches='tight')\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b513f97e",
   "metadata": {},
   "outputs": [],
   "source": [
    "left_img_path = base_path / \"kde_mean_embeddings_2d_contoursub.png\"\n",
    "right_img_path = base_path / \"pca_separate_plotssub.png\"\n",
    "fig = plt.figure(figsize=(10, 4.7), dpi=600)\n",
    "\n",
    "ax_left = fig.add_axes([0.03, 0.05, 0.43, 0.9])\n",
    "img_left = mpimg.imread(left_img_path)\n",
    "ax_left.imshow(img_left)\n",
    "ax_left.axis('off')\n",
    "\n",
    "ax_right = fig.add_axes([0.185, 0.286, 0.43, 0.424])\n",
    "img_right = mpimg.imread(right_img_path)\n",
    "ax_right.imshow(img_right)\n",
    "ax_right.axis('off')\n",
    "\n",
    "plt.savefig(base_path / \"pca_combined_finalsub.png\", bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94662776",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_cosine_similarity(vec1, vec2):\n",
    "    if np.linalg.norm(vec1) == 0 or np.linalg.norm(vec2) == 0:\n",
    "        return np.nan\n",
    "    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))\n",
    "\n",
    "avg_similarities_per_dataset = {}\n",
    "\n",
    "\n",
    "label_map = {\n",
    "    \"ground_truth\": \"Ground truth\",\n",
    "    \"generated\": \"Generated\",\n",
    "    \"random\": \"Random\"\n",
    "}\n",
    "# Compute average cosine similarities\n",
    "for name in file_names:\n",
    "    full_path = os.path.join(base_path, name)\n",
    "    with open(full_path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "\n",
    "    avg_sims = []\n",
    "    for graph_id, graph_content in data.items():\n",
    "        focal_embedding = graph_content['focal_embedding']\n",
    "        ref_embeddings = graph_content['reference_embeddings']\n",
    "\n",
    "        if len(focal_embedding.shape) == 2:\n",
    "            focal_embedding = focal_embedding.squeeze()\n",
    "\n",
    "        similarities = []\n",
    "        for node_emb in ref_embeddings.values():\n",
    "            if len(node_emb.shape) == 2:\n",
    "                node_emb = node_emb.squeeze()\n",
    "            sim = compute_cosine_similarity(focal_embedding, node_emb)\n",
    "            if not np.isnan(sim):\n",
    "                similarities.append(sim)\n",
    "\n",
    "        avg_sims.append(np.mean(similarities) if similarities else np.nan)\n",
    "\n",
    "    avg_similarities_per_dataset[name_map[name]] = avg_sims\n",
    "\n",
    "# Prepare DataFrame for plotting\n",
    "plot_data = []\n",
    "for dataset, avg_sims in avg_similarities_per_dataset.items():\n",
    "    display_name = label_map[dataset]\n",
    "    plot_data.extend([(display_name, sim) for sim in avg_sims if not np.isnan(sim)])\n",
    "\n",
    "df_plot = pd.DataFrame(plot_data, columns=[\"Dataset\", \"Average Cosine Similarity\"])\n",
    "\n",
    "plt.figure(figsize=(7, 3))\n",
    "sns.boxenplot(\n",
    "    x=\"Dataset\",\n",
    "    y=\"Average Cosine Similarity\",\n",
    "    data=df_plot,\n",
    "    palette=[colors[lbl] for lbl in labels],\n",
    "    showfliers=False\n",
    ")\n",
    "plt.ylim(0.0, 1)\n",
    "plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1])\n",
    "\n",
    "plt.xlabel(\"\", fontsize=20, fontfamily='serif')\n",
    "plt.ylabel(\"Mean \\n cosine similarity\", fontsize=20, fontfamily='serif')\n",
    "plt.xticks(fontsize=20, fontfamily='serif')\n",
    "plt.yticks(fontsize=20, fontfamily='serif')\n",
    "plt.savefig(base_path / \"high_quality_graph4sub.png\", dpi=600, bbox_inches='tight')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc6cce3f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def compute_cosine_similarity(vec1, vec2):\n",
    "    if np.linalg.norm(vec1) == 0 or np.linalg.norm(vec2) == 0:\n",
    "        return np.nan\n",
    "    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))\n",
    "\n",
    "pairwise_avg_similarities_per_dataset = {}\n",
    "\n",
    "# Similarity function\n",
    "def compute_cosine_similarity(vec1, vec2):\n",
    "    if np.linalg.norm(vec1) == 0 or np.linalg.norm(vec2) == 0:\n",
    "        return np.nan\n",
    "    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))\n",
    "\n",
    "# Compute pairwise averages\n",
    "pairwise_avg_similarities_per_dataset = {}\n",
    "\n",
    "for name in file_names:\n",
    "    full_path = os.path.join(base_path, name)\n",
    "    with open(full_path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "\n",
    "    pairwise_avg_sims = []\n",
    "    for graph_content in data.values():\n",
    "        focal_embedding = graph_content['focal_embedding']\n",
    "        ref_embeddings = graph_content['reference_embeddings']\n",
    "        \n",
    "        if len(focal_embedding.shape) == 2:\n",
    "            focal_embedding = focal_embedding.squeeze()\n",
    "        \n",
    "        all_embeddings = [focal_embedding]\n",
    "        for emb in ref_embeddings.values():\n",
    "            if len(emb.shape) == 2:\n",
    "                emb = emb.squeeze()\n",
    "            all_embeddings.append(emb)\n",
    "\n",
    "        pairwise_sims = [\n",
    "            compute_cosine_similarity(v1, v2)\n",
    "            for v1, v2 in itertools.combinations(all_embeddings, 2)\n",
    "        ]\n",
    "        valid_sims = [s for s in pairwise_sims if not np.isnan(s)]\n",
    "        pairwise_avg_sims.append(np.mean(valid_sims) if valid_sims else np.nan)\n",
    "\n",
    "    short_key = name_map[name]\n",
    "    pairwise_avg_similarities_per_dataset[short_key] = pairwise_avg_sims\n",
    "\n",
    "plot_data_pairwise = []\n",
    "for key, values in pairwise_avg_similarities_per_dataset.items():\n",
    "    display_name = label_map[key]\n",
    "    plot_data_pairwise.extend([(display_name, sim) for sim in values if not np.isnan(sim)])\n",
    "\n",
    "df_plot_pairwise = pd.DataFrame(plot_data_pairwise, columns=[\"Dataset\", \"Average Pairwise Cosine Similarity\"])\n",
    "\n",
    "plt.figure(figsize=(7, 3))\n",
    "sns.boxenplot(\n",
    "    x=\"Dataset\",\n",
    "    y=\"Average Pairwise Cosine Similarity\",\n",
    "    data=df_plot_pairwise,\n",
    "    palette=[colors[lbl] for lbl in labels],\n",
    "    showfliers=False\n",
    ")\n",
    "plt.ylim(0.0, 1)\n",
    "plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1])\n",
    "\n",
    "plt.xlabel(\"\", fontsize=20, fontfamily='serif')\n",
    "plt.ylabel(\"Mean pairwise \\n cosine similarity\", fontsize=20, fontfamily='serif')\n",
    "plt.xticks(fontsize=20, fontfamily='serif')\n",
    "plt.yticks(fontsize=20, fontfamily='serif')\n",
    "plt.savefig(base_path / \"high_quality_graph3sub.png\", dpi=600, bbox_inches='tight')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c69500eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "centroid_similarities_per_dataset = {}\n",
    "\n",
    "for name in file_names:\n",
    "    full_path = os.path.join(base_path, name)\n",
    "    with open(full_path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "\n",
    "    sims = []\n",
    "    for graph_content in data.values():\n",
    "        focal_embedding = graph_content['focal_embedding']\n",
    "        ref_embeddings = graph_content['reference_embeddings']\n",
    "        \n",
    "        if len(focal_embedding.shape) == 2:\n",
    "            focal_embedding = focal_embedding.squeeze()\n",
    "        \n",
    "        refs = []\n",
    "        for emb in ref_embeddings.values():\n",
    "            if len(emb.shape) == 2:\n",
    "                emb = emb.squeeze()\n",
    "            refs.append(np.asarray(emb).reshape(-1))\n",
    "        \n",
    "        if len(refs) == 0:\n",
    "            sims.append(np.nan)\n",
    "            continue\n",
    "        \n",
    "        centroid = np.mean(refs, axis=0)\n",
    "        sim = compute_cosine_similarity(focal_embedding, centroid)\n",
    "        sims.append(sim)\n",
    "\n",
    "    short_key = name_map[name]\n",
    "    centroid_similarities_per_dataset[short_key] = sims\n",
    "\n",
    "plot_data_centroid = []\n",
    "for key, values in centroid_similarities_per_dataset.items():\n",
    "    display_name = label_map[key]\n",
    "    plot_data_centroid.extend([(display_name, sim) for sim in values if not np.isnan(sim)])\n",
    "\n",
    "df_plot_centroid = pd.DataFrame(plot_data_centroid, columns=[\"Dataset\", \"Cosine(focal, refs centroid)\"])\n",
    "\n",
    "plt.figure(figsize=(7, 3))\n",
    "sns.boxenplot(\n",
    "    x=\"Dataset\",\n",
    "    y=\"Cosine(focal, refs centroid)\",\n",
    "    data=df_plot_centroid,\n",
    "    palette=[colors[lbl] for lbl in labels],\n",
    "    showfliers=False\n",
    ")\n",
    "plt.ylim(0.0, 1)\n",
    "plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1])\n",
    "plt.xlabel(\"\", fontsize=20, fontfamily='serif')\n",
    "plt.ylabel(\"Cosine similarity \\n to centroid\", fontsize=20, fontfamily='serif')\n",
    "plt.xticks(fontsize=20, fontfamily='serif')\n",
    "plt.yticks(fontsize=20, fontfamily='serif')\n",
    "plt.savefig(base_path / \"high_quality_graph_centroidsub.png\", dpi=600, bbox_inches='tight')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00cb863c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "left_img_path = base_path2 / \"Slide4.png\"\n",
    "middle_img_path = base_path / \"pca_combined_finalsub.png\"\n",
    "\n",
    "right_top_img_path = base_path / \"high_quality_graph4sub.png\"\n",
    "right_middle_img_path = base_path / \"high_quality_graph3sub.png\"\n",
    "right_bottom_img_path = base_path / \"high_quality_graph_centroidsub.png\"\n",
    "\n",
    "right1_top_img_path = base_path / \"high_quality_graphEuclideansub.png\"\n",
    "right1_middle_img_path = base_path / \"high_quality_graph_pairwise_l2sub.png\"\n",
    "right1_bottom_img_path = base_path / \"high_quality_graph_centroid_l2sub.png\"\n",
    "\n",
    "fig = plt.figure(figsize=(5.5, 4), dpi=1800)\n",
    "\n",
    "ax_left = fig.add_axes([-0.102, 0.09, 0.53, 0.6])\n",
    "img_left = mpimg.imread(left_img_path)\n",
    "ax_left.imshow(img_left)\n",
    "ax_left.axis('off')\n",
    "\n",
    "ax_middle = fig.add_axes([0.24, 0.205, 0.38, 0.38])\n",
    "img_middle = mpimg.imread(middle_img_path)\n",
    "ax_middle.imshow(img_middle)\n",
    "ax_middle.axis('off')\n",
    "\n",
    "ax_right_top = fig.add_axes([0.59, 0.36, 0.16, 0.32])\n",
    "img_right_top = mpimg.imread(right_top_img_path)\n",
    "ax_right_top.imshow(img_right_top)\n",
    "ax_right_top.axis('off')\n",
    "\n",
    "ax_right_bottom = fig.add_axes([0.59, 0.136, 0.16, 0.32])\n",
    "img_right_bottom = mpimg.imread(right_bottom_img_path)\n",
    "ax_right_bottom.imshow(img_right_bottom)\n",
    "ax_right_bottom.axis('off')\n",
    "\n",
    "\n",
    "ax_right_middle = fig.add_axes([0.59, 0.249, 0.16, 0.32])\n",
    "img_right_middle = mpimg.imread(right_middle_img_path)\n",
    "ax_right_middle.imshow(img_right_middle)\n",
    "ax_right_middle.axis('off')\n",
    "\n",
    "\n",
    "ax_right1_top = fig.add_axes([0.757, 0.36, 0.16, 0.32])\n",
    "img_right1_top = mpimg.imread(right1_top_img_path)\n",
    "ax_right1_top.imshow(img_right1_top)\n",
    "ax_right1_top.axis('off')\n",
    "\n",
    "\n",
    "ax_right1_bottom = fig.add_axes([0.757, 0.136, 0.16, 0.32])\n",
    "img_right1_bottom = mpimg.imread(right1_bottom_img_path)\n",
    "ax_right1_bottom.imshow(img_right1_bottom)\n",
    "ax_right1_bottom.axis('off')\n",
    "\n",
    "ax_right1_middle = fig.add_axes([0.757, 0.249, 0.16, 0.32])\n",
    "img_right1_middle = mpimg.imread(right1_middle_img_path)\n",
    "ax_right1_middle.imshow(img_right1_middle)\n",
    "ax_right1_middle.axis('off')\n",
    "\n",
    "\n",
    "plt.savefig(base_path / \"graph_with_images1sub.png\",  bbox_inches='tight')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4039389",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
